From 25eeeaa65f819bb40df427141b82b46d3fcf59e9 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:37:29 -0500 Subject: [PATCH] Allow refiner to be triggered by model timestep instead of sampling --- modules/sd_samplers_common.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 6bd38e12a..8052b021a 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -156,7 +156,16 @@ replace_torchsde_browinan() def apply_refiner(cfg_denoiser): - completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + if opts.refiner_switch_by_sample_steps: + completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + else: + # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch + try: + timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma))) + except AttributeError: # for samplers that dont use sigmas (DDIM) sigma is actually the timestep + timestep = torch.max(sigma).to(dtype=int) + completed_ratio = (999 - timestep) / 1000 + refiner_switch_at = cfg_denoiser.p.refiner_switch_at refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info