add eta to k ancestral

This commit is contained in:
C43H66N12O12S2 2022-09-28 05:09:22 +03:00 committed by GitHub
parent f2a4a2c3a6
commit 8644e494be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,8 +39,10 @@ samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
sampler_extra_params = { sampler_extra_params = {
'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'], 'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'],
'sample_euler_ancestral':['eta'],
'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'], 'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'],
'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'], 'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'],
'sample_dpm_2_ancestral':['eta'],
} }
def setup_img2img_steps(p, steps=None): def setup_img2img_steps(p, steps=None):
@ -154,9 +156,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetin step counts, like 9 # existing code fails with cetin step counts, like 9
try: try:
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta) samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta)
except Exception: except Exception:
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta) samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta)
return samples_ddim return samples_ddim