call gc.collect() when wanted_names == ()

This commit is contained in:
Won-Kyu Park 2024-10-03 19:25:50 +09:00
parent 04f9084253
commit 2a1988fa67
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 11 additions and 2 deletions

View File

@ -394,7 +394,7 @@ def restore_weights_backup(obj, field, weight):
getattr(obj, field).copy_(weight) getattr(obj, field).copy_(weight)
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):
weights_backup = getattr(self, "network_weights_backup", None) weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None) bias_backup = getattr(self, "network_bias_backup", None)
@ -416,6 +416,12 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
else: else:
restore_weights_backup(self, 'bias', bias_backup) restore_weights_backup(self, 'bias', bias_backup)
if cleanup:
if weights_backup is not None:
del self.network_weights_backup
if bias_backup is not None:
del self.network_bias_backup
def network_backup_weights(self): def network_backup_weights(self):
network_layer_name = getattr(self, 'network_layer_name', None) network_layer_name = getattr(self, 'network_layer_name', None)

View File

@ -4,6 +4,7 @@ import torch
import gradio as gr import gradio as gr
from fastapi import FastAPI from fastapi import FastAPI
import gc
import network import network
import networks import networks
import lora # noqa:F401 import lora # noqa:F401
@ -168,8 +169,10 @@ class ScriptLora(scripts.Script):
if current_names == () and current_names != wanted_names and weights_backup is None: if current_names == () and current_names != wanted_names and weights_backup is None:
networks.network_backup_weights(module) networks.network_backup_weights(module)
elif current_names != () and current_names != wanted_names: elif current_names != () and current_names != wanted_names:
networks.network_restore_weights_from_backup(module) networks.network_restore_weights_from_backup(module, wanted_names == ())
module.weights_restored = True module.weights_restored = True
if current_names != wanted_names and wanted_names == ():
gc.collect()
script_callbacks.on_infotext_pasted(infotext_pasted) script_callbacks.on_infotext_pasted(infotext_pasted)