Added parameters for the composite stage, fixed batched generation.

This commit is contained in:
CodeHatchling 2023-12-07 20:19:35 -07:00
parent 0ef4a4cb23
commit f284ae23bc

View File

@ -6,22 +6,34 @@ import modules.scripts as scripts
class SoftInpaintingSettings: class SoftInpaintingSettings:
def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): def __init__(self,
mask_blend_power,
mask_blend_scale,
inpaint_detail_preservation,
composite_mask_influence,
composite_difference_threshold,
composite_difference_contrast):
self.mask_blend_power = mask_blend_power self.mask_blend_power = mask_blend_power
self.mask_blend_scale = mask_blend_scale self.mask_blend_scale = mask_blend_scale
self.inpaint_detail_preservation = inpaint_detail_preservation self.inpaint_detail_preservation = inpaint_detail_preservation
self.composite_mask_influence = composite_mask_influence
self.composite_difference_threshold = composite_difference_threshold
self.composite_difference_contrast = composite_difference_contrast
def add_generation_params(self, dest): def add_generation_params(self, dest):
dest[enabled_gen_param_label] = True dest[enabled_gen_param_label] = True
dest[gen_param_labels.mask_blend_power] = self.mask_blend_power dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence
dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold
dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast
# ------------------- Methods ------------------- # ------------------- Methods -------------------
def latent_blend(soft_inpainting, a, b, t): def latent_blend(settings, a, b, t):
""" """
Interpolates two latent image representations according to the parameter t, Interpolates two latent image representations according to the parameter t,
where the interpolated vectors' magnitudes are also interpolated separately. where the interpolated vectors' magnitudes are also interpolated separately.
@ -54,11 +66,11 @@ def latent_blend(soft_inpainting, a, b, t):
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
soft_inpainting.inpaint_detail_preservation) * one_minus_t3 settings.inpaint_detail_preservation) * one_minus_t3
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
soft_inpainting.inpaint_detail_preservation) * t3 settings.inpaint_detail_preservation) * t3
desired_magnitude = a_magnitude desired_magnitude = a_magnitude
desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
del a_magnitude, b_magnitude, t3, one_minus_t3 del a_magnitude, b_magnitude, t3, one_minus_t3
# Change the linearly interpolated image vectors' magnitudes to the value we want. # Change the linearly interpolated image vectors' magnitudes to the value we want.
@ -77,7 +89,7 @@ def latent_blend(soft_inpainting, a, b, t):
return image_interp_scaled return image_interp_scaled
def get_modified_nmask(soft_inpainting, nmask, sigma): def get_modified_nmask(settings, nmask, sigma):
""" """
Converts a negative mask representing the transparency of the original latent vectors being overlayed Converts a negative mask representing the transparency of the original latent vectors being overlayed
to a mask that is scaled according to the denoising strength for this step. to a mask that is scaled according to the denoising strength for this step.
@ -93,10 +105,12 @@ def get_modified_nmask(soft_inpainting, nmask, sigma):
NOTE: "mask" is not used NOTE: "mask" is not used
""" """
import torch import torch
return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)
def apply_adaptive_masks( def apply_adaptive_masks(
settings:SoftInpaintingSettings,
nmask,
latent_orig, latent_orig,
latent_processed, latent_processed,
overlay_images, overlay_images,
@ -108,11 +122,13 @@ def apply_adaptive_masks(
from PIL import Image, ImageOps, ImageFilter from PIL import Image, ImageOps, ImageFilter
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
# latent_mask = p.nmask[0].float().cpu() latent_mask = nmask[0].float()
# convert the original mask into a form we use to scale distances for thresholding # convert the original mask into a form we use to scale distances for thresholding
# mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
# mask_scalar = mask_scalar / (1.00001-mask_scalar) mask_scalar = (0.5 * (1-settings.composite_mask_influence)
# mask_scalar = mask_scalar.numpy() + mask_scalar * settings.composite_mask_influence)
mask_scalar = mask_scalar / (1.00001-mask_scalar)
mask_scalar = mask_scalar.cpu().numpy()
latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
@ -128,10 +144,10 @@ def apply_adaptive_masks(
percentile_min=0.25, percentile_max=0.75, min_width=1) percentile_min=0.25, percentile_max=0.75, min_width=1)
# The distance at which opacity of original decreases to 50% # The distance at which opacity of original decreases to 50%
# half_weighted_distance = 1 # * mask_scalar half_weighted_distance = settings.composite_difference_threshold * mask_scalar
# converted_mask = converted_mask / half_weighted_distance converted_mask = converted_mask / half_weighted_distance
converted_mask = 1 / (1 + converted_mask ** 2) converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
converted_mask = smootherstep(converted_mask) converted_mask = smootherstep(converted_mask)
converted_mask = 1 - converted_mask converted_mask = 1 - converted_mask
converted_mask = 255. * converted_mask converted_mask = 255. * converted_mask
@ -161,7 +177,7 @@ def apply_adaptive_masks(
def apply_masks( def apply_masks(
soft_inpainting, settings,
nmask, nmask,
overlay_images, overlay_images,
width, height, width, height,
@ -172,7 +188,7 @@ def apply_masks(
from PIL import Image, ImageOps, ImageFilter from PIL import Image, ImageOps, ImageFilter
converted_mask = nmask[0].float() converted_mask = nmask[0].float()
converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
converted_mask = 255. * converted_mask converted_mask = 255. * converted_mask
converted_mask = converted_mask.cpu().numpy().astype(np.uint8) converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
converted_mask = Image.fromarray(converted_mask) converted_mask = Image.fromarray(converted_mask)
@ -395,7 +411,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
# ------------------- Constants ------------------- # ------------------- Constants -------------------
default = SoftInpaintingSettings(1, 0.5, 4) default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)
enabled_ui_label = "Soft inpainting" enabled_ui_label = "Soft inpainting"
enabled_gen_param_label = "Soft inpainting enabled" enabled_gen_param_label = "Soft inpainting enabled"
@ -404,25 +420,37 @@ enabled_el_id = "soft_inpainting_enabled"
ui_labels = SoftInpaintingSettings( ui_labels = SoftInpaintingSettings(
"Schedule bias", "Schedule bias",
"Preservation strength", "Preservation strength",
"Transition contrast boost") "Transition contrast boost",
"Mask influence",
"Difference threshold",
"Difference contrast")
ui_info = SoftInpaintingSettings( ui_info = SoftInpaintingSettings(
"Shifts when preservation of original content occurs during denoising.", "Shifts when preservation of original content occurs during denoising.",
"How strongly partially masked content should be preserved.", "How strongly partially masked content should be preserved.",
"Amplifies the contrast that may be lost in partially masked regions.") "Amplifies the contrast that may be lost in partially masked regions.",
"How strongly the original mask should bias the difference threshold.",
"How much an image region can change before the original pixels are not blended in anymore.",
"How sharp the transition should be between blended and not blended.")
gen_param_labels = SoftInpaintingSettings( gen_param_labels = SoftInpaintingSettings(
"Soft inpainting schedule bias", "Soft inpainting schedule bias",
"Soft inpainting preservation strength", "Soft inpainting preservation strength",
"Soft inpainting transition contrast boost") "Soft inpainting transition contrast boost",
"Soft inpainting mask influence",
"Soft inpainting difference threshold",
"Soft inpainting difference contrast")
el_ids = SoftInpaintingSettings( el_ids = SoftInpaintingSettings(
"mask_blend_power", "mask_blend_power",
"mask_blend_scale", "mask_blend_scale",
"inpaint_detail_preservation") "inpaint_detail_preservation",
"composite_mask_influence",
"composite_difference_threshold",
"composite_difference_contrast")
# ----- # ------------------- Script -------------------
class Script(scripts.Script): class Script(scripts.Script):
@ -449,28 +477,62 @@ class Script(scripts.Script):
**High _Mask blur_** values are recommended! **High _Mask blur_** values are recommended!
""") """)
result = SoftInpaintingSettings( power = \
gr.Slider(label=ui_labels.mask_blend_power, gr.Slider(label=ui_labels.mask_blend_power,
info=ui_info.mask_blend_power, info=ui_info.mask_blend_power,
minimum=0, minimum=0,
maximum=8, maximum=8,
step=0.1, step=0.1,
value=default.mask_blend_power, value=default.mask_blend_power,
elem_id=el_ids.mask_blend_power), elem_id=el_ids.mask_blend_power)
scale = \
gr.Slider(label=ui_labels.mask_blend_scale, gr.Slider(label=ui_labels.mask_blend_scale,
info=ui_info.mask_blend_scale, info=ui_info.mask_blend_scale,
minimum=0, minimum=0,
maximum=8, maximum=8,
step=0.05, step=0.05,
value=default.mask_blend_scale, value=default.mask_blend_scale,
elem_id=el_ids.mask_blend_scale), elem_id=el_ids.mask_blend_scale)
detail = \
gr.Slider(label=ui_labels.inpaint_detail_preservation, gr.Slider(label=ui_labels.inpaint_detail_preservation,
info=ui_info.inpaint_detail_preservation, info=ui_info.inpaint_detail_preservation,
minimum=1, minimum=1,
maximum=32, maximum=32,
step=0.5, step=0.5,
value=default.inpaint_detail_preservation, value=default.inpaint_detail_preservation,
elem_id=el_ids.inpaint_detail_preservation)) elem_id=el_ids.inpaint_detail_preservation)
gr.Markdown(
"""
### Pixel Composite Settings
""")
mask_inf = \
gr.Slider(label=ui_labels.composite_mask_influence,
info=ui_info.composite_mask_influence,
minimum=0,
maximum=1,
step=0.05,
value=default.composite_mask_influence,
elem_id=el_ids.composite_mask_influence)
dif_thresh = \
gr.Slider(label=ui_labels.composite_difference_threshold,
info=ui_info.composite_difference_threshold,
minimum=0,
maximum=8,
step=0.25,
value=default.composite_difference_threshold,
elem_id=el_ids.composite_difference_threshold)
dif_contr = \
gr.Slider(label=ui_labels.composite_difference_contrast,
info=ui_info.composite_difference_contrast,
minimum=0,
maximum=8,
step=0.25,
value=default.composite_difference_contrast,
elem_id=el_ids.composite_difference_contrast)
with gr.Accordion("Help", open=False): with gr.Accordion("Help", open=False):
gr.Markdown( gr.Markdown(
@ -507,41 +569,86 @@ class Script(scripts.Script):
- **High values**: Stronger contrast, may over-saturate colors. - **High values**: Stronger contrast, may over-saturate colors.
""") """)
gr.Markdown(
"""
## Pixel Composite Settings
Masks are generated based on how much a part of the image changed after denoising.
These masks are used to blend the original and final images together.
If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
""")
gr.Markdown(
f"""
### {ui_labels.composite_mask_influence}
This parameter controls how much the mask should bias this sensitivity to difference.
- **0**: Ignore the mask, only consider differences in image content.
- **1**: Follow the mask closely despite image content changes.
""")
gr.Markdown(
f"""
### {ui_labels.composite_difference_threshold}
This value represents the difference at which the opacity of the original pixels will have less than 50% opacity.
- **Low values**: Two images patches must be almost the same in order to retain original pixels.
- **High values**: Two images patches can be very different and still retain original pixels.
""")
gr.Markdown(
f"""
### {ui_labels.composite_difference_contrast}
This value represents the difference at which the opacity of the original pixels will have less than 50% opacity.
- **Low values**: Two images patches must be almost the same in order to retain original pixels.
- **High values**: Two images patches can be very different and still retain original pixels.
""")
self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
(result.mask_blend_power, gen_param_labels.mask_blend_power), (power, gen_param_labels.mask_blend_power),
(result.mask_blend_scale, gen_param_labels.mask_blend_scale), (scale, gen_param_labels.mask_blend_scale),
(result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)] (detail, gen_param_labels.inpaint_detail_preservation),
(mask_inf, gen_param_labels.composite_mask_influence),
(dif_thresh, gen_param_labels.composite_difference_threshold),
(dif_contr, gen_param_labels.composite_difference_contrast)]
self.paste_field_names = [] self.paste_field_names = []
for _, field_name in self.infotext_fields: for _, field_name in self.infotext_fields:
self.paste_field_names.append(field_name) self.paste_field_names.append(field_name)
return [soft_inpainting_enabled, return [soft_inpainting_enabled,
result.mask_blend_power, power,
result.mask_blend_scale, scale,
result.inpaint_detail_preservation] detail,
mask_inf,
dif_thresh,
dif_contr]
def process(self, p, enabled, power, scale, detail_preservation): def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled: if not enabled:
return return
# Shut off the rounding it normally does. # Shut off the rounding it normally does.
p.mask_round = False p.mask_round = False
settings = SoftInpaintingSettings(power, scale, detail_preservation) settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
# p.extra_generation_params["Mask rounding"] = False # p.extra_generation_params["Mask rounding"] = False
settings.add_generation_params(p.extra_generation_params) settings.add_generation_params(p.extra_generation_params)
def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation): def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled: if not enabled:
return return
if mba.sigma is None: if mba.is_final_blend:
mba.blended_latent = mba.current_latent mba.blended_latent = mba.current_latent
return return
settings = SoftInpaintingSettings(power, scale, detail_preservation) settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
# todo: Why is sigma 2D? Both values are the same. # todo: Why is sigma 2D? Both values are the same.
mba.blended_latent = latent_blend(settings, mba.blended_latent = latent_blend(settings,
@ -549,11 +656,11 @@ class Script(scripts.Script):
mba.current_latent, mba.current_latent,
get_modified_nmask(settings, mba.nmask, mba.sigma[0])) get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation): def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled: if not enabled:
return return
settings = SoftInpaintingSettings(power, scale, detail_preservation) settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
from modules import images from modules import images
from modules.shared import opts from modules.shared import opts
@ -570,15 +677,20 @@ class Script(scripts.Script):
self.overlay_images.append(image.convert('RGBA')) self.overlay_images.append(image.convert('RGBA'))
if len(p.init_images) == 1:
self.overlay_images = self.overlay_images * p.batch_size
if getattr(ps.samples, 'already_decoded', False): if getattr(ps.samples, 'already_decoded', False):
self.masks_for_overlay = apply_masks(soft_inpainting=settings, self.masks_for_overlay = apply_masks(settings=settings,
nmask=p.nmask, nmask=p.nmask,
overlay_images=self.overlay_images, overlay_images=self.overlay_images,
width=p.width, width=p.width,
height=p.height, height=p.height,
paste_to=p.paste_to) paste_to=p.paste_to)
else: else:
self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent, self.masks_for_overlay = apply_adaptive_masks(settings=settings,
nmask=p.nmask,
latent_orig=p.init_latent,
latent_processed=ps.samples, latent_processed=ps.samples,
overlay_images=self.overlay_images, overlay_images=self.overlay_images,
width=p.width, width=p.width,
@ -586,7 +698,7 @@ class Script(scripts.Script):
paste_to=p.paste_to) paste_to=p.paste_to)
def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation): def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled: if not enabled:
return return