improve get_crop_region

This commit is contained in:
w-e-w 2024-01-21 07:20:52 +09:00
parent f939bce845
commit e36827af32
2 changed files with 10 additions and 35 deletions

View File

@ -3,40 +3,15 @@ from PIL import Image, ImageFilter, ImageOps
def get_crop_region(mask, pad=0): def get_crop_region(mask, pad=0):
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""
mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
h, w = mask.shape box = mask_img.getbbox()
if box:
crop_left = 0 x1, y1, x2, y2 = box
for i in range(w): else: # when no box is found
if not (mask[:, i] == 0).all(): x1, y1 = mask_img.size
break x2 = y2 = 0
crop_left += 1 return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])
crop_right = 0
for i in reversed(range(w)):
if not (mask[:, i] == 0).all():
break
crop_right += 1
crop_top = 0
for i in range(h):
if not (mask[i] == 0).all():
break
crop_top += 1
crop_bottom = 0
for i in reversed(range(h)):
if not (mask[i] == 0).all():
break
crop_bottom += 1
return (
int(max(crop_left-pad, 0)),
int(max(crop_top-pad, 0)),
int(min(w - crop_right + pad, w)),
int(min(h - crop_bottom + pad, h))
)
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):

View File

@ -1562,7 +1562,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpaint_full_res: if self.inpaint_full_res:
self.mask_for_overlay = image_mask self.mask_for_overlay = image_mask
mask = image_mask.convert('L') mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region x1, y1, x2, y2 = crop_region