mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 03:40:14 +08:00
Merge pull request #15943 from eatmoreapple/update-lora-load
feat: lora partial update precede full update
This commit is contained in:
commit
6de733c91e
@ -260,6 +260,16 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
loaded_networks.clear()
|
loaded_networks.clear()
|
||||||
|
|
||||||
|
unavailable_networks = []
|
||||||
|
for name in names:
|
||||||
|
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
|
||||||
|
unavailable_networks.append(name)
|
||||||
|
elif available_network_aliases.get(name) is None:
|
||||||
|
unavailable_networks.append(name)
|
||||||
|
|
||||||
|
if unavailable_networks:
|
||||||
|
update_available_networks_by_names(unavailable_networks)
|
||||||
|
|
||||||
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
||||||
if any(x is None for x in networks_on_disk):
|
if any(x is None for x in networks_on_disk):
|
||||||
list_available_networks()
|
list_available_networks()
|
||||||
@ -566,22 +576,16 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|||||||
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
|
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def list_available_networks():
|
def process_network_files(names: list[str] | None = None):
|
||||||
available_networks.clear()
|
|
||||||
available_network_aliases.clear()
|
|
||||||
forbidden_network_aliases.clear()
|
|
||||||
available_network_hash_lookup.clear()
|
|
||||||
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
|
||||||
|
|
||||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
|
||||||
|
|
||||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
for filename in candidates:
|
for filename in candidates:
|
||||||
if os.path.isdir(filename):
|
if os.path.isdir(filename):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
# if names is provided, only load networks with names in the list
|
||||||
|
if names and name not in names:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
entry = network.NetworkOnDisk(name, filename)
|
entry = network.NetworkOnDisk(name, filename)
|
||||||
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
||||||
@ -597,6 +601,22 @@ def list_available_networks():
|
|||||||
available_network_aliases[entry.alias] = entry
|
available_network_aliases[entry.alias] = entry
|
||||||
|
|
||||||
|
|
||||||
|
def update_available_networks_by_names(names: list[str]):
|
||||||
|
process_network_files(names)
|
||||||
|
|
||||||
|
|
||||||
|
def list_available_networks():
|
||||||
|
available_networks.clear()
|
||||||
|
available_network_aliases.clear()
|
||||||
|
forbidden_network_aliases.clear()
|
||||||
|
available_network_hash_lookup.clear()
|
||||||
|
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
||||||
|
|
||||||
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||||
|
|
||||||
|
process_network_files()
|
||||||
|
|
||||||
|
|
||||||
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user