mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
add restart sampler
This commit is contained in:
parent
394ffa7b0a
commit
40a18d38a8
@ -1,3 +1,5 @@
|
||||
# export PIP_CACHE_DIR=/scratch/dengm/cache
|
||||
# export XDG_CACHE_HOME=/scratch/dengm/cache
|
||||
from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
@ -30,12 +32,76 @@ samplers_k_diffusion = [
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||
('Restart (new)', 'restart_sampler', ['restart'], {'scheduler': 'karras', "second_order": True}),
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = {0.1: [10, 2, 2]}):
|
||||
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)"""
|
||||
'''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}'''
|
||||
|
||||
from tqdm.auto import trange, tqdm
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
step_id = 0
|
||||
|
||||
from k_diffusion.sampling import to_d, append_zero
|
||||
|
||||
def heun_step(x, old_sigma, new_sigma):
|
||||
nonlocal step_id
|
||||
denoised = model(x, old_sigma * s_in, **extra_args)
|
||||
d = to_d(x, old_sigma, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
|
||||
dt = new_sigma - old_sigma
|
||||
if new_sigma == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, new_sigma, denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
step_id += 1
|
||||
return x
|
||||
# print(sigmas)
|
||||
temp_list = dict()
|
||||
for key, value in restart_list.items():
|
||||
temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value
|
||||
restart_list = temp_list
|
||||
|
||||
|
||||
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
||||
ramp = torch.linspace(0, 1, n).to(device)
|
||||
min_inv_rho = (sigma_min ** (1 / rho))
|
||||
max_inv_rho = (sigma_max ** (1 / rho))
|
||||
if isinstance(min_inv_rho, torch.Tensor):
|
||||
min_inv_rho = min_inv_rho.to(device)
|
||||
if isinstance(max_inv_rho, torch.Tensor):
|
||||
max_inv_rho = max_inv_rho.to(device)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return append_zero(sigmas).to(device)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
x = heun_step(x, sigmas[i], sigmas[i+1])
|
||||
if i + 1 in restart_list:
|
||||
restart_steps, restart_times, restart_max = restart_list[i + 1]
|
||||
min_idx = i + 1
|
||||
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
|
||||
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end
|
||||
for times in range(restart_times):
|
||||
x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5
|
||||
for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]):
|
||||
x = heun_step(x, old_sigma, new_sigma)
|
||||
return x
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
if (hasattr(k_diffusion.sampling, funcname) or funcname == 'restart_sampler')
|
||||
]
|
||||
|
||||
sampler_extra_params = {
|
||||
@ -245,7 +311,7 @@ class KDiffusionSampler:
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname) if funcname != "restart_sampler" else restart_sampler
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
|
Loading…
Reference in New Issue
Block a user