mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-20 05:10:15 +08:00
call gc.collect() when wanted_names == ()
This commit is contained in:
parent
04f9084253
commit
2a1988fa67
@ -394,7 +394,7 @@ def restore_weights_backup(obj, field, weight):
|
|||||||
getattr(obj, field).copy_(weight)
|
getattr(obj, field).copy_(weight)
|
||||||
|
|
||||||
|
|
||||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
bias_backup = getattr(self, "network_bias_backup", None)
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
|
||||||
@ -416,6 +416,12 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
|
|||||||
else:
|
else:
|
||||||
restore_weights_backup(self, 'bias', bias_backup)
|
restore_weights_backup(self, 'bias', bias_backup)
|
||||||
|
|
||||||
|
if cleanup:
|
||||||
|
if weights_backup is not None:
|
||||||
|
del self.network_weights_backup
|
||||||
|
if bias_backup is not None:
|
||||||
|
del self.network_bias_backup
|
||||||
|
|
||||||
|
|
||||||
def network_backup_weights(self):
|
def network_backup_weights(self):
|
||||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
import gc
|
||||||
import network
|
import network
|
||||||
import networks
|
import networks
|
||||||
import lora # noqa:F401
|
import lora # noqa:F401
|
||||||
@ -168,8 +169,10 @@ class ScriptLora(scripts.Script):
|
|||||||
if current_names == () and current_names != wanted_names and weights_backup is None:
|
if current_names == () and current_names != wanted_names and weights_backup is None:
|
||||||
networks.network_backup_weights(module)
|
networks.network_backup_weights(module)
|
||||||
elif current_names != () and current_names != wanted_names:
|
elif current_names != () and current_names != wanted_names:
|
||||||
networks.network_restore_weights_from_backup(module)
|
networks.network_restore_weights_from_backup(module, wanted_names == ())
|
||||||
module.weights_restored = True
|
module.weights_restored = True
|
||||||
|
if current_names != wanted_names and wanted_names == ():
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||||
|
Loading…
Reference in New Issue
Block a user