mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-03-09 23:44:55 +08:00
Merge pull request #15190 from AUTOMATIC1111/dora-weight-decompose
Fix built-in lora system bugs caused by torch.nn.MultiheadAttention
This commit is contained in:
commit
4c9a7b8a75
@ -117,6 +117,12 @@ class NetworkModule:
|
|||||||
|
|
||||||
if hasattr(self.sd_module, 'weight'):
|
if hasattr(self.sd_module, 'weight'):
|
||||||
self.shape = self.sd_module.weight.shape
|
self.shape = self.sd_module.weight.shape
|
||||||
|
elif isinstance(self.sd_module, nn.MultiheadAttention):
|
||||||
|
# For now, only self-attn use Pytorch's MHA
|
||||||
|
# So assume all qkvo proj have same shape
|
||||||
|
self.shape = self.sd_module.out_proj.weight.shape
|
||||||
|
else:
|
||||||
|
self.shape = None
|
||||||
|
|
||||||
self.ops = None
|
self.ops = None
|
||||||
self.extra_kwargs = {}
|
self.extra_kwargs = {}
|
||||||
@ -146,7 +152,7 @@ class NetworkModule:
|
|||||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||||
|
|
||||||
self.dora_scale = weights.w["dora_scale"] if "dora_scale" in weights.w else None
|
self.dora_scale = weights.w.get("dora_scale", None)
|
||||||
self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1)
|
self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1)
|
||||||
|
|
||||||
def multiplier(self):
|
def multiplier(self):
|
||||||
|
@ -429,9 +429,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
# Send "real" orig_weight into MHA's lora module
|
||||||
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
qw, kw, vw = self.in_proj_weight.chunk(3, 0)
|
||||||
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
updown_q, _ = module_q.calc_updown(qw)
|
||||||
|
updown_k, _ = module_k.calc_updown(kw)
|
||||||
|
updown_v, _ = module_v.calc_updown(vw)
|
||||||
|
del qw, kw, vw
|
||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user