re-derive sqrt alpha bar and sqrt one minus alphabar

This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent.
This commit is contained in:
drhead 2023-12-09 14:09:28 -05:00 committed by GitHub
parent 78acdcf677
commit 5381405eaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)