diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index a7a088949..2bc6af5d2 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -2,6 +2,7 @@ import torch import lyco_helpers import modules.models.sd3.mmdit +import modules.models.flux.modules.layers import network from modules import devices @@ -37,7 +38,7 @@ class NetworkModuleLora(network.NetworkModule): if weight is None and none_ok: return None - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear] + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ] is_conv = type(self.sd_module) in [torch.nn.Conv2d] if is_linear: diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index e58e1fb56..46c9ac1ff 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -37,7 +37,7 @@ module_types = [ re_digits = re.compile(r"\d+") -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$") re_compiled = {} suffix_conversion = { @@ -460,7 +460,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn for net in loaded_networks: module = net.modules.get(network_layer_name, None) - if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear): + if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)): try: with torch.no_grad(): if getattr(self, 'fp16_weight', None) is None: @@ -520,7 +520,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn continue - if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v: + module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None) + + if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None: try: with torch.no_grad(): # Send "real" orig_weight into MHA's lora module @@ -531,6 +533,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn del qw, kw, vw updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) self.weight += updown_qkv + del updown_qkv + + except RuntimeError as e: + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + + continue + + if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp: + try: + with torch.no_grad(): + qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0) + updown_q, _ = module_q.calc_updown(qw) + updown_k, _ = module_k.calc_updown(kw) + updown_v, _ = module_v.calc_updown(vw) + updown_mlp, _ = module_v.calc_updown(mlp) + del qw, kw, vw, mlp + updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) + self.weight += updown_qkv_mlp + del updown_qkv_mlp except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")