diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 7bb0a1bb7..588b64d23 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,5 +1,7 @@ +import math + import gradio as gr -from modules import scripts, shared, ui_components, ui_settings +from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste from modules.ui_components import FormColumn @@ -19,18 +21,37 @@ class ExtraOptionsSection(scripts.Script): def ui(self, is_img2img): self.comps = [] self.setting_names = [] + self.infotext_fields = [] + + mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): - for setting_name in shared.opts.extra_options: - with FormColumn(): - comp = ui_settings.create_setting_component(setting_name) + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(): - self.comps.append(comp) - self.setting_names.append(setting_name) + row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols) + + for row in range(row_count): + with gr.Row(): + for col in range(shared.opts.extra_options_cols): + index = row * shared.opts.extra_options_cols + col + if index >= len(shared.opts.extra_options): + break + + setting_name = shared.opts.extra_options[index] + + with FormColumn(): + comp = ui_settings.create_setting_component(setting_name) + + self.comps.append(comp) + self.setting_names.append(setting_name) + + setting_infotext_name = mapping.get(setting_name) + if setting_infotext_name is not None: + self.infotext_fields.append((comp, setting_infotext_name)) def get_settings_values(): - return [ui_settings.get_value_for_setting(key) for key in self.setting_names] + res = [ui_settings.get_value_for_setting(key) for key in self.setting_names] + return res[0] if len(res) == 1 else res interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False) @@ -44,5 +65,8 @@ class ExtraOptionsSection(scripts.Script): shared.options_templates.update(shared.options_section(('ui', "User interface"), { "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(), - "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion").needs_restart() + "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() })) + + diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 677e95c1b..c21d396ee 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -136,6 +136,11 @@ function setupImageForLightbox(e) { var event = isFirefox ? 'mousedown' : 'click'; e.addEventListener(event, function(evt) { + if (evt.button == 1) { + open(evt.target.src); + evt.preventDefault(); + return; + } if (!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed); diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 6711ca16e..20e30b53b 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -416,10 +416,15 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, return res if override_settings_component is not None: + already_handled_fields = {key: 1 for _, key in paste_fields} + def paste_settings(params): vals = {} for param_name, setting_name in infotext_to_setting_name_mapping: + if param_name in already_handled_fields: + continue + v = params.get(param_name, None) if v is None: continue diff --git a/modules/img2img.py b/modules/img2img.py index d8e1c534c..e06ac1d6e 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -6,7 +6,7 @@ import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError import gradio as gr -from modules import sd_samplers, images as imgutil +from modules import images as imgutil from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state @@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal process_images(p) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -172,7 +172,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w, seed_enable_extras=seed_enable_extras, - sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name, + sampler_name=sampler_name, batch_size=batch_size, n_iter=n_iter, steps=steps, diff --git a/modules/launch_utils.py b/modules/launch_utils.py index f77b577a5..5be30a183 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -139,6 +139,27 @@ def check_run_python(code: str) -> bool: return result.returncode == 0 +def git_fix_workspace(dir, name): + run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True) + run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True) + return + + +def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True): + try: + return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) + except RuntimeError: + pass + + if not autofix: + return None + + print(f"{errdesc}, attempting autofix...") + git_fix_workspace(dir, name) + + return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) + + def git_clone(url, dir, name, commithash=None): # TODO clone into temporary dir and move if successful @@ -146,12 +167,14 @@ def git_clone(url, dir, name, commithash=None): if commithash is None: return - current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() + current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() if current_hash == commithash: return - run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) + run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") + + run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) + return run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True) diff --git a/modules/processing.py b/modules/processing.py index ec66fd8e2..b635cc745 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1119,9 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): img2img_sampler_name = self.hr_sampler_name or self.sampler_name - if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM - img2img_sampler_name = 'DDIM' - self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) if self.latent_scale_mode is not None: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9ad981998..46652fbd3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ from types import MethodType from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -34,8 +34,6 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print ldm.util.print = shared.ldm_print ldm.models.diffusion.ddpm.print = shared.ldm_print -sd_hijack_inpainting.do_inpainting_hijack() - optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py deleted file mode 100644 index 2d44b8566..000000000 --- a/modules/sd_hijack_inpainting.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch - -import ldm.models.diffusion.ddpm -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms - -from ldm.models.diffusion.ddim import noise_like -from ldm.models.diffusion.sampling_util import norm_thresholding - - -@torch.no_grad() -def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = {} - for k in c: - if isinstance(c[k], list): - c_in[k] = [ - torch.cat([unconditional_conditioning[k][i], c[k][i]]) - for i in range(len(c[k])) - ] - else: - c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) - else: - c_in = torch.cat([unconditional_conditioning, c]) - - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - if dynamic_threshold is not None: - pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t - - -def do_inpainting_hijack(): - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_models.py b/modules/sd_models.py index 981aa93d7..a97af215c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -372,7 +372,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple() sd_vae.load_vae(model, vae_file, vae_source) timer.record("load VAE") @@ -715,6 +715,7 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") model_data.set_sd_model(sd_model) + sd_unet.apply_unet() return sd_model diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index bea2684c4..45faae628 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,17 +1,18 @@ -from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, - *sd_samplers_compvis.samplers_data_compvis, + *sd_samplers_timesteps.samplers_data_timesteps, ] all_samplers_map = {x.name: x for x in all_samplers} samplers = [] samplers_for_img2img = [] samplers_map = {} +samplers_hidden = {} def find_sampler_config(name): @@ -38,13 +39,11 @@ def create_sampler(name, model): def set_samplers(): - global samplers, samplers_for_img2img + global samplers, samplers_for_img2img, samplers_hidden - hidden = set(shared.opts.hide_samplers) - hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC']) - - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + samplers_hidden = set(shared.opts.hide_samplers) + samplers = all_samplers + samplers_for_img2img = all_samplers samplers_map.clear() for sampler in all_samplers: @@ -53,4 +52,8 @@ def set_samplers(): samplers_map[alias.lower()] = sampler.name +def visible_sampler_names(): + return [x.name for x in samplers if x.name not in samplers_hidden] + + set_samplers() diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py new file mode 100644 index 000000000..d826222cd --- /dev/null +++ b/modules/sd_samplers_cfg_denoiser.py @@ -0,0 +1,203 @@ +import torch +from modules import prompt_parser, devices, sd_samplers_common + +from modules.shared import opts, state +import modules.shared as shared +from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback +from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback + + +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + +class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, model, sampler): + super().__init__() + self.inner_model = model + self.mask = None + self.nmask = None + self.init_latent = None + self.step = 0 + self.image_cfg_scale = None + self.padded_cond_uncond = False + self.sampler = sampler + + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) + + return denoised + + def combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) + + return denoised + + def get_pred_x0(self, x_in, x_out, sigma): + return x_out + + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, + # so is_edit_model is set to False to support AND composition. + is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + + assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + + if self.mask is not None: + x = self.init_latent * self.mask + self.nmask * x + + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + if shared.sd_model.model.conditioning_key == "crossattn-adm": + image_uncond = torch.zeros_like(image_cond) + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} + else: + image_uncond = image_cond + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} + + if not is_edit_model: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) + else: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) + + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) + cfg_denoiser_callback(denoiser_params) + x_in = denoiser_params.x + image_cond_in = denoiser_params.image_cond + sigma_in = denoiser_params.sigma + tensor = denoiser_params.text_cond + uncond = denoiser_params.text_uncond + skip_uncond = False + + # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it + if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: + skip_uncond = True + x_in = x_in[:-batch_size] + sigma_in = sigma_in[:-batch_size] + + self.padded_cond_uncond = False + if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: + empty = shared.sd_model.cond_stage_model_empty_prompt + num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] + + if num_repeats < 0: + tensor = pad_cond(tensor, -num_repeats, empty) + self.padded_cond_uncond = True + elif num_repeats > 0: + uncond = pad_cond(uncond, num_repeats, empty) + self.padded_cond_uncond = True + + if tensor.shape[1] == uncond.shape[1] or skip_uncond: + if is_edit_model: + cond_in = catenate_conds([tensor, uncond, uncond]) + elif skip_uncond: + cond_in = tensor + else: + cond_in = catenate_conds([tensor, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) + else: + x_out = torch.zeros_like(x_in) + batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): + a = batch_offset + b = min(a + batch_size, tensor.shape[0]) + + if not is_edit_model: + c_crossattn = subscript_cond(tensor, a, b) + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) + + if not skip_uncond: + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) + + denoised_image_indexes = [x[0][0] for x in conds_list] + if skip_uncond: + fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) + x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be + + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) + cfg_denoised_callback(denoised_params) + + devices.test_for_nans(x_out, "unet") + + if is_edit_model: + denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) + elif skip_uncond: + denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) + else: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + + self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) + + if opts.live_preview_content == "Prompt": + preview = self.sampler.last_latent + elif opts.live_preview_content == "Negative prompt": + preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma) + else: + preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma) + + sd_samplers_common.store_latent(preview) + + after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) + cfg_after_cfg_callback(after_cfg_callback_params) + denoised = after_cfg_callback_params.x + + self.step += 1 + return denoised + diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 92bf0ca1d..15f279707 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,9 +1,11 @@ -from collections import namedtuple +import inspect +from collections import namedtuple, deque import numpy as np import torch from PIL import Image from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models from modules.shared import opts, state +import k_diffusion.sampling SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -155,3 +157,137 @@ def apply_refiner(sampler): return True +class TorchHijack: + def __init__(self, sampler_noises): + # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based + # implementation. + self.sampler_noises = deque(sampler_noises) + + def __getattr__(self, item): + if item == 'randn_like': + return self.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + + def randn_like(self, x): + if self.sampler_noises: + noise = self.sampler_noises.popleft() + if noise.shape == x.shape: + return noise + + return devices.randn_like(x) + + +class Sampler: + def __init__(self, funcname): + self.funcname = funcname + self.func = funcname + self.extra_params = [] + self.sampler_noises = None + self.stop_at = None + self.eta = None + self.config = None # set by the function calling the constructor + self.last_latent = None + self.s_min_uncond = None + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + + self.eta_option_field = 'eta_ancestral' + self.eta_infotext_field = 'Eta' + + self.conditioning_key = shared.sd_model.model.conditioning_key + + self.model_wrap = None + self.model_wrap_cfg = None + + def callback_state(self, d): + step = d['i'] + + if self.stop_at is not None and step > self.stop_at: + raise InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except RecursionError: + print( + 'Encountered RecursionError during sampling, returning last latent. ' + 'rho >5 with a polyexponential scheduler may cause this error. ' + 'You should try to use a smaller rho value instead.' + ) + return self.last_latent + except InterruptedException: + return self.last_latent + + def number_of_needed_noises(self, p): + return p.steps + + def initialize(self, p) -> dict: + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None + self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) + self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) + self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) + + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params[self.eta_infotext_field] = self.eta + + extra_params_kwargs['eta'] = self.eta + + if len(self.extra_params) > 0: + s_churn = getattr(opts, 's_churn', p.s_churn) + s_tmin = getattr(opts, 's_tmin', p.s_tmin) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf + s_noise = getattr(opts, 's_noise', p.s_noise) + + if s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + p.s_churn = s_churn + p.extra_generation_params['Sigma churn'] = s_churn + if s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_tmin + p.s_tmin = s_tmin + p.extra_generation_params['Sigma tmin'] = s_tmin + if s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_tmax + p.s_tmax = s_tmax + p.extra_generation_params['Sigma tmax'] = s_tmax + if s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_noise + p.s_noise = s_noise + p.extra_generation_params['Sigma noise'] = s_noise + + return extra_params_kwargs + + def create_noise_sampler(self, x, sigmas, p): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + if shared.opts.no_dpmpp_sde_batch_determinism: + return None + + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + + + diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 2eeec18a4..e69de29bb 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -1,232 +0,0 @@ -import math -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms - -import numpy as np -import torch - -from modules.shared import state -from modules import sd_samplers_common, prompt_parser, shared -import modules.models.diffusion.uni_pc - - -samplers_data_compvis = [ - sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}), - sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}), - sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}), -] - - -class VanillaStableDiffusionSampler: - def __init__(self, constructor, sd_model): - self.p = None - self.sampler = constructor(shared.sd_model) - self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') - self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) - self.orig_p_sample_ddim = None - if self.is_plms: - self.orig_p_sample_ddim = self.sampler.p_sample_plms - elif self.is_ddim: - self.orig_p_sample_ddim = self.sampler.p_sample_ddim - self.mask = None - self.nmask = None - self.init_latent = None - self.sampler_noises = None - self.steps = None - self.step = 0 - self.stop_at = None - self.eta = None - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def number_of_needed_noises(self, p): - return 0 - - def launch_sampling(self, steps, func): - self.steps = steps - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except sd_samplers_common.InterruptedException: - return self.last_latent - - def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) - - res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs) - - x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) - - return res - - def update_inner_model(self): - self.sampler.model = shared.sd_model - - def before_sample(self, x, ts, cond, unconditional_conditioning): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - if self.stop_at is not None and self.step > self.stop_at: - raise sd_samplers_common.InterruptedException - - # Have to unwrap the inpainting conditioning here to perform pre-processing - image_conditioning = None - uc_image_conditioning = None - if isinstance(cond, dict): - if self.conditioning_key == "crossattn-adm": - image_conditioning = cond["c_adm"] - uc_image_conditioning = unconditional_conditioning["c_adm"] - else: - image_conditioning = cond["c_concat"][0] - cond = cond["c_crossattn"][0] - unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) - - assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers' - cond = tensor - - # for DDIM, shapes must match, we can't just process cond and uncond independently; - # filling unconditional_conditioning with repeats of the last vector to match length is - # not 100% correct but should work well enough - if unconditional_conditioning.shape[1] < cond.shape[1]: - last_vector = unconditional_conditioning[:, -1:] - last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) - unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) - elif unconditional_conditioning.shape[1] > cond.shape[1]: - unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] - - if self.mask is not None: - img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x = img_orig * self.mask + self.nmask * x - - # Wrap the image conditioning back up since the DDIM code can accept the dict directly. - # Note that they need to be lists because it just concatenates them later. - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - cond = {"c_adm": image_conditioning, "c_crossattn": [cond]} - unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]} - else: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - return x, ts, cond, unconditional_conditioning - - def update_step(self, last_latent): - if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * last_latent - else: - self.last_latent = last_latent - - sd_samplers_common.store_latent(self.last_latent) - - self.step += 1 - state.sampling_step = self.step - shared.total_tqdm.update() - - def after_sample(self, x, ts, cond, uncond, res): - if not self.is_unipc: - self.update_step(res[1]) - - return x, ts, cond, uncond, res - - def unipc_after_update(self, x, model_x): - self.update_step(x) - - def initialize(self, p): - self.p = p - - if self.is_ddim: - self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim - else: - self.eta = 0.0 - - if self.eta != 0.0: - p.extra_generation_params["Eta DDIM"] = self.eta - - if self.is_unipc: - keys = [ - ('UniPC variant', 'uni_pc_variant'), - ('UniPC skip type', 'uni_pc_skip_type'), - ('UniPC order', 'uni_pc_order'), - ('UniPC lower order final', 'uni_pc_lower_order_final'), - ] - - for name, key in keys: - v = getattr(shared.opts, key) - if v != shared.opts.get_default(key): - p.extra_generation_params[name] = v - - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - if self.is_unipc: - self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx)) - - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None - - - def adjust_steps_if_invalid(self, p, num_steps): - if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): - if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: - num_steps = shared.opts.uni_pc_order - valid_step = 999 / (1000 // num_steps) - if valid_step == math.floor(valid_step): - return int(valid_step) + 1 - - return num_steps - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - steps = self.adjust_steps_if_invalid(p, steps) - self.initialize(p) - - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) - x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - - self.init_latent = x - self.last_latent = x - self.step = 0 - - # Wrap the conditioning models with additional image conditioning for inpainting model - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]} - else: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - self.initialize(p) - - self.init_latent = None - self.last_latent = x - self.step = 0 - - steps = self.adjust_steps_if_invalid(p, steps or p.steps) - - # Wrap the conditioning models with additional image conditioning for inpainting model - # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)} - else: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} - - samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) - - return samples_ddim diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 46da0a97f..3ff4b6345 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,17 +1,16 @@ -from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser -from modules.processing import StableDiffusionProcessing -from modules.shared import opts, state +from modules.shared import opts import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback samplers_k_diffusion = [ + ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), + ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), @@ -28,10 +27,6 @@ samplers_k_diffusion = [ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), - ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), ] @@ -57,342 +52,17 @@ k_diffusion_scheduler = { } -def catenate_conds(conds): - if not isinstance(conds[0], dict): - return torch.cat(conds) - - return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} - - -def subscript_cond(cond, a, b): - if not isinstance(cond, dict): - return cond[a:b] - - return {key: vec[a:b] for key, vec in cond.items()} - - -def pad_cond(tensor, repeats, empty): - if not isinstance(tensor, dict): - return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) - - tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) - return tensor - - -class CFGDenoiser(torch.nn.Module): - """ - Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) - that can take a noisy picture and produce a noise-free picture using two guidances (prompts) - instead of one. Originally, the second prompt is just an empty string, but we use non-empty - negative prompt. - """ - - def __init__(self, sampler): - super().__init__() - self.sampler = sampler - self.model_wrap = None - self.mask = None - self.nmask = None - self.init_latent = None - self.steps = None - self.step = 0 - self.image_cfg_scale = None - self.padded_cond_uncond = False - self.p = None - - @property - def inner_model(self): - if self.model_wrap is None: - denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) - - return self.model_wrap - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def combine_denoised_for_edit_model(self, x_out, cond_scale): - out_cond, out_img_cond, out_uncond = x_out.chunk(3) - denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) - - return denoised - - def update_inner_model(self): - self.model_wrap = None - - c, uc = self.p.get_conds() - self.sampler.sampler_extra_args['cond'] = c - self.sampler.sampler_extra_args['uncond'] = uc - - def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - if sd_samplers_common.apply_refiner(self): - cond = self.sampler.sampler_extra_args['cond'] - uncond = self.sampler.sampler_extra_args['uncond'] - - # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, - # so is_edit_model is set to False to support AND composition. - is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - if shared.sd_model.model.conditioning_key == "crossattn-adm": - image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} - else: - image_uncond = image_cond - if isinstance(uncond, dict): - make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} - else: - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} - - if not is_edit_model: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) - else: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - tensor = denoiser_params.text_cond - uncond = denoiser_params.text_uncond - skip_uncond = False - - # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it - if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: - skip_uncond = True - x_in = x_in[:-batch_size] - sigma_in = sigma_in[:-batch_size] - - self.padded_cond_uncond = False - if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: - empty = shared.sd_model.cond_stage_model_empty_prompt - num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] - - if num_repeats < 0: - tensor = pad_cond(tensor, -num_repeats, empty) - self.padded_cond_uncond = True - elif num_repeats > 0: - uncond = pad_cond(uncond, num_repeats, empty) - self.padded_cond_uncond = True - - if tensor.shape[1] == uncond.shape[1] or skip_uncond: - if is_edit_model: - cond_in = catenate_conds([tensor, uncond, uncond]) - elif skip_uncond: - cond_in = tensor - else: - cond_in = catenate_conds([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - - if not is_edit_model: - c_crossattn = subscript_cond(tensor, a, b) - else: - c_crossattn = torch.cat([tensor[a:b]], uncond) - - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) - - if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) - - denoised_image_indexes = [x[0][0] for x in conds_list] - if skip_uncond: - fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) - x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be - - denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) - cfg_denoised_callback(denoised_params) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes])) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - if is_edit_model: - denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) - elif skip_uncond: - denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) - else: - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) - cfg_after_cfg_callback(after_cfg_callback_params) - denoised = after_cfg_callback_params.x - - self.step += 1 - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - return devices.randn_like(x) - - -class KDiffusionSampler: +class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): - self.p = None - self.funcname = funcname - self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) + + super().__init__(funcname) + self.extra_params = sampler_extra_params.get(funcname, []) - self.sampler_extra_args = {} - self.model_wrap_cfg = CFGDenoiser(self) - self.model_wrap = self.model_wrap_cfg.inner_model - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.config = None # set by the function calling the constructor - self.last_latent = None - self.s_min_uncond = None + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - # NOTE: These are also defined in the StableDiffusionProcessing class. - # They should have been here to begin with but we're going to - # leave that class __init__ signature alone. - self.s_churn = 0.0 - self.s_tmin = 0.0 - self.s_tmax = float('inf') - self.s_noise = 1.0 - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - self.model_wrap_cfg.steps = steps - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except RecursionError: - print( - 'Encountered RecursionError during sampling, returning last latent. ' - 'rho >5 with a polyexponential scheduler may cause this error. ' - 'You should try to use a smaller rho value instead.' - ) - return self.last_latent - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p: StableDiffusionProcessing): - self.p = p - self.model_wrap_cfg.p = p - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) - self.eta = p.eta if p.eta is not None else opts.eta_ancestral - self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - if self.eta != 1.0: - p.extra_generation_params["Eta"] = self.eta - - extra_params_kwargs['eta'] = self.eta - - if len(self.extra_params) > 0: - s_churn = getattr(opts, 's_churn', p.s_churn) - s_tmin = getattr(opts, 's_tmin', p.s_tmin) - s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf - s_noise = getattr(opts, 's_noise', p.s_noise) - - if s_churn != self.s_churn: - extra_params_kwargs['s_churn'] = s_churn - p.s_churn = s_churn - p.extra_generation_params['Sigma churn'] = s_churn - if s_tmin != self.s_tmin: - extra_params_kwargs['s_tmin'] = s_tmin - p.s_tmin = s_tmin - p.extra_generation_params['Sigma tmin'] = s_tmin - if s_tmax != self.s_tmax: - extra_params_kwargs['s_tmax'] = s_tmax - p.s_tmax = s_tmax - p.extra_generation_params['Sigma tmax'] = s_tmax - if s_noise != self.s_noise: - extra_params_kwargs['s_noise'] = s_noise - p.s_noise = s_noise - p.extra_generation_params['Sigma noise'] = s_noise - - return extra_params_kwargs + denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) + self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -444,22 +114,12 @@ class KDiffusionSampler: return sigmas - def create_noise_sampler(self, x, sigmas, p): - """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" - if shared.opts.no_dpmpp_sde_batch_determinism: - return None - - from k_diffusion.sampling import BrownianTreeNoiseSampler - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] - return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) sigmas = self.get_sigmas(p, steps) - sigma_sched = sigmas[steps - t_enc - 1:] + xi = x + noise * sigma_sched[0] extra_params_kwargs = self.initialize(p) @@ -508,12 +168,14 @@ class KDiffusionSampler: extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters + if 'n' in parameters: + extra_params_kwargs['n'] = steps + if 'sigma_min' in parameters: extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in parameters: - extra_params_kwargs['n'] = steps - else: + + if 'sigmas' in parameters: extra_params_kwargs['sigmas'] = sigmas if self.config.options.get('brownian_noise', False): @@ -535,3 +197,4 @@ class KDiffusionSampler: return samples + diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py new file mode 100644 index 000000000..d89d0efb3 --- /dev/null +++ b/modules/sd_samplers_timesteps.py @@ -0,0 +1,147 @@ +import torch +import inspect +from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl +from modules.sd_samplers_cfg_denoiser import CFGDenoiser + +from modules.shared import opts +import modules.shared as shared + +samplers_timesteps = [ + ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), + ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), + ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), +] + + +samplers_data_timesteps = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_timesteps +] + + +class CompVisTimestepsDenoiser(torch.nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inner_model = model + + def forward(self, input, timesteps, **kwargs): + return self.inner_model.apply_model(input, timesteps, **kwargs) + + +class CompVisTimestepsVDenoiser(torch.nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inner_model = model + + def predict_eps_from_z_and_v(self, x_t, t, v): + return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + + def forward(self, input, timesteps, **kwargs): + model_output = self.inner_model.apply_model(input, timesteps, **kwargs) + e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output) + return e_t + + +class CFGDenoiserTimesteps(CFGDenoiser): + + def __init__(self, model, sampler): + super().__init__(model, sampler) + + self.alphas = model.inner_model.alphas_cumprod + + def get_pred_x0(self, x_in, x_out, sigma): + ts = int(sigma.item()) + + s_in = x_in.new_ones([x_in.shape[0]]) + a_t = self.alphas[ts].item() * s_in + sqrt_one_minus_at = (1 - a_t).sqrt() + + pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt() + + return pred_x0 + + +class CompVisSampler(sd_samplers_common.Sampler): + def __init__(self, funcname, sd_model): + super().__init__(funcname) + + self.eta_option_field = 'eta_ddim' + self.eta_infotext_field = 'Eta DDIM' + + denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(sd_model) + self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self) + + def get_timesteps(self, p, steps): + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: + discard_next_to_last_sigma = True + p.extra_generation_params["Discard penultimate sigma"] = True + + steps += 1 if discard_next_to_last_sigma else 0 + + timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999) + + return timesteps + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + + timesteps = self.get_timesteps(p, steps) + timesteps_sched = timesteps[:t_enc] + + alphas_cumprod = shared.sd_model.alphas_cumprod + sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]]) + sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]]) + + xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'timesteps' in parameters: + extra_params_kwargs['timesteps'] = timesteps_sched + if 'is_img2img' in parameters: + extra_params_kwargs['is_img2img'] = True + + self.model_wrap_cfg.init_latent = x + self.last_latent = x + extra_args = { + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + } + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps = steps or p.steps + timesteps = self.get_timesteps(p, steps) + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'timesteps' in parameters: + extra_params_kwargs['timesteps'] = timesteps + + self.last_latent = x + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py new file mode 100644 index 000000000..48d7e6491 --- /dev/null +++ b/modules/sd_samplers_timesteps_impl.py @@ -0,0 +1,135 @@ +import torch +import tqdm +import k_diffusion.sampling +import numpy as np + +from modules import shared +from modules.models.diffusion.uni_pc import uni_pc + + +@torch.no_grad() +def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + + e_t = model(x, timesteps[index].item() * s_in, **extra_args) + + a_t = alphas[index].item() * s_in + a_prev = alphas_prev[index].item() * s_in + sigma_t = sigmas[index].item() * s_in + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in + + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * k_diffusion.sampling.torch.randn_like(x) + x = a_prev.sqrt() * pred_x0 + dir_xt + noise + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + +@torch.no_grad() +def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + old_eps = [] + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = alphas[index].item() * s_in + a_prev = alphas_prev[index].item() * s_in + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + + # direction pointing to x_t + dir_xt = (1. - a_prev).sqrt() * e_t + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + return x_prev, pred_x0 + + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + ts = timesteps[index].item() * s_in + t_next = timesteps[max(index - 1, 0)].item() * s_in + + e_t = model(x, ts, **extra_args) + + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = model(x_prev, t_next, **extra_args) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + else: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + + x = x_prev + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + +class UniPCCFG(uni_pc.UniPC): + def __init__(self, cfg_model, extra_args, callback, *args, **kwargs): + super().__init__(None, *args, **kwargs) + + def after_update(x, model_x): + callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x}) + self.index += 1 + + self.cfg_model = cfg_model + self.extra_args = extra_args + self.callback = callback + self.index = 0 + self.after_update = after_update + + def get_model_input_time(self, t_continuous): + return (t_continuous - 1. / self.noise_schedule.total_N) * 1000. + + def model(self, x, t): + t_input = self.get_model_input_time(t) + + res = self.cfg_model(x, t_input, **self.extra_args) + + return res + + +def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + + ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means + unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant) + x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) + + return x diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0bd5e19bb..38bcb8402 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,7 @@ import os import collections +from dataclasses import dataclass + from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file): return None -def resolve_vae(checkpoint_file): - if shared.cmd_opts.vae_path is not None: - return shared.cmd_opts.vae_path, 'from commandline argument' +@dataclass +class VaeResolution: + vae: str = None + source: str = None + resolved: bool = True + def tuple(self): + return self.vae, self.source + + +def is_automatic(): + return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + + +def resolve_vae_from_setting() -> VaeResolution: + if shared.opts.sd_vae == "None": + return VaeResolution() + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return VaeResolution(vae_from_options, 'specified in settings') + + if not is_automatic(): + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return VaeResolution(resolved=False) + + +def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: metadata = extra_networks.get_user_metadata(checkpoint_file) vae_metadata = metadata.get("vae", None) if vae_metadata is not None and vae_metadata != "Automatic": if vae_metadata == "None": - return None, None + return VaeResolution() vae_from_metadata = vae_dict.get(vae_metadata, None) if vae_from_metadata is not None: - return vae_from_metadata, "from user metadata" + return VaeResolution(vae_from_metadata, "from user metadata") - is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + return VaeResolution(resolved=False) + +def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution: vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): - return vae_near_checkpoint, 'found near the checkpoint' + return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') - if shared.opts.sd_vae == "None": - return None, None + return VaeResolution(resolved=False) - vae_from_options = vae_dict.get(shared.opts.sd_vae, None) - if vae_from_options is not None: - return vae_from_options, 'specified in settings' - if not is_automatic: - print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") +def resolve_vae(checkpoint_file) -> VaeResolution: + if shared.cmd_opts.vae_path is not None: + return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument') - return None, None + if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic(): + return resolve_vae_from_setting() + + res = resolve_vae_from_user_metadata(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_near_checkpoint(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_from_setting() + + return res def load_vae_dict(filename, map_location): @@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): checkpoint_file = checkpoint_info.filename if vae_file == unspecified: - vae_file, vae_source = resolve_vae(checkpoint_file) + vae_file, vae_source = resolve_vae(checkpoint_file).tuple() else: vae_source = "from function argument" diff --git a/modules/shared.py b/modules/shared.py index ed8395dc4..2fd299048 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -422,6 +422,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" })) options_templates.update(options_section(('system', "System"), { + "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), @@ -481,7 +482,7 @@ For img2img, VAE is used to process user's input image before the sampling, and """), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), - "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"), @@ -610,14 +611,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"), "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), - 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}), - 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf"), - 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}).info('amount of stochasticity; only applies to Euler, Heun, and DPM2'), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}).info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'), + 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}).info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'), + 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), - 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"), - 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"), + 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), + 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}), @@ -735,6 +736,10 @@ class Options: with open(filename, "r", encoding="utf8") as file: self.data = json.load(file) + # 1.6.0 VAE defaults + if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: + self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') + # 1.1.1 quicksettings list migration if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None: self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')] diff --git a/modules/txt2img.py b/modules/txt2img.py index 935ed4181..8fa389b56 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,7 +1,7 @@ from contextlib import closing import modules.scripts -from modules import sd_samplers, processing +from modules import processing from modules.generation_parameters_copypaste import create_override_settings_dict from modules.shared import opts, cmd_opts import modules.shared as shared @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -25,7 +25,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w, seed_enable_extras=seed_enable_extras, - sampler_name=sd_samplers.samplers[sampler_index].name, + sampler_name=sampler_name, batch_size=batch_size, n_iter=n_iter, steps=steps, @@ -42,7 +42,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, - hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, + hr_sampler_name=hr_sampler_name, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, diff --git a/modules/ui.py b/modules/ui.py index 1af6b4c88..e3753e970 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import gradio_extensons # noqa: F401 -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -29,7 +29,6 @@ import modules.shared as shared import modules.images from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img from modules.generation_parameters_copypaste import image_from_url_text create_setting_component = ui_settings.create_setting_component @@ -41,6 +40,9 @@ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else mimetypes.init() mimetypes.add_type('application/javascript', '.js') +# Likewise, add explicit content-type header for certain missing image types +mimetypes.add_type('image/webp', '.webp') + if not cmd_opts.share and not cmd_opts.listen: # fix gradio phoning home gradio.utils.version_check = lambda: None @@ -357,14 +359,14 @@ def create_output_panel(tabname, outdir): def create_sampler_and_steps_selection(choices, tabname): if opts.samplers_in_dropdown: with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0]) steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0]) - return steps, sampler_index + return steps, sampler_name def ordered_ui_categories(): @@ -405,13 +407,13 @@ def create_ui(): extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs") extra_tabs.__enter__() - with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False): + with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") elif category == "dimensions": with FormRow(): @@ -457,7 +459,7 @@ def create_ui(): hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") - hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") + hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: with gr.Column(scale=80): @@ -517,7 +519,7 @@ def create_ui(): toprow.negative_prompt, toprow.ui_styles.dropdown, steps, - sampler_index, + sampler_name, restore_faces, tiling, batch_count, @@ -535,7 +537,7 @@ def create_ui(): hr_resize_x, hr_resize_y, hr_checkpoint_name, - hr_sampler_index, + hr_sampler_name, hr_prompt, hr_negative_prompt, override_settings, @@ -580,7 +582,7 @@ def create_ui(): (toprow.prompt, "Prompt"), (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), - (sampler_index, "Sampler"), + (sampler_name, "Sampler"), (restore_faces, "Face restoration"), (cfg_scale, "CFG scale"), (seed, "Seed"), @@ -602,7 +604,7 @@ def create_ui(): (hr_resize_x, "Hires resize-1"), (hr_resize_y, "Hires resize-2"), (hr_checkpoint_name, "Hires checkpoint"), - (hr_sampler_index, "Hires sampler"), + (hr_sampler_name, "Hires sampler"), (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), @@ -618,7 +620,7 @@ def create_ui(): toprow.prompt, toprow.negative_prompt, steps, - sampler_index, + sampler_name, cfg_scale, seed, width, @@ -741,7 +743,7 @@ def create_ui(): for category in ordered_ui_categories(): if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") elif category == "dimensions": with FormRow(): @@ -873,7 +875,7 @@ def create_ui(): init_img_inpaint, init_mask_inpaint, steps, - sampler_index, + sampler_name, mask_blur, mask_alpha, inpainting_fill, @@ -969,7 +971,7 @@ def create_ui(): (toprow.prompt, "Prompt"), (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), - (sampler_index, "Sampler"), + (sampler_name, "Sampler"), (restore_faces, "Face restoration"), (cfg_scale, "CFG scale"), (image_cfg_scale, "Image CFG scale"), diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py index 2c69aab86..25df0a807 100644 --- a/modules/ui_extra_networks_checkpoints_user_metadata.py +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import ui_extra_networks_user_metadata, sd_vae +from modules import ui_extra_networks_user_metadata, sd_vae, shared from modules.ui_common import create_refresh_button @@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE self.write_user_metadata(name, user_metadata) + def update_vae(self, name): + if name == shared.sd_model.sd_checkpoint_info.name_for_extra: + sd_vae.reload_vae_weights() + def put_values_into_components(self, name): user_metadata = self.get_user_metadata(name) values = super().put_values_into_components(name) @@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE ] self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) + self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input]) + diff --git a/webui.py b/webui.py index 1803ea8ae..6d36f8806 100644 --- a/webui.py +++ b/webui.py @@ -211,7 +211,7 @@ def configure_sigint_handler(): def configure_opts_onchange(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) @@ -341,6 +341,7 @@ def api_only(): setup_middleware(app) api = create_api(app) + modules.script_callbacks.before_ui_callback() modules.script_callbacks.app_started_callback(None, app) print(f"Startup time: {startup_timer.summary()}.") @@ -371,6 +372,13 @@ def webui(): gradio_auth_creds = list(get_gradio_auth_creds()) or None + auto_launch_browser = False + if os.getenv('SD_WEBUI_RESTARTING') != '1': + if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch: + auto_launch_browser = True + elif shared.opts.auto_launch_browser == "Local": + auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok]) + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, @@ -380,7 +388,7 @@ def webui(): ssl_verify=cmd_opts.disable_tls_verify, debug=cmd_opts.gradio_debug, auth=gradio_auth_creds, - inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1', + inbrowser=auto_launch_browser, prevent_thread_lock=True, allowed_paths=cmd_opts.gradio_allowed_path, app_kwargs={ @@ -390,9 +398,6 @@ def webui(): root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "", ) - # after initial launch, disable --autolaunch for subsequent restarts - cmd_opts.autolaunch = False - startup_timer.record("gradio launch") # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for @@ -437,6 +442,9 @@ def webui(): shared.demo.close() break + # disable auto launch webui in browser for subsequent UI Reload + os.environ.setdefault('SD_WEBUI_RESTARTING', '1') + print('Restarting UI...') shared.demo.close() time.sleep(0.5)