fixed ai-toolkit flux lora support

* fixed some mistake
 * some ai-toolkit's lora do not have proj_mlp
This commit is contained in:
Won-Kyu Park 2024-09-20 12:30:38 +09:00
parent 4bea93bc06
commit 30d0f950b7
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -526,7 +526,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None) module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None)
if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None: if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None and self.weight.shape[0] // 3 == module_q.up_model.weight.shape[0]:
try: try:
with torch.no_grad(): with torch.no_grad():
# Send "real" orig_weight into MHA's lora module # Send "real" orig_weight into MHA's lora module
@ -545,14 +545,17 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
continue continue
if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp: if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v:
try: try:
with torch.no_grad(): with torch.no_grad():
qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0) qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0)
updown_q, _ = module_q.calc_updown(qw) updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw) updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw) updown_v, _ = module_v.calc_updown(vw)
updown_mlp, _ = module_v.calc_updown(mlp) if module_mlp is not None:
updown_mlp, _ = module_mlp.calc_updown(mlp)
else:
updown_mlp = torch.zeros(3072 * 4, 3072, dtype=updown_q.dtype, device=updown_q.device)
del qw, kw, vw, mlp del qw, kw, vw, mlp
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
self.weight += updown_qkv_mlp self.weight += updown_qkv_mlp