From 0d7a17add52d8d7bfb3b7294afcb2ffb3adf9357 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jul 2024 12:25:00 +0400 Subject: [PATCH 1/3] fix inpaint only masked --- modules/masking.py | 38 ++++++++++++++++++++++++++++++++++++++ modules/processing.py | 8 +++++++- modules/shared_options.py | 1 + 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/modules/masking.py b/modules/masking.py index 2fc830319..29548bd8d 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -1,3 +1,4 @@ +import math from PIL import Image, ImageFilter, ImageOps @@ -77,6 +78,43 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w return x1, y1, x2, y2 +def fix_crop_region_integer_scale(crop_region, processing_width, processing_height, image_width, image_height): + """expands crop region get_crop_region() to avoid non-integer scaling artifacts (different pixels size) after applying overlay""" + + x1, y1, x2, y2 = crop_region + + ratio_w = (x2 - x1) / processing_width + ratio_h = (y2 - y1) / processing_height + + desired_w = math.ceil(ratio_w) * processing_width + diff_w = desired_w - (x2 - x1) + diff_w_l = diff_w // 2 + diff_w_r = diff_w - diff_w_l + x1 -= diff_w_l + x2 += diff_w_r + if x1 < 0: + x2 -= x1 + x1 -= x1 + if x2 >= image_width: + x2 = image_width + + desired_h = math.ceil(ratio_h) * processing_height + diff_h = desired_h - (y2 - y1) + diff_h_u = diff_h // 2 + diff_h_d = diff_h - diff_h_u + y1 -= diff_h_u + y2 += diff_h_d + if y1 < 0: + y2 -= y1 + y1 -= y1 + if y2 >= image_height: + y2 = image_height + + print(f"padding was increased by {max(diff_w_l, diff_w_r, diff_h_u, diff_h_d)} after integer upscale correction") + + return x1, y1, x2, y2 + + def fill(image, mask): """fills masked regions with colors from image using blur. Not extremely effective.""" diff --git a/modules/processing.py b/modules/processing.py index 7535b56e1..a7fd3e031 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -65,8 +65,12 @@ def apply_color_correction(correction, original_image): def uncrop(image, dest_size, paste_loc): x, y, w, h = paste_loc base_image = Image.new('RGBA', dest_size) + factor_x = w // image.size[0] + factor_y = h // image.size[1] image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) + paste_x = max(x - factor_x, 0) + paste_y = max(y - factor_y, 0) + base_image.paste(image, (paste_x, paste_y)) image = base_image return image @@ -1639,6 +1643,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding) if crop_region: crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) + if shared.opts.integer_only_masked: + crop_region = masking.fix_crop_region_integer_scale(crop_region, self.width, self.height, mask.width, mask.height) x1, y1, x2, y2 = crop_region mask = mask.crop(crop_region) image_mask = images.resize_image(2, mask, self.width, self.height) diff --git a/modules/shared_options.py b/modules/shared_options.py index 096366e0a..15331d4b5 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -226,6 +226,7 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), { "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), "img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'), "overlay_inpaint": OptionInfo(True, "Overlay original for inpaint").info("when inpainting, overlay the original image over the areas that weren't inpainted."), + "integer_only_masked": OptionInfo(False, "Integer upscale in inpaint only masked").info("Correct inpaint padding for only masked to have integer upscaling after fitting cropped region in original image"), })) options_templates.update(options_section(('optimizations', "Optimizations", "sd"), { From ad4fac2f8efee740a9c17ab377cce7b6286239d1 Mon Sep 17 00:00:00 2001 From: Andray Date: Wed, 24 Jul 2024 22:11:53 +0400 Subject: [PATCH 2/3] add preserve_colors flag for images.resize_image --- modules/colorfix.py | 114 ++++++++++++++++++++++++++++++++++++++++++ modules/images.py | 8 ++- modules/processing.py | 6 +-- 3 files changed, 123 insertions(+), 5 deletions(-) create mode 100644 modules/colorfix.py diff --git a/modules/colorfix.py b/modules/colorfix.py new file mode 100644 index 000000000..77097cab2 --- /dev/null +++ b/modules/colorfix.py @@ -0,0 +1,114 @@ +import torch +from PIL import Image +from torch import Tensor +from torch.nn import functional as F + +from torchvision.transforms import ToTensor, ToPILImage + +def adain_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply adaptive instance normalization + result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def wavelet_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply wavelet reconstruction + result_tensor = wavelet_reconstruction(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def calc_mean_std(feat: Tensor, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): + """Adaptive instance normalization. + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def wavelet_blur(image: Tensor, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq + diff --git a/modules/images.py b/modules/images.py index cfdfb3384..9a6e18b6a 100644 --- a/modules/images.py +++ b/modules/images.py @@ -22,6 +22,7 @@ import hashlib from modules import sd_samplers, shared, script_callbacks, errors from modules.paths_internal import roboto_ttf_file from modules.shared import opts +from modules.colorfix import wavelet_color_fix LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -249,7 +250,7 @@ def draw_prompt_matrix(im, width, height, all_prompts, margin=0): return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin) -def resize_image(resize_mode, im, width, height, upscaler_name=None): +def resize_image(resize_mode, im, width, height, upscaler_name=None, preserve_colors=False): """ Resizes an image with the specified resize_mode, width, and height. @@ -263,7 +264,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): height: The height to resize the image to. upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img. """ - + before_resize = im upscaler_name = upscaler_name or opts.upscaler_for_img2img def resize(im, w, h): @@ -285,6 +286,9 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): if im.width != w or im.height != h: im = im.resize((w, h), resample=LANCZOS) + if preserve_colors: + im = wavelet_color_fix(im, before_resize.resize(im.size)) + return im if resize_mode == 0: diff --git a/modules/processing.py b/modules/processing.py index a7fd3e031..f1bbf7706 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -67,7 +67,7 @@ def uncrop(image, dest_size, paste_loc): base_image = Image.new('RGBA', dest_size) factor_x = w // image.size[0] factor_y = h // image.size[1] - image = images.resize_image(1, image, w, h) + image = images.resize_image(1, image, w, h, preserve_colors=True) paste_x = max(x - factor_x, 0) paste_y = max(y - factor_y, 0) base_image.paste(image, (paste_x, paste_y)) @@ -1683,7 +1683,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.flatten(img, opts.img2img_background_color) if crop_region is None and self.resize_mode != 3: - image = images.resize_image(self.resize_mode, image, self.width, self.height) + image = images.resize_image(self.resize_mode, image, self.width, self.height, preserve_colors=True) if image_mask is not None: if self.mask_for_overlay.size != (image.width, image.height): @@ -1696,7 +1696,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # crop_region is not None if we are doing inpaint full res if crop_region is not None: image = image.crop(crop_region) - image = images.resize_image(2, image, self.width, self.height) + image = images.resize_image(2, image, self.width, self.height, preserve_colors=True) if image_mask is not None: if self.inpainting_fill != 1: From 56d79f6dc216a67d1fc5cf2ca702872462f10a8c Mon Sep 17 00:00:00 2001 From: Andray Date: Thu, 25 Jul 2024 13:43:37 +0400 Subject: [PATCH 3/3] add `img2img_upscaler_preserve_colors` option --- modules/processing.py | 6 +++--- modules/shared_options.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f1bbf7706..80e6dece7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -67,7 +67,7 @@ def uncrop(image, dest_size, paste_loc): base_image = Image.new('RGBA', dest_size) factor_x = w // image.size[0] factor_y = h // image.size[1] - image = images.resize_image(1, image, w, h, preserve_colors=True) + image = images.resize_image(1, image, w, h, preserve_colors=shared.opts.img2img_upscaler_preserve_colors) paste_x = max(x - factor_x, 0) paste_y = max(y - factor_y, 0) base_image.paste(image, (paste_x, paste_y)) @@ -1683,7 +1683,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.flatten(img, opts.img2img_background_color) if crop_region is None and self.resize_mode != 3: - image = images.resize_image(self.resize_mode, image, self.width, self.height, preserve_colors=True) + image = images.resize_image(self.resize_mode, image, self.width, self.height, preserve_colors=shared.opts.img2img_upscaler_preserve_colors) if image_mask is not None: if self.mask_for_overlay.size != (image.width, image.height): @@ -1696,7 +1696,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # crop_region is not None if we are doing inpaint full res if crop_region is not None: image = image.crop(crop_region) - image = images.resize_image(2, image, self.width, self.height, preserve_colors=True) + image = images.resize_image(2, image, self.width, self.height, preserve_colors=shared.opts.img2img_upscaler_preserve_colors) if image_mask is not None: if self.inpainting_fill != 1: diff --git a/modules/shared_options.py b/modules/shared_options.py index 15331d4b5..256669aa5 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -103,6 +103,7 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), + "img2img_upscaler_preserve_colors": OptionInfo(False, "Preserve colors in upscaler for img2img"), "set_scale_by_when_changing_upscaler": OptionInfo(False, "Automatically set the Scale by factor based on the name of the selected Upscaler."), }))