mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 15:15:05 +08:00
Merge pull request #7637 from brkirch/fix-hypernetworks-pix2pix
Fix hypernetworks and instruct pix2pix not working with `--upcast-sampling`
This commit is contained in:
commit
226bc04653
@ -380,8 +380,8 @@ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
|||||||
layer.hyper_k = hypernetwork_layers[0]
|
layer.hyper_k = hypernetwork_layers[0]
|
||||||
layer.hyper_v = hypernetwork_layers[1]
|
layer.hyper_v = hypernetwork_layers[1]
|
||||||
|
|
||||||
context_k = hypernetwork_layers[0](context_k)
|
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
|
||||||
context_v = hypernetwork_layers[1](context_v)
|
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
|
||||||
return context_k, context_v
|
return context_k, context_v
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,6 +104,9 @@ class StableDiffusionModelHijack:
|
|||||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
|
||||||
|
if m.cond_stage_key == "edit":
|
||||||
|
sd_hijack_unet.hijack_ddpm_edit()
|
||||||
|
|
||||||
self.optimization_method = apply_optimizations()
|
self.optimization_method = apply_optimizations()
|
||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
@ -44,6 +44,7 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
|||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||||
|
|
||||||
|
|
||||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||||
@ -53,6 +54,16 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return torch.nn.GELU.forward(self, x)
|
return torch.nn.GELU.forward(self, x)
|
||||||
|
|
||||||
|
|
||||||
|
ddpm_edit_hijack = None
|
||||||
|
def hijack_ddpm_edit():
|
||||||
|
global ddpm_edit_hijack
|
||||||
|
if not ddpm_edit_hijack:
|
||||||
|
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
|
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
|
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||||
|
|
||||||
|
|
||||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||||
|
Loading…
Reference in New Issue
Block a user