diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index e71c96015..6711ca16e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [ ('Pad conds', 'pad_cond_uncond'), ('VAE Encoder', 'sd_vae_encode_method'), ('VAE Decoder', 'sd_vae_decode_method'), + ('Refiner', 'sd_refiner_checkpoint'), + ('Refiner switch at', 'sd_refiner_switch_at'), ] diff --git a/modules/processing.py b/modules/processing.py index f4748d6d6..ec66fd8e2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -370,6 +370,9 @@ class StableDiffusionProcessing: self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + def get_conds(self): + return self.c, self.uc + def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -1251,6 +1254,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): extra_networks.activate(self, self.extra_network_data) + def get_conds(self): + if self.is_hr_pass: + return self.hr_c, self.hr_uc + + return super().get_conds() + + def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 3f3e83e33..92bf0ca1d 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -131,16 +131,27 @@ replace_torchsde_browinan() def apply_refiner(sampler): completed_ratio = sampler.step / sampler.steps - if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint: - refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) - if refiner_checkpoint_info is None: - raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') - with sd_models.SkipWritingToConfig(): - sd_models.reload_model_weights(info=refiner_checkpoint_info) + if completed_ratio <= shared.opts.sd_refiner_switch_at: + return False - devices.torch_gc() + if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: + return False + + refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) + if refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') + + sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title + sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at + + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=refiner_checkpoint_info) + + devices.torch_gc() + sampler.p.setup_conds() + sampler.update_inner_model() + + return True - sampler.update_inner_model() - sampler.p.setup_conds() diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 5df926d3d..2eeec18a4 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -71,8 +71,6 @@ class VanillaStableDiffusionSampler: if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - sd_samplers_common.apply_refiner(self) - if self.stop_at is not None and self.step > self.stop_at: raise sd_samplers_common.InterruptedException diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3aee4e3a9..46da0a97f 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -87,8 +87,9 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self): + def __init__(self, sampler): super().__init__() + self.sampler = sampler self.model_wrap = None self.mask = None self.nmask = None @@ -126,11 +127,17 @@ class CFGDenoiser(torch.nn.Module): def update_inner_model(self): self.model_wrap = None + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - sd_samplers_common.apply_refiner(self) + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. @@ -282,12 +289,12 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - self.p = None self.funcname = funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser() + self.sampler_extra_args = {} + self.model_wrap_cfg = CFGDenoiser(self) self.model_wrap = self.model_wrap_cfg.inner_model self.sampler_noises = None self.stop_at = None @@ -476,7 +483,7 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -484,7 +491,7 @@ class KDiffusionSampler: 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -514,13 +521,14 @@ class KDiffusionSampler: extra_params_kwargs['noise_sampler'] = noise_sampler self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True