mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
fixed ai-toolkit flux lora support
* fixed some mistake * some ai-toolkit's lora do not have proj_mlp
This commit is contained in:
parent
4bea93bc06
commit
30d0f950b7
@ -526,7 +526,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
|
||||
module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None)
|
||||
|
||||
if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None:
|
||||
if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None and self.weight.shape[0] // 3 == module_q.up_model.weight.shape[0]:
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# Send "real" orig_weight into MHA's lora module
|
||||
@ -545,14 +545,17 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
|
||||
continue
|
||||
|
||||
if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp:
|
||||
if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v:
|
||||
try:
|
||||
with torch.no_grad():
|
||||
qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0)
|
||||
updown_q, _ = module_q.calc_updown(qw)
|
||||
updown_k, _ = module_k.calc_updown(kw)
|
||||
updown_v, _ = module_v.calc_updown(vw)
|
||||
updown_mlp, _ = module_v.calc_updown(mlp)
|
||||
if module_mlp is not None:
|
||||
updown_mlp, _ = module_mlp.calc_updown(mlp)
|
||||
else:
|
||||
updown_mlp = torch.zeros(3072 * 4, 3072, dtype=updown_q.dtype, device=updown_q.device)
|
||||
del qw, kw, vw, mlp
|
||||
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
|
||||
self.weight += updown_qkv_mlp
|
||||
|
Loading…
Reference in New Issue
Block a user