call gc.collect() when wanted_names == ()

This commit is contained in:
Won-Kyu Park 2024-10-03 19:25:50 +09:00
parent 04f9084253
commit 2a1988fa67
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 11 additions and 2 deletions

View File

@ -394,7 +394,7 @@ def restore_weights_backup(obj, field, weight):
getattr(obj, field).copy_(weight)
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
@ -416,6 +416,12 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
else:
restore_weights_backup(self, 'bias', bias_backup)
if cleanup:
if weights_backup is not None:
del self.network_weights_backup
if bias_backup is not None:
del self.network_bias_backup
def network_backup_weights(self):
network_layer_name = getattr(self, 'network_layer_name', None)

View File

@ -4,6 +4,7 @@ import torch
import gradio as gr
from fastapi import FastAPI
import gc
import network
import networks
import lora # noqa:F401
@ -168,8 +169,10 @@ class ScriptLora(scripts.Script):
if current_names == () and current_names != wanted_names and weights_backup is None:
networks.network_backup_weights(module)
elif current_names != () and current_names != wanted_names:
networks.network_restore_weights_from_backup(module)
networks.network_restore_weights_from_backup(module, wanted_names == ())
module.weights_restored = True
if current_names != wanted_names and wanted_names == ():
gc.collect()
script_callbacks.on_infotext_pasted(infotext_pasted)