mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-07 22:22:53 +08:00
Fixed issue where batched inpainting (batch size > 1) wouldn't work because of mismatched tensor sizes. The 'already_decoded' decoded case should also be handled correctly (tested indirectly).
This commit is contained in:
parent
b32a334e3d
commit
6fc12428e3
@ -883,20 +883,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if getattr(samples_ddim, 'already_decoded', False):
|
if getattr(samples_ddim, 'already_decoded', False):
|
||||||
x_samples_ddim = samples_ddim
|
x_samples_ddim = samples_ddim
|
||||||
# todo: generate adaptive masks based on pixel differences.
|
# todo: generate adaptive masks based on pixel differences.
|
||||||
# if p.masks_for_overlay is used, it will already be populated with masks
|
if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
|
||||||
|
si.apply_masks(soft_inpainting=p.soft_inpainting,
|
||||||
|
nmask=p.nmask,
|
||||||
|
overlay_images=p.overlay_images,
|
||||||
|
masks_for_overlay=p.masks_for_overlay,
|
||||||
|
width=p.width,
|
||||||
|
height=p.height,
|
||||||
|
paste_to=p.paste_to)
|
||||||
else:
|
else:
|
||||||
if opts.sd_vae_decode_method != 'Full':
|
if opts.sd_vae_decode_method != 'Full':
|
||||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||||
|
|
||||||
# Generate the mask(s) based on similarity between the original and denoised latent vectors
|
# Generate the mask(s) based on similarity between the original and denoised latent vectors
|
||||||
if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
|
if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
|
||||||
si.generate_adaptive_masks(latent_orig=p.init_latent,
|
si.apply_adaptive_masks(latent_orig=p.init_latent,
|
||||||
latent_processed=samples_ddim,
|
latent_processed=samples_ddim,
|
||||||
overlay_images=p.overlay_images,
|
overlay_images=p.overlay_images,
|
||||||
masks_for_overlay=p.masks_for_overlay,
|
masks_for_overlay=p.masks_for_overlay,
|
||||||
width=p.width,
|
width=p.width,
|
||||||
height=p.height,
|
height=p.height,
|
||||||
paste_to=p.paste_to)
|
paste_to=p.paste_to)
|
||||||
|
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
|
@ -25,26 +25,32 @@ def latent_blend(soft_inpainting, a, b, t):
|
|||||||
|
|
||||||
# NOTE: We use inplace operations wherever possible.
|
# NOTE: We use inplace operations wherever possible.
|
||||||
|
|
||||||
one_minus_t = 1 - t
|
# [4][w][h] to [1][4][w][h]
|
||||||
|
t2 = t.unsqueeze(0)
|
||||||
|
# [4][w][h] to [1][1][w][h] - the [4] seem redundant.
|
||||||
|
t3 = t[0].unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
one_minus_t2 = 1 - t2
|
||||||
|
one_minus_t3 = 1 - t3
|
||||||
|
|
||||||
# Linearly interpolate the image vectors.
|
# Linearly interpolate the image vectors.
|
||||||
a_scaled = a * one_minus_t
|
a_scaled = a * one_minus_t2
|
||||||
b_scaled = b * t
|
b_scaled = b * t2
|
||||||
image_interp = a_scaled
|
image_interp = a_scaled
|
||||||
image_interp.add_(b_scaled)
|
image_interp.add_(b_scaled)
|
||||||
result_type = image_interp.dtype
|
result_type = image_interp.dtype
|
||||||
del a_scaled, b_scaled
|
del a_scaled, b_scaled, t2, one_minus_t2
|
||||||
|
|
||||||
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
|
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
|
||||||
# 64-bit operations are used here to allow large exponents.
|
# 64-bit operations are used here to allow large exponents.
|
||||||
current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
|
current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
|
||||||
|
|
||||||
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
|
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
|
||||||
a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t
|
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3
|
||||||
b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t
|
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3
|
||||||
desired_magnitude = a_magnitude
|
desired_magnitude = a_magnitude
|
||||||
desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
|
desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
|
||||||
del a_magnitude, b_magnitude, one_minus_t
|
del a_magnitude, b_magnitude, t3, one_minus_t3
|
||||||
|
|
||||||
# Change the linearly interpolated image vectors' magnitudes to the value we want.
|
# Change the linearly interpolated image vectors' magnitudes to the value we want.
|
||||||
# This is the last 64-bit operation.
|
# This is the last 64-bit operation.
|
||||||
@ -78,10 +84,11 @@ def get_modified_nmask(soft_inpainting, nmask, sigma):
|
|||||||
NOTE: "mask" is not used
|
NOTE: "mask" is not used
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
|
# todo: Why is sigma 2D? Both values are the same.
|
||||||
|
return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
|
||||||
|
|
||||||
|
|
||||||
def generate_adaptive_masks(
|
def apply_adaptive_masks(
|
||||||
latent_orig,
|
latent_orig,
|
||||||
latent_processed,
|
latent_processed,
|
||||||
overlay_images,
|
overlay_images,
|
||||||
@ -142,6 +149,45 @@ def generate_adaptive_masks(
|
|||||||
|
|
||||||
overlay_images[i] = image_masked.convert('RGBA')
|
overlay_images[i] = image_masked.convert('RGBA')
|
||||||
|
|
||||||
|
def apply_masks(
|
||||||
|
soft_inpainting,
|
||||||
|
nmask,
|
||||||
|
overlay_images,
|
||||||
|
masks_for_overlay,
|
||||||
|
width, height,
|
||||||
|
paste_to):
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import modules.processing as proc
|
||||||
|
import modules.images as images
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
|
||||||
|
converted_mask = nmask[0].float()
|
||||||
|
converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2)
|
||||||
|
converted_mask = 255. * converted_mask
|
||||||
|
converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
|
||||||
|
converted_mask = Image.fromarray(converted_mask)
|
||||||
|
converted_mask = images.resize_image(2, converted_mask, width, height)
|
||||||
|
converted_mask = proc.create_binary_mask(converted_mask, round=False)
|
||||||
|
|
||||||
|
# Remove aliasing artifacts using a gaussian blur.
|
||||||
|
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
|
||||||
|
|
||||||
|
# Expand the mask to fit the whole image if needed.
|
||||||
|
if paste_to is not None:
|
||||||
|
converted_mask = proc.uncrop(converted_mask,
|
||||||
|
(width, height),
|
||||||
|
paste_to)
|
||||||
|
|
||||||
|
for i, overlay_image in enumerate(overlay_images):
|
||||||
|
masks_for_overlay[i] = converted_mask
|
||||||
|
|
||||||
|
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
|
||||||
|
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
|
||||||
|
mask=ImageOps.invert(converted_mask.convert('L')))
|
||||||
|
|
||||||
|
overlay_images[i] = image_masked.convert('RGBA')
|
||||||
|
|
||||||
|
|
||||||
# ------------------- Constants -------------------
|
# ------------------- Constants -------------------
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user