mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-28 15:59:07 +08:00
Merge branch 'dev' into refiner
This commit is contained in:
commit
54c3e5c913
@ -1,5 +1,7 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import scripts, shared, ui_components, ui_settings
|
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
|
||||||
from modules.ui_components import FormColumn
|
from modules.ui_components import FormColumn
|
||||||
|
|
||||||
|
|
||||||
@ -19,18 +21,37 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
self.comps = []
|
self.comps = []
|
||||||
self.setting_names = []
|
self.setting_names = []
|
||||||
|
self.infotext_fields = []
|
||||||
|
|
||||||
|
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||||
|
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
|
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group():
|
||||||
for setting_name in shared.opts.extra_options:
|
|
||||||
|
row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols)
|
||||||
|
|
||||||
|
for row in range(row_count):
|
||||||
|
with gr.Row():
|
||||||
|
for col in range(shared.opts.extra_options_cols):
|
||||||
|
index = row * shared.opts.extra_options_cols + col
|
||||||
|
if index >= len(shared.opts.extra_options):
|
||||||
|
break
|
||||||
|
|
||||||
|
setting_name = shared.opts.extra_options[index]
|
||||||
|
|
||||||
with FormColumn():
|
with FormColumn():
|
||||||
comp = ui_settings.create_setting_component(setting_name)
|
comp = ui_settings.create_setting_component(setting_name)
|
||||||
|
|
||||||
self.comps.append(comp)
|
self.comps.append(comp)
|
||||||
self.setting_names.append(setting_name)
|
self.setting_names.append(setting_name)
|
||||||
|
|
||||||
|
setting_infotext_name = mapping.get(setting_name)
|
||||||
|
if setting_infotext_name is not None:
|
||||||
|
self.infotext_fields.append((comp, setting_infotext_name))
|
||||||
|
|
||||||
def get_settings_values():
|
def get_settings_values():
|
||||||
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
||||||
|
return res[0] if len(res) == 1 else res
|
||||||
|
|
||||||
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
||||||
|
|
||||||
@ -44,5 +65,8 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
||||||
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
|
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
|
||||||
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion").needs_restart()
|
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
|
||||||
|
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,6 +136,11 @@ function setupImageForLightbox(e) {
|
|||||||
var event = isFirefox ? 'mousedown' : 'click';
|
var event = isFirefox ? 'mousedown' : 'click';
|
||||||
|
|
||||||
e.addEventListener(event, function(evt) {
|
e.addEventListener(event, function(evt) {
|
||||||
|
if (evt.button == 1) {
|
||||||
|
open(evt.target.src);
|
||||||
|
evt.preventDefault();
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
||||||
|
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
||||||
|
@ -416,10 +416,15 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
if override_settings_component is not None:
|
if override_settings_component is not None:
|
||||||
|
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||||
|
|
||||||
def paste_settings(params):
|
def paste_settings(params):
|
||||||
vals = {}
|
vals = {}
|
||||||
|
|
||||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
for param_name, setting_name in infotext_to_setting_name_mapping:
|
||||||
|
if param_name in already_handled_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
v = params.get(param_name, None)
|
v = params.get(param_name, None)
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers, images as imgutil
|
from modules import images as imgutil
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
process_images(p)
|
process_images(p)
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
@ -172,7 +172,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
seed_resize_from_h=seed_resize_from_h,
|
seed_resize_from_h=seed_resize_from_h,
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
seed_resize_from_w=seed_resize_from_w,
|
||||||
seed_enable_extras=seed_enable_extras,
|
seed_enable_extras=seed_enable_extras,
|
||||||
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
sampler_name=sampler_name,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
|
@ -139,6 +139,27 @@ def check_run_python(code: str) -> bool:
|
|||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
def git_fix_workspace(dir, name):
|
||||||
|
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
||||||
|
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
||||||
|
try:
|
||||||
|
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not autofix:
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"{errdesc}, attempting autofix...")
|
||||||
|
git_fix_workspace(dir, name)
|
||||||
|
|
||||||
|
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||||
|
|
||||||
|
|
||||||
def git_clone(url, dir, name, commithash=None):
|
def git_clone(url, dir, name, commithash=None):
|
||||||
# TODO clone into temporary dir and move if successful
|
# TODO clone into temporary dir and move if successful
|
||||||
|
|
||||||
@ -146,12 +167,14 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
if commithash is None:
|
if commithash is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||||
if current_hash == commithash:
|
if current_hash == commithash:
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
|
||||||
|
run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||||
|
@ -1119,9 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
||||||
|
|
||||||
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
|
||||||
img2img_sampler_name = 'DDIM'
|
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||||
|
|
||||||
if self.latent_scale_mode is not None:
|
if self.latent_scale_mode is not None:
|
||||||
|
@ -5,7 +5,7 @@ from types import MethodType
|
|||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
@ -34,8 +34,6 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
|||||||
ldm.util.print = shared.ldm_print
|
ldm.util.print = shared.ldm_print
|
||||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
|
|
||||||
sd_hijack_inpainting.do_inpainting_hijack()
|
|
||||||
|
|
||||||
optimizers = []
|
optimizers = []
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
|
|
||||||
|
@ -1,95 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
import ldm.models.diffusion.ddpm
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
|
|
||||||
from ldm.models.diffusion.ddim import noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
|
|
||||||
def get_model_output(x, t):
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
|
||||||
e_t = self.model.apply_model(x, t, c)
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
|
||||||
|
|
||||||
if isinstance(c, dict):
|
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
|
||||||
c_in = {}
|
|
||||||
for k in c:
|
|
||||||
if isinstance(c[k], list):
|
|
||||||
c_in[k] = [
|
|
||||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
|
||||||
for i in range(len(c[k]))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
|
||||||
else:
|
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
|
||||||
|
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
|
||||||
|
|
||||||
if score_corrector is not None:
|
|
||||||
assert self.model.parameterization == "eps"
|
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
|
||||||
|
|
||||||
return e_t
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
|
||||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(e_t, index):
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
if quantize_denoised:
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
if dynamic_threshold is not None:
|
|
||||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.:
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
e_t = get_model_output(x, t)
|
|
||||||
if len(old_eps) == 0:
|
|
||||||
# Pseudo Improved Euler (2nd order)
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
|
||||||
e_t_next = get_model_output(x_prev, t_next)
|
|
||||||
e_t_prime = (e_t + e_t_next) / 2
|
|
||||||
elif len(old_eps) == 1:
|
|
||||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
|
||||||
elif len(old_eps) == 2:
|
|
||||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
|
||||||
elif len(old_eps) >= 3:
|
|
||||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
|
||||||
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
||||||
|
|
||||||
return x_prev, pred_x0, e_t
|
|
||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
|
@ -372,7 +372,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
sd_vae.delete_base_vae()
|
sd_vae.delete_base_vae()
|
||||||
sd_vae.clear_loaded_vae()
|
sd_vae.clear_loaded_vae()
|
||||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||||
sd_vae.load_vae(model, vae_file, vae_source)
|
sd_vae.load_vae(model, vae_file, vae_source)
|
||||||
timer.record("load VAE")
|
timer.record("load VAE")
|
||||||
|
|
||||||
@ -715,6 +715,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
print(f"Weights loaded in {timer.summary()}.")
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
model_data.set_sd_model(sd_model)
|
model_data.set_sd_model(sd_model)
|
||||||
|
sd_unet.apply_unet()
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
*sd_samplers_compvis.samplers_data_compvis,
|
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||||
]
|
]
|
||||||
all_samplers_map = {x.name: x for x in all_samplers}
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
samplers = []
|
samplers = []
|
||||||
samplers_for_img2img = []
|
samplers_for_img2img = []
|
||||||
samplers_map = {}
|
samplers_map = {}
|
||||||
|
samplers_hidden = {}
|
||||||
|
|
||||||
|
|
||||||
def find_sampler_config(name):
|
def find_sampler_config(name):
|
||||||
@ -38,13 +39,11 @@ def create_sampler(name, model):
|
|||||||
|
|
||||||
|
|
||||||
def set_samplers():
|
def set_samplers():
|
||||||
global samplers, samplers_for_img2img
|
global samplers, samplers_for_img2img, samplers_hidden
|
||||||
|
|
||||||
hidden = set(shared.opts.hide_samplers)
|
samplers_hidden = set(shared.opts.hide_samplers)
|
||||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
|
samplers = all_samplers
|
||||||
|
samplers_for_img2img = all_samplers
|
||||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
|
||||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
|
||||||
|
|
||||||
samplers_map.clear()
|
samplers_map.clear()
|
||||||
for sampler in all_samplers:
|
for sampler in all_samplers:
|
||||||
@ -53,4 +52,8 @@ def set_samplers():
|
|||||||
samplers_map[alias.lower()] = sampler.name
|
samplers_map[alias.lower()] = sampler.name
|
||||||
|
|
||||||
|
|
||||||
|
def visible_sampler_names():
|
||||||
|
return [x.name for x in samplers if x.name not in samplers_hidden]
|
||||||
|
|
||||||
|
|
||||||
set_samplers()
|
set_samplers()
|
||||||
|
203
modules/sd_samplers_cfg_denoiser.py
Normal file
203
modules/sd_samplers_cfg_denoiser.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import torch
|
||||||
|
from modules import prompt_parser, devices, sd_samplers_common
|
||||||
|
|
||||||
|
from modules.shared import opts, state
|
||||||
|
import modules.shared as shared
|
||||||
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||||
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||||
|
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||||
|
|
||||||
|
|
||||||
|
def catenate_conds(conds):
|
||||||
|
if not isinstance(conds[0], dict):
|
||||||
|
return torch.cat(conds)
|
||||||
|
|
||||||
|
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||||
|
|
||||||
|
|
||||||
|
def subscript_cond(cond, a, b):
|
||||||
|
if not isinstance(cond, dict):
|
||||||
|
return cond[a:b]
|
||||||
|
|
||||||
|
return {key: vec[a:b] for key, vec in cond.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def pad_cond(tensor, repeats, empty):
|
||||||
|
if not isinstance(tensor, dict):
|
||||||
|
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||||
|
|
||||||
|
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiser(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||||
|
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||||
|
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||||
|
negative prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, sampler):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
self.mask = None
|
||||||
|
self.nmask = None
|
||||||
|
self.init_latent = None
|
||||||
|
self.step = 0
|
||||||
|
self.image_cfg_scale = None
|
||||||
|
self.padded_cond_uncond = False
|
||||||
|
self.sampler = sampler
|
||||||
|
|
||||||
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
|
||||||
|
for i, conds in enumerate(conds_list):
|
||||||
|
for cond_index, weight in conds:
|
||||||
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||||
|
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||||
|
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
|
return x_out
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
|
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||||
|
# so is_edit_model is set to False to support AND composition.
|
||||||
|
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||||
|
|
||||||
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
x = self.init_latent * self.mask + self.nmask * x
|
||||||
|
|
||||||
|
batch_size = len(conds_list)
|
||||||
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||||
|
else:
|
||||||
|
image_uncond = image_cond
|
||||||
|
if isinstance(uncond, dict):
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||||
|
else:
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||||
|
|
||||||
|
if not is_edit_model:
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||||
|
|
||||||
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||||
|
cfg_denoiser_callback(denoiser_params)
|
||||||
|
x_in = denoiser_params.x
|
||||||
|
image_cond_in = denoiser_params.image_cond
|
||||||
|
sigma_in = denoiser_params.sigma
|
||||||
|
tensor = denoiser_params.text_cond
|
||||||
|
uncond = denoiser_params.text_uncond
|
||||||
|
skip_uncond = False
|
||||||
|
|
||||||
|
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||||
|
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||||
|
skip_uncond = True
|
||||||
|
x_in = x_in[:-batch_size]
|
||||||
|
sigma_in = sigma_in[:-batch_size]
|
||||||
|
|
||||||
|
self.padded_cond_uncond = False
|
||||||
|
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||||
|
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||||
|
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||||
|
|
||||||
|
if num_repeats < 0:
|
||||||
|
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
elif num_repeats > 0:
|
||||||
|
uncond = pad_cond(uncond, num_repeats, empty)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
|
||||||
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||||
|
if is_edit_model:
|
||||||
|
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||||
|
elif skip_uncond:
|
||||||
|
cond_in = tensor
|
||||||
|
else:
|
||||||
|
cond_in = catenate_conds([tensor, uncond])
|
||||||
|
|
||||||
|
if shared.batch_cond_uncond:
|
||||||
|
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = a + batch_size
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
|
|
||||||
|
if not is_edit_model:
|
||||||
|
c_crossattn = subscript_cond(tensor, a, b)
|
||||||
|
else:
|
||||||
|
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||||
|
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||||
|
|
||||||
|
if not skip_uncond:
|
||||||
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||||
|
|
||||||
|
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||||
|
if skip_uncond:
|
||||||
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||||
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||||
|
|
||||||
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||||
|
cfg_denoised_callback(denoised_params)
|
||||||
|
|
||||||
|
devices.test_for_nans(x_out, "unet")
|
||||||
|
|
||||||
|
if is_edit_model:
|
||||||
|
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||||
|
elif skip_uncond:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||||
|
else:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
|
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
if opts.live_preview_content == "Prompt":
|
||||||
|
preview = self.sampler.last_latent
|
||||||
|
elif opts.live_preview_content == "Negative prompt":
|
||||||
|
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
|
||||||
|
else:
|
||||||
|
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
sd_samplers_common.store_latent(preview)
|
||||||
|
|
||||||
|
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||||
|
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||||
|
denoised = after_cfg_callback_params.x
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
return denoised
|
||||||
|
|
@ -1,9 +1,11 @@
|
|||||||
from collections import namedtuple
|
import inspect
|
||||||
|
from collections import namedtuple, deque
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
|
import k_diffusion.sampling
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
@ -155,3 +157,137 @@ def apply_refiner(sampler):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class TorchHijack:
|
||||||
|
def __init__(self, sampler_noises):
|
||||||
|
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||||
|
# implementation.
|
||||||
|
self.sampler_noises = deque(sampler_noises)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item == 'randn_like':
|
||||||
|
return self.randn_like
|
||||||
|
|
||||||
|
if hasattr(torch, item):
|
||||||
|
return getattr(torch, item)
|
||||||
|
|
||||||
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
|
def randn_like(self, x):
|
||||||
|
if self.sampler_noises:
|
||||||
|
noise = self.sampler_noises.popleft()
|
||||||
|
if noise.shape == x.shape:
|
||||||
|
return noise
|
||||||
|
|
||||||
|
return devices.randn_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler:
|
||||||
|
def __init__(self, funcname):
|
||||||
|
self.funcname = funcname
|
||||||
|
self.func = funcname
|
||||||
|
self.extra_params = []
|
||||||
|
self.sampler_noises = None
|
||||||
|
self.stop_at = None
|
||||||
|
self.eta = None
|
||||||
|
self.config = None # set by the function calling the constructor
|
||||||
|
self.last_latent = None
|
||||||
|
self.s_min_uncond = None
|
||||||
|
self.s_churn = 0.0
|
||||||
|
self.s_tmin = 0.0
|
||||||
|
self.s_tmax = float('inf')
|
||||||
|
self.s_noise = 1.0
|
||||||
|
|
||||||
|
self.eta_option_field = 'eta_ancestral'
|
||||||
|
self.eta_infotext_field = 'Eta'
|
||||||
|
|
||||||
|
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||||
|
|
||||||
|
self.model_wrap = None
|
||||||
|
self.model_wrap_cfg = None
|
||||||
|
|
||||||
|
def callback_state(self, d):
|
||||||
|
step = d['i']
|
||||||
|
|
||||||
|
if self.stop_at is not None and step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
state.sampling_step = step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except RecursionError:
|
||||||
|
print(
|
||||||
|
'Encountered RecursionError during sampling, returning last latent. '
|
||||||
|
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||||
|
'You should try to use a smaller rho value instead.'
|
||||||
|
)
|
||||||
|
return self.last_latent
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
|
def number_of_needed_noises(self, p):
|
||||||
|
return p.steps
|
||||||
|
|
||||||
|
def initialize(self, p) -> dict:
|
||||||
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
self.model_wrap_cfg.step = 0
|
||||||
|
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||||
|
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||||
|
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||||
|
|
||||||
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||||
|
|
||||||
|
extra_params_kwargs = {}
|
||||||
|
for param_name in self.extra_params:
|
||||||
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||||
|
|
||||||
|
if 'eta' in inspect.signature(self.func).parameters:
|
||||||
|
if self.eta != 1.0:
|
||||||
|
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
||||||
|
|
||||||
|
extra_params_kwargs['eta'] = self.eta
|
||||||
|
|
||||||
|
if len(self.extra_params) > 0:
|
||||||
|
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||||
|
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||||
|
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||||
|
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||||
|
|
||||||
|
if s_churn != self.s_churn:
|
||||||
|
extra_params_kwargs['s_churn'] = s_churn
|
||||||
|
p.s_churn = s_churn
|
||||||
|
p.extra_generation_params['Sigma churn'] = s_churn
|
||||||
|
if s_tmin != self.s_tmin:
|
||||||
|
extra_params_kwargs['s_tmin'] = s_tmin
|
||||||
|
p.s_tmin = s_tmin
|
||||||
|
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||||
|
if s_tmax != self.s_tmax:
|
||||||
|
extra_params_kwargs['s_tmax'] = s_tmax
|
||||||
|
p.s_tmax = s_tmax
|
||||||
|
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||||
|
if s_noise != self.s_noise:
|
||||||
|
extra_params_kwargs['s_noise'] = s_noise
|
||||||
|
p.s_noise = s_noise
|
||||||
|
p.extra_generation_params['Sigma noise'] = s_noise
|
||||||
|
|
||||||
|
return extra_params_kwargs
|
||||||
|
|
||||||
|
def create_noise_sampler(self, x, sigmas, p):
|
||||||
|
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||||
|
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||||
|
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,232 +0,0 @@
|
|||||||
import math
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modules.shared import state
|
|
||||||
from modules import sd_samplers_common, prompt_parser, shared
|
|
||||||
import modules.models.diffusion.uni_pc
|
|
||||||
|
|
||||||
|
|
||||||
samplers_data_compvis = [
|
|
||||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
|
||||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
|
||||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
|
||||||
def __init__(self, constructor, sd_model):
|
|
||||||
self.p = None
|
|
||||||
self.sampler = constructor(shared.sd_model)
|
|
||||||
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
|
||||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
|
||||||
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
|
||||||
self.orig_p_sample_ddim = None
|
|
||||||
if self.is_plms:
|
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms
|
|
||||||
elif self.is_ddim:
|
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
|
||||||
self.sampler_noises = None
|
|
||||||
self.steps = None
|
|
||||||
self.step = 0
|
|
||||||
self.stop_at = None
|
|
||||||
self.eta = None
|
|
||||||
self.config = None
|
|
||||||
self.last_latent = None
|
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
|
||||||
self.steps = steps
|
|
||||||
state.sampling_steps = steps
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except sd_samplers_common.InterruptedException:
|
|
||||||
return self.last_latent
|
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
|
||||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
|
||||||
|
|
||||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def update_inner_model(self):
|
|
||||||
self.sampler.model = shared.sd_model
|
|
||||||
|
|
||||||
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
if self.stop_at is not None and self.step > self.stop_at:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
|
||||||
image_conditioning = None
|
|
||||||
uc_image_conditioning = None
|
|
||||||
if isinstance(cond, dict):
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
image_conditioning = cond["c_adm"]
|
|
||||||
uc_image_conditioning = unconditional_conditioning["c_adm"]
|
|
||||||
else:
|
|
||||||
image_conditioning = cond["c_concat"][0]
|
|
||||||
cond = cond["c_crossattn"][0]
|
|
||||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
|
||||||
|
|
||||||
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
|
||||||
cond = tensor
|
|
||||||
|
|
||||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
|
||||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
|
||||||
# not 100% correct but should work well enough
|
|
||||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
|
||||||
last_vector = unconditional_conditioning[:, -1:]
|
|
||||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
|
||||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
|
||||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
|
||||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
|
||||||
x = img_orig * self.mask + self.nmask * x
|
|
||||||
|
|
||||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
|
||||||
# Note that they need to be lists because it just concatenates them later.
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
|
|
||||||
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
|
|
||||||
else:
|
|
||||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
|
||||||
|
|
||||||
return x, ts, cond, unconditional_conditioning
|
|
||||||
|
|
||||||
def update_step(self, last_latent):
|
|
||||||
if self.mask is not None:
|
|
||||||
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
|
||||||
else:
|
|
||||||
self.last_latent = last_latent
|
|
||||||
|
|
||||||
sd_samplers_common.store_latent(self.last_latent)
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
state.sampling_step = self.step
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
def after_sample(self, x, ts, cond, uncond, res):
|
|
||||||
if not self.is_unipc:
|
|
||||||
self.update_step(res[1])
|
|
||||||
|
|
||||||
return x, ts, cond, uncond, res
|
|
||||||
|
|
||||||
def unipc_after_update(self, x, model_x):
|
|
||||||
self.update_step(x)
|
|
||||||
|
|
||||||
def initialize(self, p):
|
|
||||||
self.p = p
|
|
||||||
|
|
||||||
if self.is_ddim:
|
|
||||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
|
||||||
else:
|
|
||||||
self.eta = 0.0
|
|
||||||
|
|
||||||
if self.eta != 0.0:
|
|
||||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
|
||||||
|
|
||||||
if self.is_unipc:
|
|
||||||
keys = [
|
|
||||||
('UniPC variant', 'uni_pc_variant'),
|
|
||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
|
||||||
('UniPC order', 'uni_pc_order'),
|
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
|
||||||
]
|
|
||||||
|
|
||||||
for name, key in keys:
|
|
||||||
v = getattr(shared.opts, key)
|
|
||||||
if v != shared.opts.get_default(key):
|
|
||||||
p.extra_generation_params[name] = v
|
|
||||||
|
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
|
||||||
if hasattr(self.sampler, fieldname):
|
|
||||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
|
||||||
if self.is_unipc:
|
|
||||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
|
||||||
|
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_steps_if_invalid(self, p, num_steps):
|
|
||||||
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
|
|
||||||
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
|
|
||||||
num_steps = shared.opts.uni_pc_order
|
|
||||||
valid_step = 999 / (1000 // num_steps)
|
|
||||||
if valid_step == math.floor(valid_step):
|
|
||||||
return int(valid_step) + 1
|
|
||||||
|
|
||||||
return num_steps
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
|
||||||
steps = self.adjust_steps_if_invalid(p, steps)
|
|
||||||
self.initialize(p)
|
|
||||||
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
|
||||||
|
|
||||||
self.init_latent = x
|
|
||||||
self.last_latent = x
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
|
|
||||||
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
|
|
||||||
else:
|
|
||||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
self.initialize(p)
|
|
||||||
|
|
||||||
self.init_latent = None
|
|
||||||
self.last_latent = x
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
|
||||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
|
|
||||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
|
|
||||||
else:
|
|
||||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
|
||||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
|
||||||
|
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
|
||||||
|
|
||||||
return samples_ddim
|
|
@ -1,17 +1,16 @@
|
|||||||
from collections import deque
|
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
|
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
||||||
|
|
||||||
from modules.processing import StableDiffusionProcessing
|
from modules.shared import opts
|
||||||
from modules.shared import opts, state
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
|
||||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
|
||||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
samplers_k_diffusion = [
|
||||||
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||||
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
@ -28,10 +27,6 @@ samplers_k_diffusion = [
|
|||||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
|
||||||
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
|
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -57,342 +52,17 @@ k_diffusion_scheduler = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def catenate_conds(conds):
|
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||||
if not isinstance(conds[0], dict):
|
|
||||||
return torch.cat(conds)
|
|
||||||
|
|
||||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
|
||||||
|
|
||||||
|
|
||||||
def subscript_cond(cond, a, b):
|
|
||||||
if not isinstance(cond, dict):
|
|
||||||
return cond[a:b]
|
|
||||||
|
|
||||||
return {key: vec[a:b] for key, vec in cond.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def pad_cond(tensor, repeats, empty):
|
|
||||||
if not isinstance(tensor, dict):
|
|
||||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
|
||||||
|
|
||||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
|
||||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
|
||||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
|
||||||
negative prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sampler):
|
|
||||||
super().__init__()
|
|
||||||
self.sampler = sampler
|
|
||||||
self.model_wrap = None
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
|
||||||
self.steps = None
|
|
||||||
self.step = 0
|
|
||||||
self.image_cfg_scale = None
|
|
||||||
self.padded_cond_uncond = False
|
|
||||||
self.p = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def inner_model(self):
|
|
||||||
if self.model_wrap is None:
|
|
||||||
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
|
||||||
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
|
||||||
|
|
||||||
return self.model_wrap
|
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
||||||
denoised = torch.clone(denoised_uncond)
|
|
||||||
|
|
||||||
for i, conds in enumerate(conds_list):
|
|
||||||
for cond_index, weight in conds:
|
|
||||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
|
||||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
|
||||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
def update_inner_model(self):
|
|
||||||
self.model_wrap = None
|
|
||||||
|
|
||||||
c, uc = self.p.get_conds()
|
|
||||||
self.sampler.sampler_extra_args['cond'] = c
|
|
||||||
self.sampler.sampler_extra_args['uncond'] = uc
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
if sd_samplers_common.apply_refiner(self):
|
|
||||||
cond = self.sampler.sampler_extra_args['cond']
|
|
||||||
uncond = self.sampler.sampler_extra_args['uncond']
|
|
||||||
|
|
||||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
|
||||||
# so is_edit_model is set to False to support AND composition.
|
|
||||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
|
||||||
|
|
||||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
|
||||||
|
|
||||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
|
||||||
image_uncond = torch.zeros_like(image_cond)
|
|
||||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
|
||||||
else:
|
|
||||||
image_uncond = image_cond
|
|
||||||
if isinstance(uncond, dict):
|
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
|
||||||
else:
|
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
|
||||||
|
|
||||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
|
||||||
cfg_denoiser_callback(denoiser_params)
|
|
||||||
x_in = denoiser_params.x
|
|
||||||
image_cond_in = denoiser_params.image_cond
|
|
||||||
sigma_in = denoiser_params.sigma
|
|
||||||
tensor = denoiser_params.text_cond
|
|
||||||
uncond = denoiser_params.text_uncond
|
|
||||||
skip_uncond = False
|
|
||||||
|
|
||||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
|
||||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
|
||||||
skip_uncond = True
|
|
||||||
x_in = x_in[:-batch_size]
|
|
||||||
sigma_in = sigma_in[:-batch_size]
|
|
||||||
|
|
||||||
self.padded_cond_uncond = False
|
|
||||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
|
||||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
|
||||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
|
||||||
|
|
||||||
if num_repeats < 0:
|
|
||||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
|
||||||
self.padded_cond_uncond = True
|
|
||||||
elif num_repeats > 0:
|
|
||||||
uncond = pad_cond(uncond, num_repeats, empty)
|
|
||||||
self.padded_cond_uncond = True
|
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
|
||||||
if is_edit_model:
|
|
||||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
|
||||||
elif skip_uncond:
|
|
||||||
cond_in = tensor
|
|
||||||
else:
|
|
||||||
cond_in = catenate_conds([tensor, uncond])
|
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
|
||||||
else:
|
|
||||||
x_out = torch.zeros_like(x_in)
|
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
|
||||||
a = batch_offset
|
|
||||||
b = a + batch_size
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
|
||||||
else:
|
|
||||||
x_out = torch.zeros_like(x_in)
|
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
|
||||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
|
||||||
a = batch_offset
|
|
||||||
b = min(a + batch_size, tensor.shape[0])
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
c_crossattn = subscript_cond(tensor, a, b)
|
|
||||||
else:
|
|
||||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
|
||||||
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
|
||||||
|
|
||||||
if not skip_uncond:
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
|
||||||
|
|
||||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
|
||||||
if skip_uncond:
|
|
||||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
|
||||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
|
||||||
|
|
||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
|
||||||
cfg_denoised_callback(denoised_params)
|
|
||||||
|
|
||||||
devices.test_for_nans(x_out, "unet")
|
|
||||||
|
|
||||||
if opts.live_preview_content == "Prompt":
|
|
||||||
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
|
||||||
elif opts.live_preview_content == "Negative prompt":
|
|
||||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
|
||||||
|
|
||||||
if is_edit_model:
|
|
||||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
|
||||||
elif skip_uncond:
|
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
|
||||||
else:
|
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
|
||||||
|
|
||||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
|
||||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
|
||||||
denoised = after_cfg_callback_params.x
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
|
||||||
def __init__(self, sampler_noises):
|
|
||||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
|
||||||
# implementation.
|
|
||||||
self.sampler_noises = deque(sampler_noises)
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
if item == 'randn_like':
|
|
||||||
return self.randn_like
|
|
||||||
|
|
||||||
if hasattr(torch, item):
|
|
||||||
return getattr(torch, item)
|
|
||||||
|
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
|
||||||
|
|
||||||
def randn_like(self, x):
|
|
||||||
if self.sampler_noises:
|
|
||||||
noise = self.sampler_noises.popleft()
|
|
||||||
if noise.shape == x.shape:
|
|
||||||
return noise
|
|
||||||
|
|
||||||
return devices.randn_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
|
||||||
def __init__(self, funcname, sd_model):
|
def __init__(self, funcname, sd_model):
|
||||||
self.p = None
|
|
||||||
self.funcname = funcname
|
super().__init__(funcname)
|
||||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
self.sampler_extra_args = {}
|
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||||
self.model_wrap_cfg = CFGDenoiser(self)
|
|
||||||
self.model_wrap = self.model_wrap_cfg.inner_model
|
|
||||||
self.sampler_noises = None
|
|
||||||
self.stop_at = None
|
|
||||||
self.eta = None
|
|
||||||
self.config = None # set by the function calling the constructor
|
|
||||||
self.last_latent = None
|
|
||||||
self.s_min_uncond = None
|
|
||||||
|
|
||||||
# NOTE: These are also defined in the StableDiffusionProcessing class.
|
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||||
# They should have been here to begin with but we're going to
|
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||||
# leave that class __init__ signature alone.
|
self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self)
|
||||||
self.s_churn = 0.0
|
|
||||||
self.s_tmin = 0.0
|
|
||||||
self.s_tmax = float('inf')
|
|
||||||
self.s_noise = 1.0
|
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
|
||||||
|
|
||||||
def callback_state(self, d):
|
|
||||||
step = d['i']
|
|
||||||
latent = d["denoised"]
|
|
||||||
if opts.live_preview_content == "Combined":
|
|
||||||
sd_samplers_common.store_latent(latent)
|
|
||||||
self.last_latent = latent
|
|
||||||
|
|
||||||
if self.stop_at is not None and step > self.stop_at:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
state.sampling_step = step
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
|
||||||
self.model_wrap_cfg.steps = steps
|
|
||||||
state.sampling_steps = steps
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except RecursionError:
|
|
||||||
print(
|
|
||||||
'Encountered RecursionError during sampling, returning last latent. '
|
|
||||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
|
||||||
'You should try to use a smaller rho value instead.'
|
|
||||||
)
|
|
||||||
return self.last_latent
|
|
||||||
except sd_samplers_common.InterruptedException:
|
|
||||||
return self.last_latent
|
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
|
||||||
return p.steps
|
|
||||||
|
|
||||||
def initialize(self, p: StableDiffusionProcessing):
|
|
||||||
self.p = p
|
|
||||||
self.model_wrap_cfg.p = p
|
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
self.model_wrap_cfg.step = 0
|
|
||||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
|
||||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
|
||||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
|
||||||
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
|
||||||
for param_name in self.extra_params:
|
|
||||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
|
||||||
|
|
||||||
if 'eta' in inspect.signature(self.func).parameters:
|
|
||||||
if self.eta != 1.0:
|
|
||||||
p.extra_generation_params["Eta"] = self.eta
|
|
||||||
|
|
||||||
extra_params_kwargs['eta'] = self.eta
|
|
||||||
|
|
||||||
if len(self.extra_params) > 0:
|
|
||||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
|
||||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
|
||||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
|
||||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
|
||||||
|
|
||||||
if s_churn != self.s_churn:
|
|
||||||
extra_params_kwargs['s_churn'] = s_churn
|
|
||||||
p.s_churn = s_churn
|
|
||||||
p.extra_generation_params['Sigma churn'] = s_churn
|
|
||||||
if s_tmin != self.s_tmin:
|
|
||||||
extra_params_kwargs['s_tmin'] = s_tmin
|
|
||||||
p.s_tmin = s_tmin
|
|
||||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
|
||||||
if s_tmax != self.s_tmax:
|
|
||||||
extra_params_kwargs['s_tmax'] = s_tmax
|
|
||||||
p.s_tmax = s_tmax
|
|
||||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
|
||||||
if s_noise != self.s_noise:
|
|
||||||
extra_params_kwargs['s_noise'] = s_noise
|
|
||||||
p.s_noise = s_noise
|
|
||||||
p.extra_generation_params['Sigma noise'] = s_noise
|
|
||||||
|
|
||||||
return extra_params_kwargs
|
|
||||||
|
|
||||||
def get_sigmas(self, p, steps):
|
def get_sigmas(self, p, steps):
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
@ -444,22 +114,12 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def create_noise_sampler(self, x, sigmas, p):
|
|
||||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
|
||||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
|
||||||
return None
|
|
||||||
|
|
||||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
|
||||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
|
||||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
sigmas = self.get_sigmas(p, steps)
|
sigmas = self.get_sigmas(p, steps)
|
||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
|
|
||||||
xi = x + noise * sigma_sched[0]
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
@ -508,12 +168,14 @@ class KDiffusionSampler:
|
|||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'n' in parameters:
|
||||||
|
extra_params_kwargs['n'] = steps
|
||||||
|
|
||||||
if 'sigma_min' in parameters:
|
if 'sigma_min' in parameters:
|
||||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||||
if 'n' in parameters:
|
|
||||||
extra_params_kwargs['n'] = steps
|
if 'sigmas' in parameters:
|
||||||
else:
|
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|
||||||
if self.config.options.get('brownian_noise', False):
|
if self.config.options.get('brownian_noise', False):
|
||||||
@ -535,3 +197,4 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
147
modules/sd_samplers_timesteps.py
Normal file
147
modules/sd_samplers_timesteps.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import torch
|
||||||
|
import inspect
|
||||||
|
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
|
||||||
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
|
||||||
|
|
||||||
|
from modules.shared import opts
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
samplers_timesteps = [
|
||||||
|
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||||
|
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
|
||||||
|
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
samplers_data_timesteps = [
|
||||||
|
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
|
||||||
|
for label, funcname, aliases, options in samplers_timesteps
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisTimestepsDenoiser(torch.nn.Module):
|
||||||
|
def __init__(self, model, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def forward(self, input, timesteps, **kwargs):
|
||||||
|
return self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||||
|
def __init__(self, model, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||||
|
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
|
||||||
|
|
||||||
|
def forward(self, input, timesteps, **kwargs):
|
||||||
|
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||||
|
e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||||
|
|
||||||
|
def __init__(self, model, sampler):
|
||||||
|
super().__init__(model, sampler)
|
||||||
|
|
||||||
|
self.alphas = model.inner_model.alphas_cumprod
|
||||||
|
|
||||||
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
|
ts = int(sigma.item())
|
||||||
|
|
||||||
|
s_in = x_in.new_ones([x_in.shape[0]])
|
||||||
|
a_t = self.alphas[ts].item() * s_in
|
||||||
|
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||||
|
|
||||||
|
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
||||||
|
|
||||||
|
return pred_x0
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisSampler(sd_samplers_common.Sampler):
|
||||||
|
def __init__(self, funcname, sd_model):
|
||||||
|
super().__init__(funcname)
|
||||||
|
|
||||||
|
self.eta_option_field = 'eta_ddim'
|
||||||
|
self.eta_infotext_field = 'Eta DDIM'
|
||||||
|
|
||||||
|
denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||||
|
self.model_wrap = denoiser(sd_model)
|
||||||
|
self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
|
||||||
|
|
||||||
|
def get_timesteps(self, p, steps):
|
||||||
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
|
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||||
|
discard_next_to_last_sigma = True
|
||||||
|
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||||
|
|
||||||
|
steps += 1 if discard_next_to_last_sigma else 0
|
||||||
|
|
||||||
|
timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
|
||||||
|
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
|
timesteps = self.get_timesteps(p, steps)
|
||||||
|
timesteps_sched = timesteps[:t_enc]
|
||||||
|
|
||||||
|
alphas_cumprod = shared.sd_model.alphas_cumprod
|
||||||
|
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
|
||||||
|
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
|
||||||
|
|
||||||
|
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'timesteps' in parameters:
|
||||||
|
extra_params_kwargs['timesteps'] = timesteps_sched
|
||||||
|
if 'is_img2img' in parameters:
|
||||||
|
extra_params_kwargs['is_img2img'] = True
|
||||||
|
|
||||||
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
|
extra_args = {
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale,
|
||||||
|
's_min_uncond': self.s_min_uncond
|
||||||
|
}
|
||||||
|
|
||||||
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps = steps or p.steps
|
||||||
|
timesteps = self.get_timesteps(p, steps)
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'timesteps' in parameters:
|
||||||
|
extra_params_kwargs['timesteps'] = timesteps
|
||||||
|
|
||||||
|
self.last_latent = x
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale,
|
||||||
|
's_min_uncond': self.s_min_uncond
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
135
modules/sd_samplers_timesteps_impl.py
Normal file
135
modules/sd_samplers_timesteps_impl.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import k_diffusion.sampling
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.models.diffusion.uni_pc import uni_pc
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
alphas = alphas_cumprod[timesteps]
|
||||||
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
||||||
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
|
index = len(timesteps) - 1 - i
|
||||||
|
|
||||||
|
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||||
|
|
||||||
|
a_t = alphas[index].item() * s_in
|
||||||
|
a_prev = alphas_prev[index].item() * s_in
|
||||||
|
sigma_t = sigmas[index].item() * s_in
|
||||||
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
||||||
|
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
|
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||||
|
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
alphas = alphas_cumprod[timesteps]
|
||||||
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
||||||
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = alphas[index].item() * s_in
|
||||||
|
a_prev = alphas_prev[index].item() * s_in
|
||||||
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev).sqrt() * e_t
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
|
index = len(timesteps) - 1 - i
|
||||||
|
ts = timesteps[index].item() * s_in
|
||||||
|
t_next = timesteps[max(index - 1, 0)].item() * s_in
|
||||||
|
|
||||||
|
e_t = model(x, ts, **extra_args)
|
||||||
|
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = model(x_prev, t_next, **extra_args)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
else:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
|
||||||
|
x = x_prev
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UniPCCFG(uni_pc.UniPC):
|
||||||
|
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
|
||||||
|
super().__init__(None, *args, **kwargs)
|
||||||
|
|
||||||
|
def after_update(x, model_x):
|
||||||
|
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
|
||||||
|
self.index += 1
|
||||||
|
|
||||||
|
self.cfg_model = cfg_model
|
||||||
|
self.extra_args = extra_args
|
||||||
|
self.callback = callback
|
||||||
|
self.index = 0
|
||||||
|
self.after_update = after_update
|
||||||
|
|
||||||
|
def get_model_input_time(self, t_continuous):
|
||||||
|
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
|
||||||
|
|
||||||
|
def model(self, x, t):
|
||||||
|
t_input = self.get_model_input_time(t)
|
||||||
|
|
||||||
|
res = self.cfg_model(x, t_input, **self.extra_args)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
|
||||||
|
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||||
|
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
||||||
|
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
|
||||||
|
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||||
|
|
||||||
|
return x
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def resolve_vae(checkpoint_file):
|
@dataclass
|
||||||
if shared.cmd_opts.vae_path is not None:
|
class VaeResolution:
|
||||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
vae: str = None
|
||||||
|
source: str = None
|
||||||
|
resolved: bool = True
|
||||||
|
|
||||||
|
def tuple(self):
|
||||||
|
return self.vae, self.source
|
||||||
|
|
||||||
|
|
||||||
|
def is_automatic():
|
||||||
|
return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_from_setting() -> VaeResolution:
|
||||||
|
if shared.opts.sd_vae == "None":
|
||||||
|
return VaeResolution()
|
||||||
|
|
||||||
|
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||||
|
if vae_from_options is not None:
|
||||||
|
return VaeResolution(vae_from_options, 'specified in settings')
|
||||||
|
|
||||||
|
if not is_automatic():
|
||||||
|
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||||
|
|
||||||
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
|
||||||
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||||
vae_metadata = metadata.get("vae", None)
|
vae_metadata = metadata.get("vae", None)
|
||||||
if vae_metadata is not None and vae_metadata != "Automatic":
|
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||||
if vae_metadata == "None":
|
if vae_metadata == "None":
|
||||||
return None, None
|
return VaeResolution()
|
||||||
|
|
||||||
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||||
if vae_from_metadata is not None:
|
if vae_from_metadata is not None:
|
||||||
return vae_from_metadata, "from user metadata"
|
return VaeResolution(vae_from_metadata, "from user metadata")
|
||||||
|
|
||||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
|
||||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
||||||
return vae_near_checkpoint, 'found near the checkpoint'
|
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
|
||||||
|
|
||||||
if shared.opts.sd_vae == "None":
|
return VaeResolution(resolved=False)
|
||||||
return None, None
|
|
||||||
|
|
||||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
|
||||||
if vae_from_options is not None:
|
|
||||||
return vae_from_options, 'specified in settings'
|
|
||||||
|
|
||||||
if not is_automatic:
|
def resolve_vae(checkpoint_file) -> VaeResolution:
|
||||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
if shared.cmd_opts.vae_path is not None:
|
||||||
|
return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
|
||||||
|
|
||||||
return None, None
|
if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
|
||||||
|
return resolve_vae_from_setting()
|
||||||
|
|
||||||
|
res = resolve_vae_from_user_metadata(checkpoint_file)
|
||||||
|
if res.resolved:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res = resolve_vae_near_checkpoint(checkpoint_file)
|
||||||
|
if res.resolved:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res = resolve_vae_from_setting()
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def load_vae_dict(filename, map_location):
|
def load_vae_dict(filename, map_location):
|
||||||
@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
|
|
||||||
if vae_file == unspecified:
|
if vae_file == unspecified:
|
||||||
vae_file, vae_source = resolve_vae(checkpoint_file)
|
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
|
||||||
else:
|
else:
|
||||||
vae_source = "from function argument"
|
vae_source = "from function argument"
|
||||||
|
|
||||||
|
@ -422,6 +422,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
options_templates.update(options_section(('system', "System"), {
|
||||||
|
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||||
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||||
@ -481,7 +482,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
|||||||
"""),
|
"""),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
|
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
|
||||||
@ -610,14 +611,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
||||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}),
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}).info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}).info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf"),
|
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}).info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
|
||||||
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
|
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||||
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
|
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||||
@ -735,6 +736,10 @@ class Options:
|
|||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
self.data = json.load(file)
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
# 1.6.0 VAE defaults
|
||||||
|
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||||
|
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||||
|
|
||||||
# 1.1.1 quicksettings list migration
|
# 1.1.1 quicksettings list migration
|
||||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules import sd_samplers, processing
|
from modules import processing
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
@ -25,7 +25,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
seed_resize_from_h=seed_resize_from_h,
|
seed_resize_from_h=seed_resize_from_h,
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
seed_resize_from_w=seed_resize_from_w,
|
||||||
seed_enable_extras=seed_enable_extras,
|
seed_enable_extras=seed_enable_extras,
|
||||||
sampler_name=sd_samplers.samplers[sampler_index].name,
|
sampler_name=sampler_name,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
@ -42,7 +42,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
hr_resize_x=hr_resize_x,
|
hr_resize_x=hr_resize_x,
|
||||||
hr_resize_y=hr_resize_y,
|
hr_resize_y=hr_resize_y,
|
||||||
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||||
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
hr_sampler_name=hr_sampler_name,
|
||||||
hr_prompt=hr_prompt,
|
hr_prompt=hr_prompt,
|
||||||
hr_negative_prompt=hr_negative_prompt,
|
hr_negative_prompt=hr_negative_prompt,
|
||||||
override_settings=override_settings,
|
override_settings=override_settings,
|
||||||
|
@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
|||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import gradio_extensons # noqa: F401
|
from modules import gradio_extensons # noqa: F401
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
@ -29,7 +29,6 @@ import modules.shared as shared
|
|||||||
import modules.images
|
import modules.images
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
|
||||||
create_setting_component = ui_settings.create_setting_component
|
create_setting_component = ui_settings.create_setting_component
|
||||||
@ -41,6 +40,9 @@ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else
|
|||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
mimetypes.add_type('application/javascript', '.js')
|
||||||
|
|
||||||
|
# Likewise, add explicit content-type header for certain missing image types
|
||||||
|
mimetypes.add_type('image/webp', '.webp')
|
||||||
|
|
||||||
if not cmd_opts.share and not cmd_opts.listen:
|
if not cmd_opts.share and not cmd_opts.listen:
|
||||||
# fix gradio phoning home
|
# fix gradio phoning home
|
||||||
gradio.utils.version_check = lambda: None
|
gradio.utils.version_check = lambda: None
|
||||||
@ -357,14 +359,14 @@ def create_output_panel(tabname, outdir):
|
|||||||
def create_sampler_and_steps_selection(choices, tabname):
|
def create_sampler_and_steps_selection(choices, tabname):
|
||||||
if opts.samplers_in_dropdown:
|
if opts.samplers_in_dropdown:
|
||||||
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
||||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||||
else:
|
else:
|
||||||
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||||
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
||||||
|
|
||||||
return steps, sampler_index
|
return steps, sampler_name
|
||||||
|
|
||||||
|
|
||||||
def ordered_ui_categories():
|
def ordered_ui_categories():
|
||||||
@ -405,13 +407,13 @@ def create_ui():
|
|||||||
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
|
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||||
scripts.scripts_txt2img.prepare_ui()
|
scripts.scripts_txt2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
|
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
||||||
|
|
||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
@ -457,7 +459,7 @@ def create_ui():
|
|||||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||||
|
|
||||||
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
@ -517,7 +519,7 @@ def create_ui():
|
|||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
toprow.ui_styles.dropdown,
|
toprow.ui_styles.dropdown,
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_name,
|
||||||
restore_faces,
|
restore_faces,
|
||||||
tiling,
|
tiling,
|
||||||
batch_count,
|
batch_count,
|
||||||
@ -535,7 +537,7 @@ def create_ui():
|
|||||||
hr_resize_x,
|
hr_resize_x,
|
||||||
hr_resize_y,
|
hr_resize_y,
|
||||||
hr_checkpoint_name,
|
hr_checkpoint_name,
|
||||||
hr_sampler_index,
|
hr_sampler_name,
|
||||||
hr_prompt,
|
hr_prompt,
|
||||||
hr_negative_prompt,
|
hr_negative_prompt,
|
||||||
override_settings,
|
override_settings,
|
||||||
@ -580,7 +582,7 @@ def create_ui():
|
|||||||
(toprow.prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(toprow.negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
(steps, "Steps"),
|
(steps, "Steps"),
|
||||||
(sampler_index, "Sampler"),
|
(sampler_name, "Sampler"),
|
||||||
(restore_faces, "Face restoration"),
|
(restore_faces, "Face restoration"),
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(seed, "Seed"),
|
(seed, "Seed"),
|
||||||
@ -602,7 +604,7 @@ def create_ui():
|
|||||||
(hr_resize_x, "Hires resize-1"),
|
(hr_resize_x, "Hires resize-1"),
|
||||||
(hr_resize_y, "Hires resize-2"),
|
(hr_resize_y, "Hires resize-2"),
|
||||||
(hr_checkpoint_name, "Hires checkpoint"),
|
(hr_checkpoint_name, "Hires checkpoint"),
|
||||||
(hr_sampler_index, "Hires sampler"),
|
(hr_sampler_name, "Hires sampler"),
|
||||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||||
(hr_prompt, "Hires prompt"),
|
(hr_prompt, "Hires prompt"),
|
||||||
(hr_negative_prompt, "Hires negative prompt"),
|
(hr_negative_prompt, "Hires negative prompt"),
|
||||||
@ -618,7 +620,7 @@ def create_ui():
|
|||||||
toprow.prompt,
|
toprow.prompt,
|
||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_name,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
seed,
|
seed,
|
||||||
width,
|
width,
|
||||||
@ -741,7 +743,7 @@ def create_ui():
|
|||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
|
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
||||||
|
|
||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
@ -873,7 +875,7 @@ def create_ui():
|
|||||||
init_img_inpaint,
|
init_img_inpaint,
|
||||||
init_mask_inpaint,
|
init_mask_inpaint,
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_name,
|
||||||
mask_blur,
|
mask_blur,
|
||||||
mask_alpha,
|
mask_alpha,
|
||||||
inpainting_fill,
|
inpainting_fill,
|
||||||
@ -969,7 +971,7 @@ def create_ui():
|
|||||||
(toprow.prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(toprow.negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
(steps, "Steps"),
|
(steps, "Steps"),
|
||||||
(sampler_index, "Sampler"),
|
(sampler_name, "Sampler"),
|
||||||
(restore_faces, "Face restoration"),
|
(restore_faces, "Face restoration"),
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(image_cfg_scale, "Image CFG scale"),
|
(image_cfg_scale, "Image CFG scale"),
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import ui_extra_networks_user_metadata, sd_vae
|
from modules import ui_extra_networks_user_metadata, sd_vae, shared
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
|
|||||||
|
|
||||||
self.write_user_metadata(name, user_metadata)
|
self.write_user_metadata(name, user_metadata)
|
||||||
|
|
||||||
|
def update_vae(self, name):
|
||||||
|
if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
|
||||||
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
def put_values_into_components(self, name):
|
def put_values_into_components(self, name):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
values = super().put_values_into_components(name)
|
values = super().put_values_into_components(name)
|
||||||
@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
||||||
|
self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
|
||||||
|
|
||||||
|
18
webui.py
18
webui.py
@ -211,7 +211,7 @@ def configure_sigint_handler():
|
|||||||
def configure_opts_onchange():
|
def configure_opts_onchange():
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||||
@ -341,6 +341,7 @@ def api_only():
|
|||||||
setup_middleware(app)
|
setup_middleware(app)
|
||||||
api = create_api(app)
|
api = create_api(app)
|
||||||
|
|
||||||
|
modules.script_callbacks.before_ui_callback()
|
||||||
modules.script_callbacks.app_started_callback(None, app)
|
modules.script_callbacks.app_started_callback(None, app)
|
||||||
|
|
||||||
print(f"Startup time: {startup_timer.summary()}.")
|
print(f"Startup time: {startup_timer.summary()}.")
|
||||||
@ -371,6 +372,13 @@ def webui():
|
|||||||
|
|
||||||
gradio_auth_creds = list(get_gradio_auth_creds()) or None
|
gradio_auth_creds = list(get_gradio_auth_creds()) or None
|
||||||
|
|
||||||
|
auto_launch_browser = False
|
||||||
|
if os.getenv('SD_WEBUI_RESTARTING') != '1':
|
||||||
|
if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
|
||||||
|
auto_launch_browser = True
|
||||||
|
elif shared.opts.auto_launch_browser == "Local":
|
||||||
|
auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok])
|
||||||
|
|
||||||
app, local_url, share_url = shared.demo.launch(
|
app, local_url, share_url = shared.demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
@ -380,7 +388,7 @@ def webui():
|
|||||||
ssl_verify=cmd_opts.disable_tls_verify,
|
ssl_verify=cmd_opts.disable_tls_verify,
|
||||||
debug=cmd_opts.gradio_debug,
|
debug=cmd_opts.gradio_debug,
|
||||||
auth=gradio_auth_creds,
|
auth=gradio_auth_creds,
|
||||||
inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1',
|
inbrowser=auto_launch_browser,
|
||||||
prevent_thread_lock=True,
|
prevent_thread_lock=True,
|
||||||
allowed_paths=cmd_opts.gradio_allowed_path,
|
allowed_paths=cmd_opts.gradio_allowed_path,
|
||||||
app_kwargs={
|
app_kwargs={
|
||||||
@ -390,9 +398,6 @@ def webui():
|
|||||||
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
|
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
# after initial launch, disable --autolaunch for subsequent restarts
|
|
||||||
cmd_opts.autolaunch = False
|
|
||||||
|
|
||||||
startup_timer.record("gradio launch")
|
startup_timer.record("gradio launch")
|
||||||
|
|
||||||
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
||||||
@ -437,6 +442,9 @@ def webui():
|
|||||||
shared.demo.close()
|
shared.demo.close()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# disable auto launch webui in browser for subsequent UI Reload
|
||||||
|
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||||
|
|
||||||
print('Restarting UI...')
|
print('Restarting UI...')
|
||||||
shared.demo.close()
|
shared.demo.close()
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user