diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b763..c63b677c4 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,3 +1,5 @@ +# export PIP_CACHE_DIR=/scratch/dengm/cache +# export XDG_CACHE_HOME=/scratch/dengm/cache from collections import deque import torch import inspect @@ -30,12 +32,76 @@ samplers_k_diffusion = [ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('Restart (new)', 'restart_sampler', ['restart'], {'scheduler': 'karras', "second_order": True}), ] + +@torch.no_grad() +def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = {0.1: [10, 2, 2]}): + """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" + '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' + + from tqdm.auto import trange, tqdm + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + step_id = 0 + + from k_diffusion.sampling import to_d, append_zero + + def heun_step(x, old_sigma, new_sigma): + nonlocal step_id + denoised = model(x, old_sigma * s_in, **extra_args) + d = to_d(x, old_sigma, denoised) + if callback is not None: + callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) + dt = new_sigma - old_sigma + if new_sigma == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, new_sigma * s_in, **extra_args) + d_2 = to_d(x_2, new_sigma, denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + step_id += 1 + return x + # print(sigmas) + temp_list = dict() + for key, value in restart_list.items(): + temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value + restart_list = temp_list + + + def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + ramp = torch.linspace(0, 1, n).to(device) + min_inv_rho = (sigma_min ** (1 / rho)) + max_inv_rho = (sigma_max ** (1 / rho)) + if isinstance(min_inv_rho, torch.Tensor): + min_inv_rho = min_inv_rho.to(device) + if isinstance(max_inv_rho, torch.Tensor): + max_inv_rho = max_inv_rho.to(device) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + for i in trange(len(sigmas) - 1, disable=disable): + x = heun_step(x, sigmas[i], sigmas[i+1]) + if i + 1 in restart_list: + restart_steps, restart_times, restart_max = restart_list[i + 1] + min_idx = i + 1 + max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end + for times in range(restart_times): + x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5 + for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]): + x = heun_step(x, old_sigma, new_sigma) + return x + samplers_data_k_diffusion = [ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) + if (hasattr(k_diffusion.sampling, funcname) or funcname == 'restart_sampler') ] sampler_extra_params = { @@ -245,7 +311,7 @@ class KDiffusionSampler: self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) + self.func = getattr(k_diffusion.sampling, self.funcname) if funcname != "restart_sampler" else restart_sampler self.extra_params = sampler_extra_params.get(funcname, []) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None