From 39a6d5655f6c162e2b8da024a1719d79304332a2 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 8 Jun 2024 18:55:07 -0400 Subject: [PATCH] patch k_diffusion to_d and strip device from schedulers --- modules/sd_schedulers.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py index 0c09af8d0..a2b9eb290 100644 --- a/modules/sd_schedulers.py +++ b/modules/sd_schedulers.py @@ -4,6 +4,12 @@ import torch import k_diffusion +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / sigma + +k_diffusion.sampling.to_d = to_d + import numpy as np from modules import shared @@ -19,11 +25,11 @@ class Scheduler: aliases: list = None -def uniform(n, sigma_min, sigma_max, inner_model, device): +def uniform(n, sigma_min, sigma_max, inner_model): return inner_model.get_sigmas(n) -def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): +def sgm_uniform(n, sigma_min, sigma_max, inner_model): start = inner_model.sigma_to_t(torch.tensor(sigma_max)) end = inner_model.sigma_to_t(torch.tensor(sigma_min)) sigs = [ @@ -31,9 +37,9 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): for ts in torch.linspace(start, end, n + 1)[:-1] ] sigs += [0.0] - return torch.FloatTensor(sigs).to(device) + return torch.FloatTensor(sigs) -def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device='cpu'): +def get_align_your_steps_sigmas(n, sigma_min, sigma_max): # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html def loglinear_interp(t_steps, num_steps): """ @@ -59,12 +65,12 @@ def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device='cpu'): else: sigmas.append(0.0) - return torch.FloatTensor(sigmas).to(device) + return torch.FloatTensor(sigmas) -def kl_optimal(n, sigma_min, sigma_max, device): - alpha_min = torch.arctan(torch.tensor(sigma_min, device=device)) - alpha_max = torch.arctan(torch.tensor(sigma_max, device=device)) - step_indices = torch.arange(n + 1, device=device) +def kl_optimal(n, sigma_min, sigma_max): + alpha_min = torch.arctan(torch.tensor(sigma_min)) + alpha_max = torch.arctan(torch.tensor(sigma_max)) + step_indices = torch.arange(n + 1) sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max) return sigmas