mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 15:15:05 +08:00
merge errors
This commit is contained in:
parent
54c3e5c913
commit
f8ff8c0638
@ -38,16 +38,24 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
negative prompt.
|
negative prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, sampler):
|
def __init__(self, sampler):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.model_wrap = None
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.steps = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.image_cfg_scale = None
|
self.image_cfg_scale = None
|
||||||
self.padded_cond_uncond = False
|
self.padded_cond_uncond = False
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
|
self.model_wrap = None
|
||||||
|
self.p = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
@ -68,10 +76,21 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
def get_pred_x0(self, x_in, x_out, sigma):
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
return x_out
|
return x_out
|
||||||
|
|
||||||
|
def update_inner_model(self):
|
||||||
|
self.model_wrap = None
|
||||||
|
|
||||||
|
c, uc = self.p.get_conds()
|
||||||
|
self.sampler.sampler_extra_args['cond'] = c
|
||||||
|
self.sampler.sampler_extra_args['uncond'] = uc
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
|
if sd_samplers_common.apply_refiner(self):
|
||||||
|
cond = self.sampler.sampler_extra_args['cond']
|
||||||
|
uncond = self.sampler.sampler_extra_args['uncond']
|
||||||
|
|
||||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||||
# so is_edit_model is set to False to support AND composition.
|
# so is_edit_model is set to False to support AND composition.
|
||||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||||
|
@ -202,8 +202,9 @@ class Sampler:
|
|||||||
|
|
||||||
self.conditioning_key = shared.sd_model.model.conditioning_key
|
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||||
|
|
||||||
self.model_wrap = None
|
self.p = None
|
||||||
self.model_wrap_cfg = None
|
self.model_wrap_cfg = None
|
||||||
|
self.sampler_extra_args = None
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
step = d['i']
|
step = d['i']
|
||||||
@ -215,6 +216,7 @@ class Sampler:
|
|||||||
shared.total_tqdm.update()
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
def launch_sampling(self, steps, func):
|
||||||
|
self.model_wrap_cfg.steps = steps
|
||||||
state.sampling_steps = steps
|
state.sampling_steps = steps
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
@ -234,6 +236,8 @@ class Sampler:
|
|||||||
return p.steps
|
return p.steps
|
||||||
|
|
||||||
def initialize(self, p) -> dict:
|
def initialize(self, p) -> dict:
|
||||||
|
self.p = p
|
||||||
|
self.model_wrap_cfg.p = p
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.model_wrap_cfg.step = 0
|
self.model_wrap_cfg.step = 0
|
||||||
|
@ -52,17 +52,24 @@ k_diffusion_scheduler = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||||
|
|
||||||
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||||
def __init__(self, funcname, sd_model):
|
def __init__(self, funcname, sd_model):
|
||||||
|
|
||||||
super().__init__(funcname)
|
super().__init__(funcname)
|
||||||
|
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
|
||||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||||
|
|
||||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||||
self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self)
|
|
||||||
|
|
||||||
def get_sigmas(self, p, steps):
|
def get_sigmas(self, p, steps):
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
|
@ -44,10 +44,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
class CFGDenoiserTimesteps(CFGDenoiser):
|
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||||
|
|
||||||
def __init__(self, model, sampler):
|
def __init__(self, sampler):
|
||||||
super().__init__(model, sampler)
|
super().__init__(sampler)
|
||||||
|
|
||||||
self.alphas = model.inner_model.alphas_cumprod
|
self.alphas = shared.sd_model.alphas_cumprod
|
||||||
|
|
||||||
def get_pred_x0(self, x_in, x_out, sigma):
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
ts = int(sigma.item())
|
ts = int(sigma.item())
|
||||||
@ -60,6 +60,14 @@ class CFGDenoiserTimesteps(CFGDenoiser):
|
|||||||
|
|
||||||
return pred_x0
|
return pred_x0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model)
|
||||||
|
|
||||||
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
class CompVisSampler(sd_samplers_common.Sampler):
|
class CompVisSampler(sd_samplers_common.Sampler):
|
||||||
def __init__(self, funcname, sd_model):
|
def __init__(self, funcname, sd_model):
|
||||||
@ -68,9 +76,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
self.eta_option_field = 'eta_ddim'
|
self.eta_option_field = 'eta_ddim'
|
||||||
self.eta_infotext_field = 'Eta DDIM'
|
self.eta_infotext_field = 'Eta DDIM'
|
||||||
|
|
||||||
denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
||||||
self.model_wrap = denoiser(sd_model)
|
|
||||||
self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
|
|
||||||
|
|
||||||
def get_timesteps(self, p, steps):
|
def get_timesteps(self, p, steps):
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
@ -106,7 +112,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
extra_args = {
|
self.sampler_extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
@ -114,7 +120,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
p.extra_generation_params["Pad conds"] = True
|
p.extra_generation_params["Pad conds"] = True
|
||||||
@ -132,13 +138,14 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
extra_params_kwargs['timesteps'] = timesteps
|
extra_params_kwargs['timesteps'] = timesteps
|
||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
self.sampler_extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
p.extra_generation_params["Pad conds"] = True
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
Loading…
Reference in New Issue
Block a user