diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 6bd38e12a..8052b021a 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -156,7 +156,16 @@ replace_torchsde_browinan() def apply_refiner(cfg_denoiser): - completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + if opts.refiner_switch_by_sample_steps: + completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + else: + # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch + try: + timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma))) + except AttributeError: # for samplers that dont use sigmas (DDIM) sigma is actually the timestep + timestep = torch.max(sigma).to(dtype=int) + completed_ratio = (999 - timestep) / 1000 + refiner_switch_at = cfg_denoiser.p.refiner_switch_at refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info