mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-03-09 15:34:54 +08:00
Merge pull request #12543 from AUTOMATIC1111/extra-norm-module
Fix MHA error with ex_bias and support ex_bias for layers which don't have bias
This commit is contained in:
commit
3a4bee1096
@ -277,7 +277,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
|
||||
self.weight.copy_(weights_backup)
|
||||
|
||||
if bias_backup is not None:
|
||||
self.bias.copy_(bias_backup)
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.out_proj.bias.copy_(bias_backup)
|
||||
else:
|
||||
self.bias.copy_(bias_backup)
|
||||
else:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.out_proj.bias = None
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
@ -304,8 +312,13 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
self.network_weights_backup = weights_backup
|
||||
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if bias_backup is None and getattr(self, 'bias', None) is not None:
|
||||
bias_backup = self.bias.to(devices.cpu, copy=True)
|
||||
if bias_backup is None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
|
||||
elif getattr(self, 'bias', None) is not None:
|
||||
bias_backup = self.bias.to(devices.cpu, copy=True)
|
||||
else:
|
||||
bias_backup = None
|
||||
self.network_bias_backup = bias_backup
|
||||
|
||||
if current_names != wanted_names:
|
||||
@ -323,8 +336,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
|
||||
self.weight += updown
|
||||
if ex_bias is not None and getattr(self, 'bias', None) is not None:
|
||||
self.bias += ex_bias
|
||||
if ex_bias is not None and hasattr(self, 'bias'):
|
||||
if self.bias is None:
|
||||
self.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.bias += ex_bias
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
@ -339,14 +355,19 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
updown_out = module_out.calc_updown(self.out_proj.weight)
|
||||
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += updown_out
|
||||
if ex_bias is not None:
|
||||
if self.out_proj.bias is None:
|
||||
self.out_proj.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.out_proj.bias += ex_bias
|
||||
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
|
Loading…
Reference in New Issue
Block a user