Removed soft inpainting, added hooks for softpainting to work instead.

This commit is contained in:
CodeHatchling 2023-12-06 21:16:27 -07:00
parent 4608f6236f
commit ac45789123
3 changed files with 118 additions and 69 deletions

View File

@ -30,7 +30,6 @@ import modules.sd_models as sd_models
import modules.sd_vae as sd_vae import modules.sd_vae as sd_vae
from ldm.data.util import AddMiDaS from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
import modules.soft_inpainting as si
from einops import repeat, rearrange from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType from blendmodes.blend import blendLayers, BlendType
@ -73,12 +72,10 @@ def uncrop(image, dest_size, paste_loc):
return image return image
def apply_overlay(image, paste_loc, index, overlays): def apply_overlay(image, paste_loc, overlay):
if overlays is None or index >= len(overlays): if overlay is None:
return image return image
overlay = overlays[index]
if paste_loc is not None: if paste_loc is not None:
image = uncrop(image, (overlay.width, overlay.height), paste_loc) image = uncrop(image, (overlay.width, overlay.height), paste_loc)
@ -150,7 +147,6 @@ class StableDiffusionProcessing:
do_not_save_grid: bool = False do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None extra_generation_params: dict[str, Any] = None
overlay_images: list = None overlay_images: list = None
masks_for_overlay: list = None
eta: float = None eta: float = None
do_not_reload_embeddings: bool = False do_not_reload_embeddings: bool = False
denoising_strength: float = None denoising_strength: float = None
@ -880,31 +876,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
p.scripts.post_sample(p, ps)
samples_ddim = pp.samples
if getattr(samples_ddim, 'already_decoded', False): if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim x_samples_ddim = samples_ddim
# todo: generate adaptive masks based on pixel differences.
if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
si.apply_masks(soft_inpainting=p.soft_inpainting,
nmask=p.nmask,
overlay_images=p.overlay_images,
masks_for_overlay=p.masks_for_overlay,
width=p.width,
height=p.height,
paste_to=p.paste_to)
else: else:
if opts.sd_vae_decode_method != 'Full': if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
# Generate the mask(s) based on similarity between the original and denoised latent vectors
if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
si.apply_adaptive_masks(latent_orig=p.init_latent,
latent_processed=samples_ddim,
overlay_images=p.overlay_images,
masks_for_overlay=p.masks_for_overlay,
width=p.width,
height=p.height,
paste_to=p.paste_to)
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.stack(x_samples_ddim).float()
@ -955,9 +937,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
pp = scripts.PostprocessImageArgs(image) pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp) p.scripts.postprocess_image(p, pp)
image = pp.image image = pp.image
mask_for_overlay = p.mask_for_overlay
overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None
if p.scripts is not None:
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
p.scripts.postprocess_maskoverlay(p, ppmo)
mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image
if p.color_corrections is not None and i < len(p.color_corrections): if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction: if save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image) image = apply_color_correction(p.color_corrections[i], image)
@ -968,9 +959,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
original_denoised_image = image.copy() original_denoised_image = image.copy()
if p.paste_to is not None: if p.paste_to is not None:
original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to) original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to)
image = apply_overlay(image, p.paste_to, i, p.overlay_images) image = apply_overlay(image, p.paste_to, overlay_image)
if save_samples: if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@ -981,13 +972,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
mask_for_overlay = p.mask_for_overlay
elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]:
mask_for_overlay = p.masks_for_overlay[i]
else:
mask_for_overlay = None
if mask_for_overlay is not None: if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask: if opts.return_mask or opts.save_mask:
image_mask = mask_for_overlay.convert('RGB') image_mask = mask_for_overlay.convert('RGB')
@ -1401,7 +1385,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4 mask_blur_x: int = 4
mask_blur_y: int = 4 mask_blur_y: int = 4
mask_blur: int = None mask_blur: int = None
soft_inpainting: si.SoftInpaintingParameters = si.default mask_round: bool = True
inpainting_fill: int = 0 inpainting_fill: int = 0
inpaint_full_res: bool = True inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0 inpaint_full_res_padding: int = 0
@ -1447,7 +1431,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if image_mask is not None: if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks, # image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks. # but we still want to support binary masks.
image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None)) image_mask = create_binary_mask(image_mask, round=self.mask_round)
if self.inpainting_mask_invert: if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask) image_mask = ImageOps.invert(image_mask)
@ -1465,7 +1449,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = Image.fromarray(np_mask) image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res: if self.inpaint_full_res:
self.mask_for_overlay = image_mask if self.soft_inpainting is None else None self.mask_for_overlay = image_mask
mask = image_mask.convert('L') mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@ -1476,13 +1460,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.paste_to = (x1, y1, x2-x1, y2-y1) self.paste_to = (x1, y1, x2-x1, y2-y1)
else: else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
if self.soft_inpainting is None:
np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
self.masks_for_overlay = [] if self.soft_inpainting is not None else None
self.overlay_images = [] self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@ -1504,15 +1485,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = images.resize_image(self.resize_mode, image, self.width, self.height) image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None: if image_mask is not None:
if self.soft_inpainting is not None: image_masked = Image.new('RGBa', (image.width, image.height))
# We apply the masks AFTER to adjust mask based on changed content. image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image.convert('RGBA'))
self.masks_for_overlay.append(image_mask)
else:
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA')) self.overlay_images.append(image_masked.convert('RGBA'))
# crop_region is not None if we are doing inpaint full res # crop_region is not None if we are doing inpaint full res
if crop_region is not None: if crop_region is not None:
@ -1565,7 +1541,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0] latmask = latmask[0]
if self.soft_inpainting is None: if self.mask_round:
latmask = np.around(latmask) latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1)) latmask = np.tile(latmask[None], (4, 1, 1))
@ -1578,7 +1554,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3: elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask self.init_latent = self.init_latent * self.mask
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None) self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next() x = self.rng.next()
@ -1589,8 +1565,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None and self.soft_inpainting is None: blended_samples = samples * self.nmask + self.init_latent * self.mask
samples = samples * self.nmask + self.init_latent * self.mask
if self.scripts is not None:
mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True)
self.scripts.on_mask_blend(self, mba)
blended_samples = mba.blended_latent
samples = blended_samples
del x del x
devices.torch_gc() devices.torch_gc()

View File

@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object() AlwaysVisible = object()
class MaskBlendArgs:
def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None):
self.current_latent = current_latent
self.nmask = nmask
self.init_latent = init_latent
self.mask = mask
self.blended_samples = blended_samples
self.denoiser = denoiser
self.is_final_blend = denoiser is None
self.sigma = sigma
class PostSampleArgs:
def __init__(self, samples):
self.samples = samples
class PostprocessImageArgs: class PostprocessImageArgs:
def __init__(self, image): def __init__(self, image):
self.image = image self.image = image
class PostProcessMaskOverlayArgs:
def __init__(self, index, mask_for_overlay, overlay_image):
self.index = index
self.mask_for_overlay = mask_for_overlay
self.overlay_image = overlay_image
class PostprocessBatchListArgs: class PostprocessBatchListArgs:
def __init__(self, images): def __init__(self, images):
@ -206,6 +226,25 @@ class Script:
pass pass
def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
"""
Called in inpainting mode when the original content is blended with the inpainted content.
This is called at every step in the denoising process and once at the end.
If is_final_blend is true, this is called for the final blending stage.
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
"""
pass
def post_sample(self, p, ps: PostSampleArgs, *args):
"""
Called after the samples have been generated,
but before they have been decoded by the VAE, if applicable.
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
"""
pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args): def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
""" """
Called for every image after it has been generated. Called for every image after it has been generated.
@ -213,6 +252,13 @@ class Script:
pass pass
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
"""
Called for every image after it has been generated.
"""
pass
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
""" """
This function is called after processing ends for AlwaysVisible scripts. This function is called after processing ends for AlwaysVisible scripts.
@ -767,6 +813,22 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs): def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
@ -775,6 +837,14 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args)
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs): def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try: try:

View File

@ -109,19 +109,16 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
# If we use masks, blending between the denoised and original latent images occurs here. # If we use masks, blending between the denoised and original latent images occurs here.
def apply_blend(latent): def apply_blend(current_latent):
if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): blended_latent = current_latent * self.nmask + self.init_latent * self.mask
return self.p.denoiser_masked_blend_function(
self, if self.p.scripts is not None:
# Using an argument dictionary so that arguments can be added without breaking extensions. from modules import scripts
args= mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
{ self.p.scripts.on_mask_blend(self.p, mba)
"denoiser": self, blended_latent = mba.blended_latent
"current_latent": latent,
"sigma": sigma return blended_latent
})
else:
return self.init_latent * self.mask + self.nmask * latent
# Blend in the original latents (before) # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None: if self.mask_before_denoising and self.mask is not None: