mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 11:50:18 +08:00
style changes for #14979
This commit is contained in:
parent
1a51b166a0
commit
ee470cc6a3
@ -552,36 +552,48 @@ def repair_config(sd_config):
|
|||||||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
|
||||||
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
|
return alphas_bar
|
||||||
|
|
||||||
|
|
||||||
def apply_alpha_schedule_override(sd_model, p=None):
|
def apply_alpha_schedule_override(sd_model, p=None):
|
||||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
"""
|
||||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
Applies an override to the alpha schedule of the model according to settings.
|
||||||
|
- downcasts the alpha schedule to half precision
|
||||||
|
- rescales the alpha schedule to have zero terminal SNR
|
||||||
|
"""
|
||||||
|
|
||||||
# Store old values.
|
if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
|
||||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
return
|
||||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
||||||
|
|
||||||
# Shift so the last timestep is zero.
|
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
||||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
|
||||||
|
|
||||||
# Scale so the first timestep is back to the old value.
|
if opts.use_downcasted_alpha_bar:
|
||||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
if p is not None:
|
||||||
|
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||||
|
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
||||||
|
|
||||||
# Convert alphas_bar_sqrt to betas
|
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
if p is not None:
|
||||||
alphas_bar[-1] = 4.8973451890853435e-08
|
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||||
return alphas_bar
|
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||||
|
|
||||||
if hasattr(sd_model, 'alphas_cumprod') and hasattr(sd_model, 'alphas_cumprod_original'):
|
|
||||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
|
||||||
|
|
||||||
if opts.use_downcasted_alpha_bar:
|
|
||||||
if p is not None:
|
|
||||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
|
||||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
|
||||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
|
||||||
if p is not None:
|
|
||||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
|
||||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
|
||||||
|
|
||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
|
Loading…
Reference in New Issue
Block a user