fix for Lora flux

This commit is contained in:
Won-Kyu Park 2024-09-06 14:07:31 +09:00
parent d6a609a539
commit 51c285265f
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 27 additions and 4 deletions

View File

@ -2,6 +2,7 @@ import torch
import lyco_helpers
import modules.models.sd3.mmdit
import modules.models.flux.modules.layers
import network
from modules import devices
@ -37,7 +38,7 @@ class NetworkModuleLora(network.NetworkModule):
if weight is None and none_ok:
return None
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
if is_linear:

View File

@ -37,7 +37,7 @@ module_types = [
re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$")
re_compiled = {}
suffix_conversion = {
@ -460,7 +460,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):
try:
with torch.no_grad():
if getattr(self, 'fp16_weight', None) is None:
@ -520,7 +520,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
continue
if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None)
if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None:
try:
with torch.no_grad():
# Send "real" orig_weight into MHA's lora module
@ -531,6 +533,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
del qw, kw, vw
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.weight += updown_qkv
del updown_qkv
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp:
try:
with torch.no_grad():
qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0)
updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw)
updown_mlp, _ = module_v.calc_updown(mlp)
del qw, kw, vw, mlp
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
self.weight += updown_qkv_mlp
del updown_qkv_mlp
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")