diff --git a/modules/processing.py b/modules/processing.py index f72185ac5..548eec297 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -129,6 +129,73 @@ class StableDiffusionProcessing(): self.all_seeds = None self.all_subseeds = None + def txt2img_image_conditioning(self, x, width=None, height=None): + if self.sampler.conditioning_key not in {'hybrid', 'concat'}: + # Dummy zero conditioning if we're not using inpainting model. + # Still takes up a bit of memory, but no encoder call. + # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + return torch.zeros( + x.shape[0], 5, 1, 1, + dtype=x.dtype, + device=x.device + ) + + height = height or self.height + width = width or self.width + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + + def img2img_image_conditioning(self, source_image, latent_image, image_mask = None): + if self.sampler.conditioning_key not in {'hybrid', 'concat'}: + # Dummy zero conditioning if we're not using inpainting model. + return torch.zeros( + latent_image.shape[0], 5, 1, 1, + dtype=latent_image.dtype, + device=latent_image.device + ) + + # Handle the different mask inputs + if image_mask is not None: + if torch.is_tensor(image_mask): + conditioning_mask = image_mask + else: + conditioning_mask = np.array(image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. + conditioning_mask = conditioning_mask.to(source_image.device) + conditioning_image = torch.lerp( + source_image, + source_image * (1.0 - conditioning_mask), + getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) + ) + + # Encode the new masked image using first stage of network. + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype) + + return image_conditioning + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def create_dummy_mask(self, x, width=None, height=None): - if self.sampler.conditioning_key in {'hybrid', 'concat'}: - height = height or self.height - width = width or self.width - - # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) - - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) - - else: - # Dummy zero conditioning if we're not using inpainting model. - # Still takes up a bit of memory, but no encoder call. - # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. - image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) - - return image_conditioning - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] @@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples)) + image_conditioning = self.img2img_image_conditioning( + decoded_samples, + samples, + decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3]) + ) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) return samples @@ -770,40 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - if self.sampler.conditioning_key in {'hybrid', 'concat'}: - if self.image_mask is not None: - conditioning_mask = np.array(self.image_mask.convert("L")) - conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 - conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) - else: - conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) - - # Create another latent image, this time with a masked version of the original input. - conditioning_mask = conditioning_mask.to(image.device) - - # Smoothly interpolate between the masked and unmasked latent conditioning image. - conditioning_image = torch.lerp( - image, - image * (1.0 - conditioning_mask), - getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) - ) - - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) - - # Create the concatenated conditioning tensor to be fed to `c_concat` - conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) - conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) - self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) - self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) - else: - self.image_conditioning = torch.zeros( - self.init_latent.shape[0], 5, 1, 1, - dtype=self.init_latent.dtype, - device=self.init_latent.device - ) + self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):