From f4869f8de3ed76735ea331fe5463abc6190bd4cf Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:18:13 -0500 Subject: [PATCH 1/4] Add compatibility option for refiner switching --- modules/shared_options.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index bb3752ba6..e17eed512 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -227,7 +227,8 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd" "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), - "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod") + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod"), + "refiner_switch_by_sample_steps": OptionInfo(False, "Switch to refiner by sampling steps instead of model timesteps. Old behavior for refiner.", infotext="Refiner switch by sampling steps") })) options_templates.update(options_section(('interrogate', "Interrogate"), { From 09d2e5881120c4a51888633947062b40726c6fef Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:22:40 -0500 Subject: [PATCH 2/4] Pass sigma to apply_refiner --- modules/sd_samplers_cfg_denoiser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index a73d3b036..93581c9ac 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -152,7 +152,7 @@ class CFGDenoiser(torch.nn.Module): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - if sd_samplers_common.apply_refiner(self): + if sd_samplers_common.apply_refiner(self, sigma): cond = self.sampler.sampler_extra_args['cond'] uncond = self.sampler.sampler_extra_args['uncond'] From 25eeeaa65f819bb40df427141b82b46d3fcf59e9 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:37:29 -0500 Subject: [PATCH 3/4] Allow refiner to be triggered by model timestep instead of sampling --- modules/sd_samplers_common.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 6bd38e12a..8052b021a 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -156,7 +156,16 @@ replace_torchsde_browinan() def apply_refiner(cfg_denoiser): - completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + if opts.refiner_switch_by_sample_steps: + completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + else: + # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch + try: + timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma))) + except AttributeError: # for samplers that dont use sigmas (DDIM) sigma is actually the timestep + timestep = torch.max(sigma).to(dtype=int) + completed_ratio = (999 - timestep) / 1000 + refiner_switch_at = cfg_denoiser.p.refiner_switch_at refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info From bf348032bc07d48ec0b4fbb5be1c4648ee8bd49b Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:59:28 -0500 Subject: [PATCH 4/4] fix missing arg --- modules/sd_samplers_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 8052b021a..045b9e2fe 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -155,7 +155,7 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() -def apply_refiner(cfg_denoiser): +def apply_refiner(cfg_denoiser, sigma): if opts.refiner_switch_by_sample_steps: completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps else: