mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-08 14:42:54 +08:00
fix for Lora flux
This commit is contained in:
parent
d6a609a539
commit
51c285265f
@ -2,6 +2,7 @@ import torch
|
|||||||
|
|
||||||
import lyco_helpers
|
import lyco_helpers
|
||||||
import modules.models.sd3.mmdit
|
import modules.models.sd3.mmdit
|
||||||
|
import modules.models.flux.modules.layers
|
||||||
import network
|
import network
|
||||||
from modules import devices
|
from modules import devices
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ class NetworkModuleLora(network.NetworkModule):
|
|||||||
if weight is None and none_ok:
|
if weight is None and none_ok:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ]
|
||||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
|
|
||||||
if is_linear:
|
if is_linear:
|
||||||
|
@ -37,7 +37,7 @@ module_types = [
|
|||||||
|
|
||||||
|
|
||||||
re_digits = re.compile(r"\d+")
|
re_digits = re.compile(r"\d+")
|
||||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$")
|
||||||
re_compiled = {}
|
re_compiled = {}
|
||||||
|
|
||||||
suffix_conversion = {
|
suffix_conversion = {
|
||||||
@ -460,7 +460,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
|
|
||||||
for net in loaded_networks:
|
for net in loaded_networks:
|
||||||
module = net.modules.get(network_layer_name, None)
|
module = net.modules.get(network_layer_name, None)
|
||||||
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
|
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if getattr(self, 'fp16_weight', None) is None:
|
if getattr(self, 'fp16_weight', None) is None:
|
||||||
@ -520,7 +520,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
|
module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None)
|
||||||
|
|
||||||
|
if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None:
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Send "real" orig_weight into MHA's lora module
|
# Send "real" orig_weight into MHA's lora module
|
||||||
@ -531,6 +533,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
del qw, kw, vw
|
del qw, kw, vw
|
||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
self.weight += updown_qkv
|
self.weight += updown_qkv
|
||||||
|
del updown_qkv
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp:
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0)
|
||||||
|
updown_q, _ = module_q.calc_updown(qw)
|
||||||
|
updown_k, _ = module_k.calc_updown(kw)
|
||||||
|
updown_v, _ = module_v.calc_updown(vw)
|
||||||
|
updown_mlp, _ = module_v.calc_updown(mlp)
|
||||||
|
del qw, kw, vw, mlp
|
||||||
|
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
|
||||||
|
self.weight += updown_qkv_mlp
|
||||||
|
del updown_qkv_mlp
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
Loading…
Reference in New Issue
Block a user