mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-25 06:19:00 +08:00
add some codes for robust
This commit is contained in:
parent
9feb034e34
commit
bfe418a58d
@ -108,6 +108,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
|||||||
else:
|
else:
|
||||||
sd = sd_model.model.state_dict()
|
sd = sd_model.model.state_dict()
|
||||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input is not None:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
if diffusion_model_input.shape[1] == 9:
|
||||||
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
|
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
|
||||||
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
|
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
|
||||||
@ -378,6 +379,7 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
sd = self.sampler.model_wrap.inner_model.model.state_dict()
|
sd = self.sampler.model_wrap.inner_model.model.state_dict()
|
||||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input is not None:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
if diffusion_model_input.shape[1] == 9:
|
||||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
|||||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||||
sd = self.model.state_dict()
|
sd = self.model.state_dict()
|
||||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input is not None:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
if diffusion_model_input.shape[1] == 9:
|
||||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user