mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-31 02:32:57 +08:00
Add memory cache for VAE weights
This commit is contained in:
parent
c6f347b81f
commit
893933e05a
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
import collections
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from modules import shared, devices, script_callbacks
|
from modules import shared, devices, script_callbacks
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
@ -30,6 +31,7 @@ base_vae = None
|
|||||||
loaded_vae_file = None
|
loaded_vae_file = None
|
||||||
checkpoint_info = None
|
checkpoint_info = None
|
||||||
|
|
||||||
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
def get_base_vae(model):
|
def get_base_vae(model):
|
||||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||||
@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
|
|||||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||||
# save_settings = False
|
# save_settings = False
|
||||||
|
|
||||||
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||||
|
|
||||||
if vae_file:
|
if vae_file:
|
||||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
if cache_enabled and vae_file in checkpoints_loaded:
|
||||||
print(f"Loading VAE weights from: {vae_file}")
|
# use vae checkpoint cache
|
||||||
store_base_vae(model)
|
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
store_base_vae(model)
|
||||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||||
_load_vae_dict(model, vae_dict_1)
|
else:
|
||||||
|
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||||
|
print(f"Loading VAE weights from: {vae_file}")
|
||||||
|
store_base_vae(model)
|
||||||
|
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||||
|
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||||
|
_load_vae_dict(model, vae_dict_1)
|
||||||
|
|
||||||
|
if cache_enabled:
|
||||||
|
# cache newly loaded vae
|
||||||
|
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||||
|
|
||||||
|
# clean up cache if limit is reached
|
||||||
|
if cache_enabled:
|
||||||
|
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||||
|
checkpoints_loaded.popitem(last=False) # LRU
|
||||||
|
|
||||||
# If vae used is not in dict, update it
|
# If vae used is not in dict, update it
|
||||||
# It will be removed on refresh though
|
# It will be removed on refresh though
|
||||||
|
@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), {
|
|||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
|
Loading…
Reference in New Issue
Block a user