mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 15:15:05 +08:00
implementing script metadata and DAG sorting mechanism
This commit is contained in:
parent
5e80d9ee99
commit
0fc7dc1c04
@ -1,3 +1,5 @@
|
|||||||
|
import configparser
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
@ -23,8 +25,9 @@ class Extension:
|
|||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
||||||
|
|
||||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.canonical_name = canonical_name or name.lower()
|
||||||
self.path = path
|
self.path = path
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
self.status = ''
|
self.status = ''
|
||||||
@ -37,6 +40,17 @@ class Extension:
|
|||||||
self.remote = None
|
self.remote = None
|
||||||
self.have_info_from_repo = False
|
self.have_info_from_repo = False
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def metadata(self):
|
||||||
|
if os.path.isfile(os.path.join(self.path, "sd_webui_metadata.ini")):
|
||||||
|
try:
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(os.path.join(self.path, "sd_webui_metadata.ini"))
|
||||||
|
return config
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {x: getattr(self, x) for x in self.cached_fields}
|
return {x: getattr(self, x) for x in self.cached_fields}
|
||||||
|
|
||||||
@ -136,9 +150,6 @@ class Extension:
|
|||||||
def list_extensions():
|
def list_extensions():
|
||||||
extensions.clear()
|
extensions.clear()
|
||||||
|
|
||||||
if not os.path.isdir(extensions_dir):
|
|
||||||
return
|
|
||||||
|
|
||||||
if shared.cmd_opts.disable_all_extensions:
|
if shared.cmd_opts.disable_all_extensions:
|
||||||
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||||
elif shared.opts.disable_all_extensions == "all":
|
elif shared.opts.disable_all_extensions == "all":
|
||||||
@ -148,18 +159,69 @@ def list_extensions():
|
|||||||
elif shared.opts.disable_all_extensions == "extra":
|
elif shared.opts.disable_all_extensions == "extra":
|
||||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||||
|
|
||||||
extension_paths = []
|
extension_dependency_map = {}
|
||||||
|
|
||||||
|
# scan through extensions directory and load metadata
|
||||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||||
if not os.path.isdir(dirname):
|
if not os.path.isdir(dirname):
|
||||||
return
|
continue
|
||||||
|
|
||||||
for extension_dirname in sorted(os.listdir(dirname)):
|
for extension_dirname in sorted(os.listdir(dirname)):
|
||||||
path = os.path.join(dirname, extension_dirname)
|
path = os.path.join(dirname, extension_dirname)
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
canonical_name = extension_dirname
|
||||||
|
requires = None
|
||||||
|
|
||||||
for dirname, path, is_builtin in extension_paths:
|
if os.path.isfile(os.path.join(path, "sd_webui_metadata.ini")):
|
||||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
try:
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(os.path.join(path, "sd_webui_metadata.ini"))
|
||||||
|
canonical_name = config.get("Extension", "Name", fallback=canonical_name)
|
||||||
|
requires = config.get("Extension", "Requires", fallback=None)
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error reading sd_webui_metadata.ini for extension {extension_dirname}. "
|
||||||
|
f"Will load regardless.", exc_info=True)
|
||||||
|
|
||||||
|
canonical_name = canonical_name.lower().strip()
|
||||||
|
|
||||||
|
# check for duplicated canonical names
|
||||||
|
if canonical_name in extension_dependency_map:
|
||||||
|
errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
|
||||||
|
f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
|
||||||
|
f"The current loading extension will be discarded.", exc_info=False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# we want to wash the data to lowercase and remove whitespaces just in case
|
||||||
|
requires = [x.strip() for x in requires.lower().split(',')] if requires else []
|
||||||
|
|
||||||
|
extension_dependency_map[canonical_name] = {
|
||||||
|
"dirname": extension_dirname,
|
||||||
|
"path": path,
|
||||||
|
"requires": requires,
|
||||||
|
}
|
||||||
|
|
||||||
|
# check for requirements
|
||||||
|
for (_, extension_data) in extension_dependency_map.items():
|
||||||
|
dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
|
||||||
|
requirement_met = True
|
||||||
|
for req in requires:
|
||||||
|
if req not in extension_dependency_map:
|
||||||
|
errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
|
||||||
|
f"The current loading extension will be discarded.", exc_info=False)
|
||||||
|
requirement_met = False
|
||||||
|
break
|
||||||
|
dep_dirname = extension_dependency_map[req]['dirname']
|
||||||
|
if dep_dirname in shared.opts.disabled_extensions:
|
||||||
|
errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
|
||||||
|
f"The current loading extension will be discarded.", exc_info=False)
|
||||||
|
requirement_met = False
|
||||||
|
break
|
||||||
|
|
||||||
|
is_builtin = dirname == extensions_builtin_dir
|
||||||
|
extension = Extension(name=dirname, path=path,
|
||||||
|
enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
|
||||||
|
is_builtin=is_builtin)
|
||||||
extensions.append(extension)
|
extensions.append(extension)
|
||||||
|
@ -2,6 +2,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
|
from graphlib import TopologicalSorter, CycleError
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -314,15 +315,131 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi
|
|||||||
|
|
||||||
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||||
scripts_list = []
|
scripts_list = []
|
||||||
|
script_dependency_map = {}
|
||||||
|
|
||||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
# build script dependency map
|
||||||
if os.path.exists(basedir):
|
|
||||||
for filename in sorted(os.listdir(basedir)):
|
root_script_basedir = os.path.join(paths.script_path, scriptdirname)
|
||||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
if os.path.exists(root_script_basedir):
|
||||||
|
for filename in sorted(os.listdir(root_script_basedir)):
|
||||||
|
script_dependency_map[filename] = {
|
||||||
|
"extension": None,
|
||||||
|
"extension_dirname": None,
|
||||||
|
"script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
|
||||||
|
"requires": [],
|
||||||
|
"load_before": [],
|
||||||
|
"load_after": [],
|
||||||
|
}
|
||||||
|
|
||||||
if include_extensions:
|
if include_extensions:
|
||||||
for ext in extensions.active():
|
for ext in extensions.active():
|
||||||
scripts_list += ext.list_files(scriptdirname, extension)
|
extension_scripts_list = ext.list_files(scriptdirname, extension)
|
||||||
|
for extension_script in extension_scripts_list:
|
||||||
|
# this is built on the assumption that script name is unique.
|
||||||
|
# I think bad thing is gonna happen if name collide in the current implementation anyway, but we
|
||||||
|
# will need to refactor here if this assumption is broken later on.
|
||||||
|
if extension_script.filename in script_dependency_map:
|
||||||
|
errors.report(f"Duplicate script name \"{extension_script.filename}\" found in extensions "
|
||||||
|
f"\"{ext.name}\" and \"{script_dependency_map[extension_script.filename]['extension_dirname'] or 'builtin'}\". "
|
||||||
|
f"The current loading file will be discarded.", exc_info=False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
relative_path = scriptdirname + "/" + extension_script.filename
|
||||||
|
|
||||||
|
requires = None
|
||||||
|
load_before = None
|
||||||
|
load_after = None
|
||||||
|
|
||||||
|
if ext.metadata is not None:
|
||||||
|
requires = ext.metadata.get(relative_path, "Requires", fallback=None)
|
||||||
|
load_before = ext.metadata.get(relative_path, "Before", fallback=None)
|
||||||
|
load_after = ext.metadata.get(relative_path, "After", fallback=None)
|
||||||
|
|
||||||
|
requires = [x.strip() for x in requires.split(',')] if requires else []
|
||||||
|
load_after = [x.strip() for x in load_after.split(',')] if load_after else []
|
||||||
|
load_before = [x.strip() for x in load_before.split(',')] if load_before else []
|
||||||
|
|
||||||
|
script_dependency_map[extension_script.filename] = {
|
||||||
|
"extension": ext.canonical_name,
|
||||||
|
"extension_dirname": ext.name,
|
||||||
|
"script_file": extension_script,
|
||||||
|
"requires": requires,
|
||||||
|
"load_before": load_before,
|
||||||
|
"load_after": load_after,
|
||||||
|
}
|
||||||
|
|
||||||
|
# resolve dependencies
|
||||||
|
|
||||||
|
loaded_extensions = set()
|
||||||
|
for _, script_data in script_dependency_map.items():
|
||||||
|
if script_data['extension'] is not None:
|
||||||
|
loaded_extensions.add(script_data['extension'])
|
||||||
|
|
||||||
|
for script_filename, script_data in script_dependency_map.items():
|
||||||
|
# load before requires inverse dependency
|
||||||
|
# in this case, append the script name into the load_after list of the specified script
|
||||||
|
for load_before_script in script_data['load_before']:
|
||||||
|
if load_before_script.startswith('ext:'):
|
||||||
|
# if this requires an extension to be loaded before
|
||||||
|
required_extension = load_before_script[4:]
|
||||||
|
for _, script_data2 in script_dependency_map.items():
|
||||||
|
if script_data2['extension'] == required_extension:
|
||||||
|
script_data2['load_after'].append(script_filename)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# if this requires an individual script to be loaded before
|
||||||
|
if load_before_script in script_dependency_map:
|
||||||
|
script_dependency_map[load_before_script]['load_after'].append(script_filename)
|
||||||
|
|
||||||
|
# resolve extension name in load_after lists
|
||||||
|
for load_after_script in script_data['load_after']:
|
||||||
|
if load_after_script.startswith('ext:'):
|
||||||
|
# if this requires an extension to be loaded after
|
||||||
|
required_extension = load_after_script[4:]
|
||||||
|
for script_file_name2, script_data2 in script_dependency_map.items():
|
||||||
|
if script_data2['extension'] == required_extension:
|
||||||
|
script_data['load_after'].append(script_file_name2)
|
||||||
|
|
||||||
|
# remove all extension names in load_after lists
|
||||||
|
script_data['load_after'] = [x for x in script_data['load_after'] if not x.startswith('ext:')]
|
||||||
|
|
||||||
|
# build the DAG
|
||||||
|
sorter = TopologicalSorter()
|
||||||
|
for script_filename, script_data in script_dependency_map.items():
|
||||||
|
requirement_met = True
|
||||||
|
for required_script in script_data['requires']:
|
||||||
|
if required_script.startswith('ext:'):
|
||||||
|
# if this requires an extension to be installed
|
||||||
|
required_extension = required_script[4:]
|
||||||
|
if required_extension not in loaded_extensions:
|
||||||
|
errors.report(f"Script \"{script_filename}\" requires extension \"{required_extension}\" to "
|
||||||
|
f"be loaded, but it is not. Skipping.",
|
||||||
|
exc_info=False)
|
||||||
|
requirement_met = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# if this requires an individual script to be loaded
|
||||||
|
if required_script not in script_dependency_map:
|
||||||
|
errors.report(f"Script \"{script_filename}\" requires script \"{required_script}\" to "
|
||||||
|
f"be loaded, but it is not. Skipping.",
|
||||||
|
exc_info=False)
|
||||||
|
requirement_met = False
|
||||||
|
break
|
||||||
|
if not requirement_met:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sorter.add(script_filename, *script_data['load_after'])
|
||||||
|
|
||||||
|
# sort the scripts
|
||||||
|
try:
|
||||||
|
ordered_script = sorter.static_order()
|
||||||
|
except CycleError:
|
||||||
|
errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
|
||||||
|
ordered_script = script_dependency_map.keys()
|
||||||
|
|
||||||
|
for script_filename in ordered_script:
|
||||||
|
script_data = script_dependency_map[script_filename]
|
||||||
|
scripts_list.append(script_data['script_file'])
|
||||||
|
|
||||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||||
|
|
||||||
@ -365,15 +482,9 @@ def load_scripts():
|
|||||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||||
|
|
||||||
def orderby(basedir):
|
# here the scripts_list is already ordered
|
||||||
# 1st webui, 2nd extensions-builtin, 3rd extensions
|
# processing_script is not considered though
|
||||||
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
|
for scriptfile in scripts_list:
|
||||||
for key in priority:
|
|
||||||
if basedir.startswith(key):
|
|
||||||
return priority[key]
|
|
||||||
return 9999
|
|
||||||
|
|
||||||
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
|
|
||||||
try:
|
try:
|
||||||
if scriptfile.basedir != paths.script_path:
|
if scriptfile.basedir != paths.script_path:
|
||||||
sys.path = [scriptfile.basedir] + sys.path
|
sys.path = [scriptfile.basedir] + sys.path
|
||||||
|
Loading…
Reference in New Issue
Block a user