mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
style changes for #14979
This commit is contained in:
parent
1a51b166a0
commit
ee470cc6a3
@ -552,36 +552,48 @@ def repair_config(sd_config):
|
||||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return alphas_bar
|
||||
|
||||
|
||||
def apply_alpha_schedule_override(sd_model, p=None):
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
"""
|
||||
Applies an override to the alpha schedule of the model according to settings.
|
||||
- downcasts the alpha schedule to half precision
|
||||
- rescales the alpha schedule to have zero terminal SNR
|
||||
"""
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
|
||||
return
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
if opts.use_downcasted_alpha_bar:
|
||||
if p is not None:
|
||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return alphas_bar
|
||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||
if p is not None:
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
if hasattr(sd_model, 'alphas_cumprod') and hasattr(sd_model, 'alphas_cumprod_original'):
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
||||
|
||||
if opts.use_downcasted_alpha_bar:
|
||||
if p is not None:
|
||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||
if p is not None:
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
|
Loading…
Reference in New Issue
Block a user