diff --git a/modules/processing.py b/modules/processing.py index 317450065..61ba5f115 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1112,9 +1112,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): img2img_sampler_name = self.hr_sampler_name or self.sampler_name - if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM - img2img_sampler_name = 'DDIM' - self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) if self.latent_scale_mode is not None: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9ad981998..46652fbd3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ from types import MethodType from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -34,8 +34,6 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print ldm.util.print = shared.ldm_print ldm.models.diffusion.ddpm.print = shared.ldm_print -sd_hijack_inpainting.do_inpainting_hijack() - optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py deleted file mode 100644 index 2d44b8566..000000000 --- a/modules/sd_hijack_inpainting.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch - -import ldm.models.diffusion.ddpm -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms - -from ldm.models.diffusion.ddim import noise_like -from ldm.models.diffusion.sampling_util import norm_thresholding - - -@torch.no_grad() -def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = {} - for k in c: - if isinstance(c[k], list): - c_in[k] = [ - torch.cat([unconditional_conditioning[k][i], c[k][i]]) - for i in range(len(c[k])) - ] - else: - c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) - else: - c_in = torch.cat([unconditional_conditioning, c]) - - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - if dynamic_threshold is not None: - pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t - - -def do_inpainting_hijack(): - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index fe206894e..05dbe2b50 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,11 +1,10 @@ -from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, - *sd_samplers_compvis.samplers_data_compvis, *sd_samplers_timesteps.samplers_data_timesteps, ] all_samplers_map = {x.name: x for x in all_samplers} @@ -42,10 +41,8 @@ def set_samplers(): global samplers, samplers_for_img2img hidden = set(shared.opts.hide_samplers) - hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC']) - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + samplers_for_img2img = [x for x in all_samplers if x.name not in hidden] samplers_map.clear() for sampler in all_samplers: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 166a00c7c..d826222cd 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -1,4 +1,3 @@ -from collections import deque import torch from modules import prompt_parser, devices, sd_samplers_common diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py deleted file mode 100644 index 4a8396f97..000000000 --- a/modules/sd_samplers_compvis.py +++ /dev/null @@ -1,224 +0,0 @@ -import math -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms - -import numpy as np -import torch - -from modules.shared import state -from modules import sd_samplers_common, prompt_parser, shared -import modules.models.diffusion.uni_pc - - -samplers_data_compvis = [ - sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}), - sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}), - sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}), -] - - -class VanillaStableDiffusionSampler: - def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) - self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') - self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) - self.orig_p_sample_ddim = None - if self.is_plms: - self.orig_p_sample_ddim = self.sampler.p_sample_plms - elif self.is_ddim: - self.orig_p_sample_ddim = self.sampler.p_sample_ddim - self.mask = None - self.nmask = None - self.init_latent = None - self.sampler_noises = None - self.step = 0 - self.stop_at = None - self.eta = None - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def number_of_needed_noises(self, p): - return 0 - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except sd_samplers_common.InterruptedException: - return self.last_latent - - def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) - - res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs) - - x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) - - return res - - def before_sample(self, x, ts, cond, unconditional_conditioning): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - if self.stop_at is not None and self.step > self.stop_at: - raise sd_samplers_common.InterruptedException - - # Have to unwrap the inpainting conditioning here to perform pre-processing - image_conditioning = None - uc_image_conditioning = None - if isinstance(cond, dict): - if self.conditioning_key == "crossattn-adm": - image_conditioning = cond["c_adm"] - uc_image_conditioning = unconditional_conditioning["c_adm"] - else: - image_conditioning = cond["c_concat"][0] - cond = cond["c_crossattn"][0] - unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) - - assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers' - cond = tensor - - # for DDIM, shapes must match, we can't just process cond and uncond independently; - # filling unconditional_conditioning with repeats of the last vector to match length is - # not 100% correct but should work well enough - if unconditional_conditioning.shape[1] < cond.shape[1]: - last_vector = unconditional_conditioning[:, -1:] - last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) - unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) - elif unconditional_conditioning.shape[1] > cond.shape[1]: - unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] - - if self.mask is not None: - img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x = img_orig * self.mask + self.nmask * x - - # Wrap the image conditioning back up since the DDIM code can accept the dict directly. - # Note that they need to be lists because it just concatenates them later. - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - cond = {"c_adm": image_conditioning, "c_crossattn": [cond]} - unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]} - else: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - return x, ts, cond, unconditional_conditioning - - def update_step(self, last_latent): - if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * last_latent - else: - self.last_latent = last_latent - - sd_samplers_common.store_latent(self.last_latent) - - self.step += 1 - state.sampling_step = self.step - shared.total_tqdm.update() - - def after_sample(self, x, ts, cond, uncond, res): - if not self.is_unipc: - self.update_step(res[1]) - - return x, ts, cond, uncond, res - - def unipc_after_update(self, x, model_x): - self.update_step(x) - - def initialize(self, p): - if self.is_ddim: - self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim - else: - self.eta = 0.0 - - if self.eta != 0.0: - p.extra_generation_params["Eta DDIM"] = self.eta - - if self.is_unipc: - keys = [ - ('UniPC variant', 'uni_pc_variant'), - ('UniPC skip type', 'uni_pc_skip_type'), - ('UniPC order', 'uni_pc_order'), - ('UniPC lower order final', 'uni_pc_lower_order_final'), - ] - - for name, key in keys: - v = getattr(shared.opts, key) - if v != shared.opts.get_default(key): - p.extra_generation_params[name] = v - - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - if self.is_unipc: - self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx)) - - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None - - - def adjust_steps_if_invalid(self, p, num_steps): - if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): - if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: - num_steps = shared.opts.uni_pc_order - valid_step = 999 / (1000 // num_steps) - if valid_step == math.floor(valid_step): - return int(valid_step) + 1 - - return num_steps - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - steps = self.adjust_steps_if_invalid(p, steps) - self.initialize(p) - - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) - x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - - self.init_latent = x - self.last_latent = x - self.step = 0 - - # Wrap the conditioning models with additional image conditioning for inpainting model - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]} - else: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - self.initialize(p) - - self.init_latent = None - self.last_latent = x - self.step = 0 - - steps = self.adjust_steps_if_invalid(p, steps or p.steps) - - # Wrap the conditioning models with additional image conditioning for inpainting model - # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)} - else: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} - - samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) - - return samples_ddim diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3a2e01b77..27a734869 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,8 +1,7 @@ -from collections import deque import torch import inspect import k_diffusion.sampling -from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.shared import opts import modules.shared as shared diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 8560d009e..d89d0efb3 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -7,9 +7,9 @@ from modules.shared import opts import modules.shared as shared samplers_timesteps = [ - ('k_DDIM', sd_samplers_timesteps_impl.ddim, ['k_ddim'], {}), - ('k_PLMS', sd_samplers_timesteps_impl.plms, ['k_plms'], {}), - ('k_UniPC', sd_samplers_timesteps_impl.unipc, ['k_unipc'], {}), + ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), + ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), + ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), ]