stable-diffusion-webui/modules/sd_samplers.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

55 lines
1.5 KiB
Python
Raw Normal View History

2023-08-09 02:07:18 +08:00
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
# imports for functions that previously were here and are used by other modules
2023-05-10 14:02:23 +08:00
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
2022-09-03 22:21:15 +08:00
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_timesteps.samplers_data_timesteps,
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers_map = {}
def find_sampler_config(name):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
return config
def create_sampler(name, model):
config = find_sampler_config(name)
assert config is not None, f'bad sampler name: {name}'
if model.is_sdxl and config.options.get("no_sdxl", False):
raise Exception(f"Sampler {config.name} is not supported for SDXL")
2022-10-06 19:12:52 +08:00
sampler = config.constructor(model)
sampler.config = config
2022-10-06 19:12:52 +08:00
return sampler
def set_samplers():
global samplers, samplers_for_img2img
hidden = set(shared.opts.hide_samplers)
samplers = [x for x in all_samplers if x.name not in hidden]
2023-08-09 02:07:18 +08:00
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden]
samplers_map.clear()
for sampler in all_samplers:
samplers_map[sampler.name.lower()] = sampler.name
for alias in sampler.aliases:
samplers_map[alias.lower()] = sampler.name
set_samplers()