diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 76cef0a55..e45b82387 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -417,16 +417,8 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li restore_weights_backup(self, 'bias', bias_backup) -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): - """ - Applies the currently selected set of networks to the weights of torch layer self. - If weights already have this particular set of networks applied, does nothing. - If not, restores original weights from backup and alters weights according to networks. - """ - +def network_backup_weights(self): network_layer_name = getattr(self, 'network_layer_name', None) - if network_layer_name is None: - return current_names = getattr(self, "network_current_names", ()) wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) @@ -459,9 +451,31 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_bias_backup = bias_backup - if current_names != wanted_names: + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of networks to the weights of torch layer self. + If weights already have this particular set of networks applied, does nothing. + If not, restores original weights from backup and alters weights according to networks. + """ + + network_layer_name = getattr(self, 'network_layer_name', None) + if network_layer_name is None: + return + + current_names = getattr(self, "network_current_names", ()) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None and wanted_names != (): + network_backup_weights(self) + elif current_names != () and current_names != wanted_names and not getattr(self, "weights_restored", False): network_restore_weights_from_backup(self) + if current_names != wanted_names: + if hasattr(self, "weights_restored"): + self.weights_restored = False + for net in loaded_networks: module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 8ee93efef..8163a05f3 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -143,6 +143,10 @@ class ScriptLora(scripts.Script): target_dtype = devices.dtype_inference for module in modules: + network_layer_name = getattr(module, 'network_layer_name', None) + if network_layer_name is None: + continue + if isinstance(module, torch.nn.MultiheadAttention): org_dtype = torch.float32 else: @@ -155,6 +159,18 @@ class ScriptLora(scripts.Script): # set org_dtype module.org_dtype = org_dtype + # backup/restore weights + current_names = getattr(module, "network_current_names", ()) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in networks.loaded_networks) + + weights_backup = getattr(module, "network_weights_backup", None) + + if current_names == () and current_names != wanted_names and weights_backup is None: + networks.network_backup_weights(module) + elif current_names != () and current_names != wanted_names: + networks.network_restore_weights_from_backup(module) + module.weights_restored = True + script_callbacks.on_infotext_pasted(infotext_pasted)