From 2a1988fa67e204805b8d63942527c3843fd9a3df Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 3 Oct 2024 19:25:50 +0900 Subject: [PATCH] call gc.collect() when wanted_names == () --- extensions-builtin/Lora/networks.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 5 ++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index e45b82387..78d7407a0 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -394,7 +394,7 @@ def restore_weights_backup(obj, field, weight): getattr(obj, field).copy_(weight) -def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False): weights_backup = getattr(self, "network_weights_backup", None) bias_backup = getattr(self, "network_bias_backup", None) @@ -416,6 +416,12 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li else: restore_weights_backup(self, 'bias', bias_backup) + if cleanup: + if weights_backup is not None: + del self.network_weights_backup + if bias_backup is not None: + del self.network_bias_backup + def network_backup_weights(self): network_layer_name = getattr(self, 'network_layer_name', None) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 8163a05f3..7a23b8d57 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -4,6 +4,7 @@ import torch import gradio as gr from fastapi import FastAPI +import gc import network import networks import lora # noqa:F401 @@ -168,8 +169,10 @@ class ScriptLora(scripts.Script): if current_names == () and current_names != wanted_names and weights_backup is None: networks.network_backup_weights(module) elif current_names != () and current_names != wanted_names: - networks.network_restore_weights_from_backup(module) + networks.network_restore_weights_from_backup(module, wanted_names == ()) module.weights_restored = True + if current_names != wanted_names and wanted_names == (): + gc.collect() script_callbacks.on_infotext_pasted(infotext_pasted)