mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 11:50:18 +08:00
Merge pull request #15333 from AUTOMATIC1111/scheduler_selection
Scheduler selection in main UI
This commit is contained in:
commit
bf2f7b3af4
@ -146,7 +146,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
return batch_results
|
return batch_results
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
@ -193,10 +193,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
sampler_name=sampler_name,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
|
@ -152,6 +152,7 @@ class StableDiffusionProcessing:
|
|||||||
seed_resize_from_w: int = -1
|
seed_resize_from_w: int = -1
|
||||||
seed_enable_extras: bool = True
|
seed_enable_extras: bool = True
|
||||||
sampler_name: str = None
|
sampler_name: str = None
|
||||||
|
scheduler: str = None
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
n_iter: int = 1
|
n_iter: int = 1
|
||||||
steps: int = 50
|
steps: int = 50
|
||||||
@ -721,6 +722,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": p.sampler_name,
|
"Sampler": p.sampler_name,
|
||||||
|
"Schedule type": p.scheduler,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||||
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
||||||
|
79
modules/processing_scripts/sampler.py
Normal file
79
modules/processing_scripts/sampler.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from modules import scripts, sd_samplers, sd_schedulers, shared
|
||||||
|
from modules.infotext_utils import PasteField
|
||||||
|
from modules.ui_components import FormRow, FormGroup
|
||||||
|
|
||||||
|
|
||||||
|
def get_sampler_from_infotext(d: dict):
|
||||||
|
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler_from_infotext(d: dict):
|
||||||
|
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_sampler_and_scheduler(sampler_name, scheduler_name):
|
||||||
|
default_sampler = sd_samplers.samplers[0]
|
||||||
|
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
|
||||||
|
|
||||||
|
name = sampler_name or default_sampler.name
|
||||||
|
|
||||||
|
for scheduler in sd_schedulers.schedulers:
|
||||||
|
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
|
||||||
|
|
||||||
|
for name_option in name_options:
|
||||||
|
if name.endswith(" " + name_option):
|
||||||
|
found_scheduler = scheduler
|
||||||
|
name = name[0:-(len(name_option) + 1)]
|
||||||
|
break
|
||||||
|
|
||||||
|
sampler = sd_samplers.all_samplers_map.get(name, default_sampler)
|
||||||
|
|
||||||
|
# revert back to Automatic if it's the default scheduler for the selected sampler
|
||||||
|
if sampler.options.get('scheduler', None) == found_scheduler.name:
|
||||||
|
found_scheduler = sd_schedulers.schedulers[0]
|
||||||
|
|
||||||
|
return sampler.name, found_scheduler.label
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptSampler(scripts.ScriptBuiltinUI):
|
||||||
|
section = "sampler"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.steps = None
|
||||||
|
self.sampler_name = None
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Sampler"
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
sampler_names = [x.name for x in sd_samplers.visible_samplers()]
|
||||||
|
scheduler_names = [x.label for x in sd_schedulers.schedulers]
|
||||||
|
|
||||||
|
if shared.opts.samplers_in_dropdown:
|
||||||
|
with FormRow(elem_id=f"sampler_selection_{self.tabname}"):
|
||||||
|
self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
|
||||||
|
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
|
||||||
|
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
|
||||||
|
else:
|
||||||
|
with FormGroup(elem_id=f"sampler_selection_{self.tabname}"):
|
||||||
|
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
|
||||||
|
self.sampler_name = gr.Radio(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
|
||||||
|
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
PasteField(self.steps, "Steps", api="steps"),
|
||||||
|
PasteField(self.sampler_name, get_sampler_from_infotext, api="sampler_name"),
|
||||||
|
PasteField(self.scheduler, get_scheduler_from_infotext, api="scheduler"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return self.steps, self.sampler_name, self.scheduler
|
||||||
|
|
||||||
|
def setup(self, p, steps, sampler_name, scheduler):
|
||||||
|
p.steps = steps
|
||||||
|
p.sampler_name = sampler_name
|
||||||
|
p.scheduler = scheduler
|
@ -352,6 +352,9 @@ class ScriptBuiltinUI(Script):
|
|||||||
|
|
||||||
return f'{tabname}{item_id}'
|
return f'{tabname}{item_id}'
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return AlwaysVisible
|
||||||
|
|
||||||
|
|
||||||
current_basedir = paths.script_path
|
current_basedir = paths.script_path
|
||||||
|
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
|
||||||
|
sample_to_image = sd_samplers_common.sample_to_image
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
@ -10,8 +13,8 @@ all_samplers = [
|
|||||||
]
|
]
|
||||||
all_samplers_map = {x.name: x for x in all_samplers}
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
samplers = []
|
samplers: list[sd_samplers_common.SamplerData] = []
|
||||||
samplers_for_img2img = []
|
samplers_for_img2img: list[sd_samplers_common.SamplerData] = []
|
||||||
samplers_map = {}
|
samplers_map = {}
|
||||||
samplers_hidden = {}
|
samplers_hidden = {}
|
||||||
|
|
||||||
@ -57,4 +60,8 @@ def visible_sampler_names():
|
|||||||
return [x.name for x in samplers if x.name not in samplers_hidden]
|
return [x.name for x in samplers if x.name not in samplers_hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def visible_samplers():
|
||||||
|
return [x for x in samplers if x.name not in samplers_hidden]
|
||||||
|
|
||||||
|
|
||||||
set_samplers()
|
set_samplers()
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
|
|
||||||
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
|
|
||||||
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
|
|
||||||
sigs = [
|
|
||||||
inner_model.t_to_sigma(ts)
|
|
||||||
for ts in torch.linspace(start, end, n)[:-1]
|
|
||||||
]
|
|
||||||
sigs += [0.0]
|
|
||||||
return torch.FloatTensor(sigs).to(device)
|
|
@ -1,41 +1,28 @@
|
|||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers
|
||||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||||
from modules.sd_samplers_custom_schedulers import sgm_uniform
|
|
||||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||||
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
samplers_k_diffusion = [
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||||
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||||
|
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
|
||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
|
||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
|
|
||||||
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
|
|
||||||
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
|
||||||
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
|
||||||
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
|
||||||
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
|
||||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
|
||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
|
||||||
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
|
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -59,13 +46,7 @@ sampler_extra_params = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||||
k_diffusion_scheduler = {
|
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
|
||||||
'Automatic': None,
|
|
||||||
'karras': k_diffusion.sampling.get_sigmas_karras,
|
|
||||||
'exponential': k_diffusion.sampling.get_sigmas_exponential,
|
|
||||||
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential,
|
|
||||||
'sgm_uniform' : sgm_uniform,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||||
@ -98,47 +79,41 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
|
|
||||||
steps += 1 if discard_next_to_last_sigma else 0
|
steps += 1 if discard_next_to_last_sigma else 0
|
||||||
|
|
||||||
|
scheduler_name = p.scheduler or 'Automatic'
|
||||||
|
if scheduler_name == 'Automatic':
|
||||||
|
scheduler_name = self.config.options.get('scheduler', None)
|
||||||
|
|
||||||
|
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
|
||||||
|
|
||||||
|
m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
|
||||||
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||||
elif opts.k_sched_type != "Automatic":
|
elif scheduler is None or scheduler.function is None:
|
||||||
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
else:
|
||||||
sigmas_kwargs = {
|
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
|
||||||
'sigma_min': sigma_min,
|
|
||||||
'sigma_max': sigma_max,
|
|
||||||
}
|
|
||||||
|
|
||||||
sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
|
if scheduler.label != 'Automatic':
|
||||||
p.extra_generation_params["Schedule type"] = opts.k_sched_type
|
p.extra_generation_params["Schedule type"] = scheduler.label
|
||||||
|
|
||||||
if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
|
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
|
||||||
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
||||||
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
|
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
|
||||||
if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
|
|
||||||
|
if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max:
|
||||||
sigmas_kwargs['sigma_max'] = opts.sigma_max
|
sigmas_kwargs['sigma_max'] = opts.sigma_max
|
||||||
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
|
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
|
||||||
|
|
||||||
default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
|
if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho:
|
||||||
|
|
||||||
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
|
|
||||||
sigmas_kwargs['rho'] = opts.rho
|
sigmas_kwargs['rho'] = opts.rho
|
||||||
p.extra_generation_params["Schedule rho"] = opts.rho
|
p.extra_generation_params["Schedule rho"] = opts.rho
|
||||||
if opts.k_sched_type == 'sgm_uniform':
|
|
||||||
# Ensure the "step" will be target step + 1
|
if scheduler.need_inner_model:
|
||||||
steps += 1 if not discard_next_to_last_sigma else 0
|
|
||||||
sigmas_kwargs['inner_model'] = self.model_wrap
|
sigmas_kwargs['inner_model'] = self.model_wrap
|
||||||
sigmas_kwargs.pop('rho', None)
|
|
||||||
|
|
||||||
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
|
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=shared.device)
|
||||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
|
||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
|
||||||
|
|
||||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
|
||||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
|
|
||||||
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
|
||||||
sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
|
|
||||||
else:
|
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
|
||||||
|
|
||||||
if discard_next_to_last_sigma:
|
if discard_next_to_last_sigma:
|
||||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||||
|
43
modules/sd_schedulers.py
Normal file
43
modules/sd_schedulers.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import dataclasses
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import k_diffusion
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Scheduler:
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
function: any
|
||||||
|
|
||||||
|
default_rho: float = -1
|
||||||
|
need_inner_model: bool = False
|
||||||
|
aliases: list = None
|
||||||
|
|
||||||
|
|
||||||
|
def uniform(n, sigma_min, sigma_max, inner_model, device):
|
||||||
|
return inner_model.get_sigmas(n)
|
||||||
|
|
||||||
|
|
||||||
|
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
|
||||||
|
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
|
||||||
|
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
|
||||||
|
sigs = [
|
||||||
|
inner_model.t_to_sigma(ts)
|
||||||
|
for ts in torch.linspace(start, end, n + 1)[:-1]
|
||||||
|
]
|
||||||
|
sigs += [0.0]
|
||||||
|
return torch.FloatTensor(sigs).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
schedulers = [
|
||||||
|
Scheduler('automatic', 'Automatic', None),
|
||||||
|
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
||||||
|
Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
|
||||||
|
Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
|
||||||
|
Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
|
||||||
|
Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
@ -368,7 +368,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential", "sgm_uniform"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
|
||||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||||
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||||
|
@ -11,7 +11,7 @@ from PIL import Image
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
if force_enable_hr:
|
if force_enable_hr:
|
||||||
@ -24,10 +24,8 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
sampler_name=sampler_name,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
|
@ -12,7 +12,7 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin # noqa: F401
|
from PIL import Image, PngImagePlugin # noqa: F401
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import gradio_extensons # noqa: F401
|
from modules import gradio_extensons, sd_schedulers # noqa: F401
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
@ -229,19 +229,6 @@ def create_output_panel(tabname, outdir, toprow=None):
|
|||||||
return ui_common.create_output_panel(tabname, outdir, toprow)
|
return ui_common.create_output_panel(tabname, outdir, toprow)
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_and_steps_selection(choices, tabname):
|
|
||||||
if opts.samplers_in_dropdown:
|
|
||||||
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
|
||||||
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
|
||||||
else:
|
|
||||||
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
|
||||||
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
|
||||||
|
|
||||||
return steps, sampler_name
|
|
||||||
|
|
||||||
|
|
||||||
def ordered_ui_categories():
|
def ordered_ui_categories():
|
||||||
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
|
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
|
||||||
|
|
||||||
@ -295,9 +282,6 @@ def create_ui():
|
|||||||
if category == "prompt":
|
if category == "prompt":
|
||||||
toprow.create_inline_toprow_prompts()
|
toprow.create_inline_toprow_prompts()
|
||||||
|
|
||||||
if category == "sampler":
|
|
||||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
|
||||||
|
|
||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="txt2img_column_size", scale=4):
|
with gr.Column(elem_id="txt2img_column_size", scale=4):
|
||||||
@ -396,8 +380,6 @@ def create_ui():
|
|||||||
toprow.prompt,
|
toprow.prompt,
|
||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
toprow.ui_styles.dropdown,
|
toprow.ui_styles.dropdown,
|
||||||
steps,
|
|
||||||
sampler_name,
|
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
@ -461,8 +443,6 @@ def create_ui():
|
|||||||
txt2img_paste_fields = [
|
txt2img_paste_fields = [
|
||||||
PasteField(toprow.prompt, "Prompt", api="prompt"),
|
PasteField(toprow.prompt, "Prompt", api="prompt"),
|
||||||
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
|
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
|
||||||
PasteField(steps, "Steps", api="steps"),
|
|
||||||
PasteField(sampler_name, "Sampler", api="sampler_name"),
|
|
||||||
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
|
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
|
||||||
PasteField(width, "Size-1", api="width"),
|
PasteField(width, "Size-1", api="width"),
|
||||||
PasteField(height, "Size-2", api="height"),
|
PasteField(height, "Size-2", api="height"),
|
||||||
@ -488,11 +468,13 @@ def create_ui():
|
|||||||
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
|
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
steps = scripts.scripts_txt2img.script('Sampler').steps
|
||||||
|
|
||||||
txt2img_preview_params = [
|
txt2img_preview_params = [
|
||||||
toprow.prompt,
|
toprow.prompt,
|
||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
steps,
|
steps,
|
||||||
sampler_name,
|
scripts.scripts_txt2img.script('Sampler').sampler_name,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
scripts.scripts_txt2img.script('Seed').seed,
|
scripts.scripts_txt2img.script('Seed').seed,
|
||||||
width,
|
width,
|
||||||
@ -623,9 +605,6 @@ def create_ui():
|
|||||||
with FormRow():
|
with FormRow():
|
||||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||||
|
|
||||||
if category == "sampler":
|
|
||||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
|
||||||
|
|
||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||||
@ -754,8 +733,6 @@ def create_ui():
|
|||||||
inpaint_color_sketch_orig,
|
inpaint_color_sketch_orig,
|
||||||
init_img_inpaint,
|
init_img_inpaint,
|
||||||
init_mask_inpaint,
|
init_mask_inpaint,
|
||||||
steps,
|
|
||||||
sampler_name,
|
|
||||||
mask_blur,
|
mask_blur,
|
||||||
mask_alpha,
|
mask_alpha,
|
||||||
inpainting_fill,
|
inpainting_fill,
|
||||||
@ -840,6 +817,8 @@ def create_ui():
|
|||||||
**interrogate_args,
|
**interrogate_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
steps = scripts.scripts_img2img.script('Sampler').steps
|
||||||
|
|
||||||
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||||
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||||
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||||
@ -848,8 +827,6 @@ def create_ui():
|
|||||||
img2img_paste_fields = [
|
img2img_paste_fields = [
|
||||||
(toprow.prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(toprow.negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
(steps, "Steps"),
|
|
||||||
(sampler_name, "Sampler"),
|
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(image_cfg_scale, "Image CFG scale"),
|
(image_cfg_scale, "Image CFG scale"),
|
||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
|
@ -11,7 +11,7 @@ import numpy as np
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors
|
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_schedulers, errors
|
||||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -248,7 +248,7 @@ axis_options = [
|
|||||||
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
||||||
AxisOption("Sigma noise", float, apply_field("s_noise")),
|
AxisOption("Sigma noise", float, apply_field("s_noise")),
|
||||||
AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
|
AxisOption("Schedule type", str, apply_field("scheduler"), choices=lambda: [x.label for x in sd_schedulers.schedulers]),
|
||||||
AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
|
AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
|
||||||
AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
|
AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
|
||||||
AxisOption("Schedule rho", float, apply_override("rho")),
|
AxisOption("Schedule rho", float, apply_override("rho")),
|
||||||
|
Loading…
Reference in New Issue
Block a user