diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 16572c7e0..6aed29742 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -51,10 +51,9 @@ class CFGDenoiserTimesteps(CFGDenoiser): self.alphas = shared.sd_model.alphas_cumprod def get_pred_x0(self, x_in, x_out, sigma): - ts = int(sigma.item()) + ts = sigma.to(dtype=int) - s_in = x_in.new_ones([x_in.shape[0]]) - a_t = self.alphas[ts].item() * s_in + a_t = self.alphas[ts][:, None, None, None] sqrt_one_minus_at = (1 - a_t).sqrt() pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt() diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index d32e35213..a72daafd4 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -16,16 +16,17 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) + s_in = x.new_ones((x.shape[0])) + s_x = x.new_ones((x.shape[0], 1, 1, 1)) for i in tqdm.trange(len(timesteps) - 1, disable=disable): index = len(timesteps) - 1 - i e_t = model(x, timesteps[index].item() * s_in, **extra_args) - a_t = alphas[index].item() * s_in - a_prev = alphas_prev[index].item() * s_in - sigma_t = sigmas[index].item() * s_in - sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in + a_t = alphas[index].item() * s_x + a_prev = alphas_prev[index].item() * s_x + sigma_t = sigmas[index].item() * s_x + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t @@ -47,13 +48,14 @@ def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) + s_x = x.new_ones((x.shape[0], 1, 1, 1)) old_eps = [] def get_x_prev_and_pred_x0(e_t, index): # select parameters corresponding to the currently considered timestep - a_t = alphas[index].item() * s_in - a_prev = alphas_prev[index].item() * s_in - sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in + a_t = alphas[index].item() * s_x + a_prev = alphas_prev[index].item() * s_x + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()