diff --git a/modules/colorfix.py b/modules/colorfix.py new file mode 100644 index 000000000..77097cab2 --- /dev/null +++ b/modules/colorfix.py @@ -0,0 +1,114 @@ +import torch +from PIL import Image +from torch import Tensor +from torch.nn import functional as F + +from torchvision.transforms import ToTensor, ToPILImage + +def adain_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply adaptive instance normalization + result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def wavelet_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply wavelet reconstruction + result_tensor = wavelet_reconstruction(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def calc_mean_std(feat: Tensor, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): + """Adaptive instance normalization. + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def wavelet_blur(image: Tensor, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq + diff --git a/modules/images.py b/modules/images.py index cfdfb3384..9a6e18b6a 100644 --- a/modules/images.py +++ b/modules/images.py @@ -22,6 +22,7 @@ import hashlib from modules import sd_samplers, shared, script_callbacks, errors from modules.paths_internal import roboto_ttf_file from modules.shared import opts +from modules.colorfix import wavelet_color_fix LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -249,7 +250,7 @@ def draw_prompt_matrix(im, width, height, all_prompts, margin=0): return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin) -def resize_image(resize_mode, im, width, height, upscaler_name=None): +def resize_image(resize_mode, im, width, height, upscaler_name=None, preserve_colors=False): """ Resizes an image with the specified resize_mode, width, and height. @@ -263,7 +264,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): height: The height to resize the image to. upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img. """ - + before_resize = im upscaler_name = upscaler_name or opts.upscaler_for_img2img def resize(im, w, h): @@ -285,6 +286,9 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): if im.width != w or im.height != h: im = im.resize((w, h), resample=LANCZOS) + if preserve_colors: + im = wavelet_color_fix(im, before_resize.resize(im.size)) + return im if resize_mode == 0: diff --git a/modules/processing.py b/modules/processing.py index a7fd3e031..f1bbf7706 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -67,7 +67,7 @@ def uncrop(image, dest_size, paste_loc): base_image = Image.new('RGBA', dest_size) factor_x = w // image.size[0] factor_y = h // image.size[1] - image = images.resize_image(1, image, w, h) + image = images.resize_image(1, image, w, h, preserve_colors=True) paste_x = max(x - factor_x, 0) paste_y = max(y - factor_y, 0) base_image.paste(image, (paste_x, paste_y)) @@ -1683,7 +1683,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.flatten(img, opts.img2img_background_color) if crop_region is None and self.resize_mode != 3: - image = images.resize_image(self.resize_mode, image, self.width, self.height) + image = images.resize_image(self.resize_mode, image, self.width, self.height, preserve_colors=True) if image_mask is not None: if self.mask_for_overlay.size != (image.width, image.height): @@ -1696,7 +1696,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # crop_region is not None if we are doing inpaint full res if crop_region is not None: image = image.crop(crop_region) - image = images.resize_image(2, image, self.width, self.height) + image = images.resize_image(2, image, self.width, self.height, preserve_colors=True) if image_mask is not None: if self.inpainting_fill != 1: