From 9c1ece89784e36a86b19f371e3b6e60bb630394e Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Tue, 20 Feb 2024 19:23:21 -0500 Subject: [PATCH 1/5] Protect alphas_cumprod during refiner switchover --- modules/sd_samplers_common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 6bd38e12a..c9578ffe6 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -181,8 +181,12 @@ def apply_refiner(cfg_denoiser): cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at + alphas_cumprod_original = cfg_denoiser.p.sd_model.alphas_cumprod_original + alphas_cumprod = cfg_denoiser.p.sd_model.alphas_cumprod with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=refiner_checkpoint_info) + cfg_denoiser.p.sd_model.alphas_cumprod_original = alphas_cumprod_original + cfg_denoiser.p.sd_model.alphas_cumprod = alphas_cumprod devices.torch_gc() cfg_denoiser.p.setup_conds() From 648f6a8e0cdf5881cbec9697792e6294c54422d4 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sun, 25 Feb 2024 23:28:36 -0500 Subject: [PATCH 2/5] dont need to preserve alphas_cumprod_original --- modules/sd_samplers_common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index c9578ffe6..7ab1bf65a 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -181,11 +181,9 @@ def apply_refiner(cfg_denoiser): cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at - alphas_cumprod_original = cfg_denoiser.p.sd_model.alphas_cumprod_original alphas_cumprod = cfg_denoiser.p.sd_model.alphas_cumprod with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=refiner_checkpoint_info) - cfg_denoiser.p.sd_model.alphas_cumprod_original = alphas_cumprod_original cfg_denoiser.p.sd_model.alphas_cumprod = alphas_cumprod devices.torch_gc() From e2cd92ea230801ecc5fc7ed90e14ab55c946fb4a Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:43:27 -0500 Subject: [PATCH 3/5] move refiner fix to sd_models.py --- modules/sd_models.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2c0457715..fbd53adba 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,6 +15,7 @@ from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules.timer import Timer +from modules.shared import opts import tomesd import numpy as np @@ -549,6 +550,36 @@ def repair_config(sd_config): karlo_path = os.path.join(paths.models_path, 'karlo') sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) +def apply_alpha_schedule_override(sd_model, p=None): + def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + if hasattr(sd_model, 'alphas_cumprod') and hasattr(sd_model, 'alphas_cumprod_original'): + sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + if p is not None: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + if p is not None: + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device) sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' @@ -812,6 +843,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + apply_alpha_schedule_override(sd_model) return sd_model if sd_model is not None: From 94f23d00a76e7988f4b73ced1fa2922801e893fb Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:44:58 -0500 Subject: [PATCH 4/5] move alphas cumprod override out of processing --- modules/processing.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index d208a922d..411c7c3f4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -915,33 +915,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - def rescale_zero_terminal_snr_abar(alphas_cumprod): - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= (alphas_bar_sqrt_T) - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas_bar[-1] = 4.8973451890853435e-08 - return alphas_bar - - if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + sd_models.apply_alpha_schedule_override(p.sd_model, p) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From 4dae91a1fe960ad9a9774f8f5407ef67c1a109f9 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:46:10 -0500 Subject: [PATCH 5/5] remove alphas cumprod fix from samplers_common --- modules/sd_samplers_common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 7ab1bf65a..6bd38e12a 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -181,10 +181,8 @@ def apply_refiner(cfg_denoiser): cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at - alphas_cumprod = cfg_denoiser.p.sd_model.alphas_cumprod with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=refiner_checkpoint_info) - cfg_denoiser.p.sd_model.alphas_cumprod = alphas_cumprod devices.torch_gc() cfg_denoiser.p.setup_conds()