mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
Merge pull request #16235 from v0xie/beta-sampling
Feature: Beta scheduler
This commit is contained in:
commit
5a10bb9aa6
@ -120,6 +120,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
if scheduler.need_inner_model:
|
if scheduler.need_inner_model:
|
||||||
sigmas_kwargs['inner_model'] = self.model_wrap
|
sigmas_kwargs['inner_model'] = self.model_wrap
|
||||||
|
|
||||||
|
if scheduler.label == 'Beta':
|
||||||
|
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
|
||||||
|
p.extra_generation_params["Beta schedule beta"] = opts.beta_dist_beta
|
||||||
|
|
||||||
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu)
|
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu)
|
||||||
|
|
||||||
if discard_next_to_last_sigma:
|
if discard_next_to_last_sigma:
|
||||||
|
@ -2,6 +2,7 @@ import dataclasses
|
|||||||
import torch
|
import torch
|
||||||
import k_diffusion
|
import k_diffusion
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
@ -115,6 +116,17 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
|||||||
return torch.FloatTensor(sigs).to(device)
|
return torch.FloatTensor(sigs).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
||||||
|
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
|
||||||
|
alpha = shared.opts.beta_dist_alpha
|
||||||
|
beta = shared.opts.beta_dist_beta
|
||||||
|
timesteps = 1 - np.linspace(0, 1, n)
|
||||||
|
timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
|
||||||
|
sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
|
||||||
|
sigmas += [0.0]
|
||||||
|
return torch.FloatTensor(sigmas).to(device)
|
||||||
|
|
||||||
|
|
||||||
schedulers = [
|
schedulers = [
|
||||||
Scheduler('automatic', 'Automatic', None),
|
Scheduler('automatic', 'Automatic', None),
|
||||||
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
||||||
@ -127,6 +139,7 @@ schedulers = [
|
|||||||
Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
|
Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
|
||||||
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
|
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
|
||||||
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
|
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
|
||||||
|
Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
||||||
|
@ -405,6 +405,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||||
'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"),
|
'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"),
|
||||||
'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"),
|
'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"),
|
||||||
|
'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'),
|
||||||
|
'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
|
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
|
||||||
|
@ -259,6 +259,8 @@ axis_options = [
|
|||||||
AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
|
AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
|
||||||
AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
|
AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
|
||||||
AxisOption("Schedule rho", float, apply_override("rho")),
|
AxisOption("Schedule rho", float, apply_override("rho")),
|
||||||
|
AxisOption("Beta schedule alpha", float, apply_override("beta_dist_alpha")),
|
||||||
|
AxisOption("Beta schedule beta", float, apply_override("beta_dist_beta")),
|
||||||
AxisOption("Eta", float, apply_field("eta")),
|
AxisOption("Eta", float, apply_field("eta")),
|
||||||
AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')),
|
AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')),
|
||||||
AxisOption("Denoising", float, apply_field("denoising_strength")),
|
AxisOption("Denoising", float, apply_field("denoising_strength")),
|
||||||
|
Loading…
Reference in New Issue
Block a user