stable-diffusion-webui/modules/extra_networks.py

226 lines
7.1 KiB
Python
Raw Normal View History

import json
import os
import re
import logging
from collections import defaultdict
from modules import errors
extra_network_registry = {}
extra_network_aliases = {}
def initialize():
extra_network_registry.clear()
extra_network_aliases.clear()
def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
def register_extra_network_alias(extra_network, alias):
extra_network_aliases[alias] = extra_network
def register_default_extra_networks():
from modules.extra_networks_hypernet import ExtraNetworkHypernet
register_extra_network(ExtraNetworkHypernet())
class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
self.positional = []
self.named = {}
for item in self.items:
parts = item.split('=', 2) if isinstance(item, str) else [item]
if len(parts) == 2:
self.named[parts[0]] = parts[1]
else:
self.positional.append(item)
2023-06-04 03:24:44 +08:00
def __eq__(self, other):
2023-06-04 04:19:47 +08:00
return self.items == other.items
2023-06-04 03:24:44 +08:00
class ExtraNetwork:
def __init__(self, name):
self.name = name
def activate(self, p, params_list):
"""
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
Passes arguments related to this extra network in params_list.
User passes arguments by specifying this in his prompt:
<name:arg1:arg2:arg3>
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
separated by colon.
2024-03-04 14:37:23 +08:00
Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -
in this case, all effects of this extra networks should be disabled.
Can be called multiple times before deactivate() - each new call should override the previous call completely.
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
params_list will be:
[
ExtraNetworkParams(items=["agm", "1.1"]),
ExtraNetworkParams(items=["ray"])
]
"""
raise NotImplementedError
def deactivate(self, p):
"""
Called at the end of processing for housekeeping. No need to do anything here.
"""
raise NotImplementedError
def lookup_extra_networks(extra_network_data):
"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.
Example input:
{
'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
}
Example output:
{
<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
}
"""
res = {}
for extra_network_name, extra_network_args in list(extra_network_data.items()):
extra_network = extra_network_registry.get(extra_network_name, None)
alias = extra_network_aliases.get(extra_network_name, None)
if alias is not None and extra_network is None:
extra_network = alias
if extra_network is None:
logging.info(f"Skipping unknown extra network: {extra_network_name}")
continue
res.setdefault(extra_network, []).extend(extra_network_args)
return res
def activate(p, extra_network_data):
"""call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""
activated = []
for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():
try:
extra_network.activate(p, extra_network_args)
activated.append(extra_network)
except Exception as e:
errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items():
if extra_network in activated:
continue
try:
extra_network.activate(p, [])
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}")
if p.scripts is not None:
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""
data = lookup_extra_networks(extra_network_data)
for extra_network in data:
try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating extra network {extra_network.name}")
for extra_network_name, extra_network in extra_network_registry.items():
if extra_network in data:
continue
try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
def parse_prompt(prompt):
res = defaultdict(list)
def found(m):
name = m.group(1)
args = m.group(2)
res[name].append(ExtraNetworkParams(items=args.split(":")))
return ""
prompt = re.sub(re_extra_net, found, prompt)
return prompt, res
def parse_prompts(prompts):
res = []
extra_data = None
for prompt in prompts:
updated_prompt, parsed_extra_data = parse_prompt(prompt)
if extra_data is None:
extra_data = parsed_extra_data
res.append(updated_prompt)
return res, extra_data
def get_user_metadata(filename, lister=None):
if filename is None:
return {}
basename, ext = os.path.splitext(filename)
metadata_filename = basename + '.json'
metadata = {}
try:
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
if exists:
with open(metadata_filename, "r", encoding="utf8") as file:
metadata = json.load(file)
except Exception as e:
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
return metadata