mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-25 06:19:00 +08:00
A custom blending function can be provided by p, replacing the use of soft_inpainting.
This commit is contained in:
parent
38864816fa
commit
e90d4334ad
@ -6,7 +6,6 @@ import modules.shared as shared
|
|||||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||||
import modules.soft_inpainting as si
|
|
||||||
|
|
||||||
|
|
||||||
def catenate_conds(conds):
|
def catenate_conds(conds):
|
||||||
@ -44,7 +43,6 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.soft_inpainting: si.SoftInpaintingParameters = None
|
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.steps = None
|
self.steps = None
|
||||||
"""number of steps as specified by user in UI"""
|
"""number of steps as specified by user in UI"""
|
||||||
@ -94,7 +92,6 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.sampler.sampler_extra_args['uncond'] = uc
|
self.sampler.sampler_extra_args['uncond'] = uc
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||||
|
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
@ -111,15 +108,24 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
|
# If we use masks, blending between the denoised and original latent images occurs here.
|
||||||
|
def apply_blend(latent):
|
||||||
|
if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function):
|
||||||
|
return self.p.denoiser_masked_blend_function(
|
||||||
|
self,
|
||||||
|
# Using an argument dictionary so that arguments can be added without breaking extensions.
|
||||||
|
args=
|
||||||
|
{
|
||||||
|
"denoiser": self,
|
||||||
|
"current_latent": latent,
|
||||||
|
"sigma": sigma
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return self.init_latent * self.mask + self.nmask * latent
|
||||||
|
|
||||||
# Blend in the original latents (before)
|
# Blend in the original latents (before)
|
||||||
if self.mask_before_denoising and self.mask is not None:
|
if self.mask_before_denoising and self.mask is not None:
|
||||||
if self.soft_inpainting is None:
|
x = apply_blend(x)
|
||||||
x = self.init_latent * self.mask + self.nmask * x
|
|
||||||
else:
|
|
||||||
x = si.latent_blend(self.soft_inpainting,
|
|
||||||
self.init_latent,
|
|
||||||
x,
|
|
||||||
si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))
|
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
@ -222,13 +228,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
# Blend in the original latents (after)
|
# Blend in the original latents (after)
|
||||||
if not self.mask_before_denoising and self.mask is not None:
|
if not self.mask_before_denoising and self.mask is not None:
|
||||||
if self.soft_inpainting is None:
|
denoised = apply_blend(denoised)
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
|
||||||
else:
|
|
||||||
denoised = si.latent_blend(self.soft_inpainting,
|
|
||||||
self.init_latent,
|
|
||||||
denoised,
|
|
||||||
si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))
|
|
||||||
|
|
||||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
@ -277,7 +277,6 @@ class Sampler:
|
|||||||
self.model_wrap_cfg.p = p
|
self.model_wrap_cfg.p = p
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.model_wrap_cfg.soft_inpainting = p.soft_inpainting if hasattr(p, 'soft_inpainting') else None
|
|
||||||
self.model_wrap_cfg.step = 0
|
self.model_wrap_cfg.step = 0
|
||||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||||
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user