mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
call gc.collect() when wanted_names == ()
This commit is contained in:
parent
04f9084253
commit
2a1988fa67
@ -394,7 +394,7 @@ def restore_weights_backup(obj, field, weight):
|
||||
getattr(obj, field).copy_(weight)
|
||||
|
||||
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
|
||||
@ -416,6 +416,12 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
|
||||
else:
|
||||
restore_weights_backup(self, 'bias', bias_backup)
|
||||
|
||||
if cleanup:
|
||||
if weights_backup is not None:
|
||||
del self.network_weights_backup
|
||||
if bias_backup is not None:
|
||||
del self.network_bias_backup
|
||||
|
||||
|
||||
def network_backup_weights(self):
|
||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||
|
@ -4,6 +4,7 @@ import torch
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
|
||||
import gc
|
||||
import network
|
||||
import networks
|
||||
import lora # noqa:F401
|
||||
@ -168,8 +169,10 @@ class ScriptLora(scripts.Script):
|
||||
if current_names == () and current_names != wanted_names and weights_backup is None:
|
||||
networks.network_backup_weights(module)
|
||||
elif current_names != () and current_names != wanted_names:
|
||||
networks.network_restore_weights_from_backup(module)
|
||||
networks.network_restore_weights_from_backup(module, wanted_names == ())
|
||||
module.weights_restored = True
|
||||
if current_names != wanted_names and wanted_names == ():
|
||||
gc.collect()
|
||||
|
||||
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
|
Loading…
Reference in New Issue
Block a user