From 663a4d80dfae5510257b362fd0015c8dc8b8bb5e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 16 Jun 2024 17:47:21 -0700 Subject: [PATCH] add new sampler DDIM CFG++ --- modules/sd_samplers_cfg_denoiser.py | 10 ++++++++ modules/sd_samplers_timesteps.py | 1 + modules/sd_samplers_timesteps_impl.py | 37 +++++++++++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index a86fa88ee..c8eeedad3 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module): self.model_wrap = None self.p = None + self.last_noise_uncond = None + # NOTE: masking before denoising can cause the original latents to be oversmoothed # as the original latents do not have noise self.mask_before_denoising = False @@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module): # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 + is_cfg_pp = 'CFG++' in self.sampler.config.name + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) @@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module): denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) cfg_denoised_callback(denoised_params) + if is_cfg_pp: + self.last_noise_uncond = x_out[-uncond.shape[0]:] + self.last_noise_uncond = torch.clone(self.last_noise_uncond) + if is_edit_model: denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) elif skip_uncond: denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) + elif is_cfg_pp: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale/12.5) # CFG++ scale of (0, 1) maps to (1.0, 12.5) else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 8cc7d3848..81edd67d6 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -10,6 +10,7 @@ import modules.shared as shared samplers_timesteps = [ ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), + ('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}), ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), ] diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index 84867d6ee..8896cfc9a 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= return x +@torch.no_grad() +def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): + """ Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024). + Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction. + The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0]. + """ + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones((x.shape[0])) + s_x = x.new_ones((x.shape[0], 1, 1, 1)) + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + + e_t = model(x, timesteps[index].item() * s_in, **extra_args) + last_noise_uncond = model.last_noise_uncond + + a_t = alphas[index].item() * s_x + a_prev = alphas_prev[index].item() * s_x + sigma_t = sigmas[index].item() * s_x + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x + + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond + noise = sigma_t * k_diffusion.sampling.torch.randn_like(x) + x = a_prev.sqrt() * pred_x0 + dir_xt + noise + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + @torch.no_grad() def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod