mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-04 19:49:01 +08:00
Highres fix works with unmasked latent.
Also refactor the mask creation to make it more accesible.
This commit is contained in:
parent
f3f2ffd448
commit
26a3fd2fe9
@ -129,6 +129,73 @@ class StableDiffusionProcessing():
|
|||||||
self.all_seeds = None
|
self.all_seeds = None
|
||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
|
|
||||||
|
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||||
|
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
||||||
|
# Dummy zero conditioning if we're not using inpainting model.
|
||||||
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
|
return torch.zeros(
|
||||||
|
x.shape[0], 5, 1, 1,
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
height = height or self.height
|
||||||
|
width = width or self.width
|
||||||
|
|
||||||
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||||
|
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||||
|
|
||||||
|
# Add the fake full 1s mask to the first dimension.
|
||||||
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
|
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||||
|
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
||||||
|
# Dummy zero conditioning if we're not using inpainting model.
|
||||||
|
return torch.zeros(
|
||||||
|
latent_image.shape[0], 5, 1, 1,
|
||||||
|
dtype=latent_image.dtype,
|
||||||
|
device=latent_image.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle the different mask inputs
|
||||||
|
if image_mask is not None:
|
||||||
|
if torch.is_tensor(image_mask):
|
||||||
|
conditioning_mask = image_mask
|
||||||
|
else:
|
||||||
|
conditioning_mask = np.array(image_mask.convert("L"))
|
||||||
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||||
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||||
|
|
||||||
|
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||||
|
conditioning_mask = torch.round(conditioning_mask)
|
||||||
|
else:
|
||||||
|
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
|
||||||
|
|
||||||
|
# Create another latent image, this time with a masked version of the original input.
|
||||||
|
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
||||||
|
conditioning_mask = conditioning_mask.to(source_image.device)
|
||||||
|
conditioning_image = torch.lerp(
|
||||||
|
source_image,
|
||||||
|
source_image * (1.0 - conditioning_mask),
|
||||||
|
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode the new masked image using first stage of network.
|
||||||
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||||
|
|
||||||
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||||
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||||
|
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||||
|
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||||
|
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||||
|
|
||||||
def create_dummy_mask(self, x, width=None, height=None):
|
|
||||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
|
||||||
height = height or self.height
|
|
||||||
width = width or self.width
|
|
||||||
|
|
||||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
|
||||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
|
||||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
|
||||||
|
|
||||||
# Add the fake full 1s mask to the first dimension.
|
|
||||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
|
||||||
image_conditioning = image_conditioning.to(x.dtype)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Dummy zero conditioning if we're not using inpainting model.
|
|
||||||
# Still takes up a bit of memory, but no encoder call.
|
|
||||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
|
||||||
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
return image_conditioning
|
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||||
|
|
||||||
@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
|
image_conditioning = self.img2img_image_conditioning(
|
||||||
|
decoded_samples,
|
||||||
|
samples,
|
||||||
|
decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
|
||||||
|
)
|
||||||
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -770,40 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
|
||||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
|
||||||
if self.image_mask is not None:
|
|
||||||
conditioning_mask = np.array(self.image_mask.convert("L"))
|
|
||||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
|
||||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
|
||||||
|
|
||||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
|
||||||
conditioning_mask = torch.round(conditioning_mask)
|
|
||||||
else:
|
|
||||||
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
|
||||||
|
|
||||||
# Create another latent image, this time with a masked version of the original input.
|
|
||||||
conditioning_mask = conditioning_mask.to(image.device)
|
|
||||||
|
|
||||||
# Smoothly interpolate between the masked and unmasked latent conditioning image.
|
|
||||||
conditioning_image = torch.lerp(
|
|
||||||
image,
|
|
||||||
image * (1.0 - conditioning_mask),
|
|
||||||
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
|
||||||
|
|
||||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
|
||||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
|
||||||
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
|
||||||
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
|
||||||
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
|
||||||
else:
|
|
||||||
self.image_conditioning = torch.zeros(
|
|
||||||
self.init_latent.shape[0], 5, 1, 1,
|
|
||||||
dtype=self.init_latent.dtype,
|
|
||||||
device=self.init_latent.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user