From c2ce1d3b9c9be2bef7efd1fe5b5a53424105a1c5 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 19 Oct 2024 19:58:13 -0400 Subject: [PATCH] Automatically enable ztSNR based on existence of key in state_dict --- modules/sd_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 167d4ff36..1c7d370e9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -423,6 +423,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer set_model_type(model, state_dict) set_model_fields(model) + if 'ztsnr' in state_dict: + model.ztsnr = True + else: + model.ztsnr = False if model.is_sdxl: sd_models_xl.extend_sdxl(model) @@ -661,7 +665,7 @@ def apply_alpha_schedule_override(sd_model, p=None): p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": + if opts.sd_noise_schedule == "Zero Terminal SNR" or (hasattr(sd_model, 'ztsnr') and sd_model.ztsnr): if p is not None: p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)