From 686598387f49a1e21298204e5f162a6ec93cddff Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 7 Aug 2023 12:10:16 +0300 Subject: [PATCH] send noisy latent into refiner without adding noise --- modules/processing.py | 33 ++++++++++++++++--------------- modules/sd_samplers_kdiffusion.py | 2 ++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index cb5e3f725..36e7ecab4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -384,11 +384,11 @@ class StableDiffusionProcessing: shared.state.nextjob() stopped_at = self.sampler.stop_at + noisy_output = self.sampler.noisy_output self.sampler = None a_is_sdxl = shared.sd_model.is_sdxl - - decoded_samples = decode_latent_batch(shared.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + decoded_noisy = decode_latent_batch(shared.sd_model, noisy_output, target_device=devices.cpu, check_for_nans=True) refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) if refiner_checkpoint_info is None: @@ -408,21 +408,21 @@ class StableDiffusionProcessing: b_is_sdxl = shared.sd_model.is_sdxl if a_is_sdxl != b_is_sdxl: - decoded_samples = torch.stack(decoded_samples).float() - decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) - latent = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model) + decoded_noisy = torch.stack(decoded_noisy).float() + decoded_noisy = torch.clamp((decoded_noisy + 1.0) / 2.0, min=0.0, max=1.0) + noisy_latent = images_tensor_to_samples(decoded_noisy, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model) else: - latent = samples + noisy_latent = noisy_output - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = torch.zeros_like(noisy_latent) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): denoising_strength = self.denoising_strength - self.denoising_strength = 1.0 - stopped_at / self.steps - self.image_conditioning = txt2img_image_conditioning(shared.sd_model, latent, self.width, self.height) + self.denoising_strength = 1.0 - (stopped_at + 1) / self.steps + self.image_conditioning = txt2img_image_conditioning(shared.sd_model, noisy_latent, self.width, self.height) self.sampler = sd_samplers.create_sampler(self.sampler_name, shared.sd_model) - samples = self.sampler.sample_img2img(self, latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1)) + samples = self.sampler.sample_img2img(self, noisy_latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1)) self.denoising_strength = denoising_strength @@ -823,6 +823,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.interrupted: break + sd_models.reload_model_weights() # model can be changed for example by refiner + p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] @@ -862,10 +864,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - if have_refiner: - p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1)) - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) + + if have_refiner: + p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1)) + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if opts.sd_vae_decode_method != 'Full': @@ -1056,8 +1060,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_uc = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - if self.enable_hr: if self.hr_checkpoint_name: self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) @@ -1355,7 +1357,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) crop_region = None image_mask = self.image_mask diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 359b2d52a..51e640078 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -276,6 +276,7 @@ class KDiffusionSampler: self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None self.stop_at = None + self.noisy_output = None self.eta = None self.config = None # set by the function calling the constructor self.last_latent = None @@ -297,6 +298,7 @@ class KDiffusionSampler: if opts.live_preview_content == "Combined": sd_samplers_common.store_latent(latent) self.last_latent = latent + self.noisy_output = d['x'] if self.stop_at is not None and step > self.stop_at: raise sd_samplers_common.InterruptedException