mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-06 04:39:01 +08:00
Lora cache in memory
This commit is contained in:
parent
7ba8f11688
commit
eed963e972
@ -195,6 +195,15 @@ def load_network(name, network_on_disk):
|
|||||||
return net
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def purge_networks_from_memory():
|
||||||
|
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
|
||||||
|
name = next(iter(networks_in_memory))
|
||||||
|
networks_in_memory.pop(name, None)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
|
|
||||||
@ -212,15 +221,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
failed_to_load_networks = []
|
failed_to_load_networks = []
|
||||||
|
|
||||||
for i, name in enumerate(names):
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
||||||
net = already_loaded.get(name, None)
|
net = already_loaded.get(name, None)
|
||||||
|
|
||||||
network_on_disk = networks_on_disk[i]
|
|
||||||
|
|
||||||
if network_on_disk is not None:
|
if network_on_disk is not None:
|
||||||
|
if net is None:
|
||||||
|
net = networks_in_memory.get(name)
|
||||||
|
|
||||||
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
||||||
try:
|
try:
|
||||||
net = load_network(name, network_on_disk)
|
net = load_network(name, network_on_disk)
|
||||||
|
|
||||||
|
networks_in_memory.pop(name, None)
|
||||||
|
networks_in_memory[name] = net
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, f"loading network {network_on_disk.filename}")
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||||
continue
|
continue
|
||||||
@ -242,6 +255,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
||||||
|
|
||||||
|
purge_networks_from_memory()
|
||||||
|
|
||||||
|
|
||||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
@ -462,6 +477,7 @@ def infotext_pasted(infotext, params):
|
|||||||
available_networks = {}
|
available_networks = {}
|
||||||
available_network_aliases = {}
|
available_network_aliases = {}
|
||||||
loaded_networks = []
|
loaded_networks = []
|
||||||
|
networks_in_memory = {}
|
||||||
available_network_hash_lookup = {}
|
available_network_hash_lookup = {}
|
||||||
forbidden_network_aliases = {}
|
forbidden_network_aliases = {}
|
||||||
|
|
||||||
|
@ -65,6 +65,7 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
|||||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||||
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||||
|
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
@ -121,3 +122,5 @@ def infotext_pasted(infotext, d):
|
|||||||
|
|
||||||
|
|
||||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||||
|
|
||||||
|
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user