implementing script metadata and DAG sorting mechanism

This commit is contained in:
wfjsw 2023-11-11 04:01:13 -06:00
parent 5e80d9ee99
commit 0fc7dc1c04
2 changed files with 196 additions and 23 deletions

View File

@ -1,3 +1,5 @@
import configparser
import functools
import os import os
import threading import threading
@ -23,8 +25,9 @@ class Extension:
lock = threading.Lock() lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
def __init__(self, name, path, enabled=True, is_builtin=False): def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
self.name = name self.name = name
self.canonical_name = canonical_name or name.lower()
self.path = path self.path = path
self.enabled = enabled self.enabled = enabled
self.status = '' self.status = ''
@ -37,6 +40,17 @@ class Extension:
self.remote = None self.remote = None
self.have_info_from_repo = False self.have_info_from_repo = False
@functools.cached_property
def metadata(self):
if os.path.isfile(os.path.join(self.path, "sd_webui_metadata.ini")):
try:
config = configparser.ConfigParser()
config.read(os.path.join(self.path, "sd_webui_metadata.ini"))
return config
except Exception:
errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", exc_info=True)
return None
def to_dict(self): def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields} return {x: getattr(self, x) for x in self.cached_fields}
@ -136,9 +150,6 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
if not os.path.isdir(extensions_dir):
return
if shared.cmd_opts.disable_all_extensions: if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "all": elif shared.opts.disable_all_extensions == "all":
@ -148,18 +159,69 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra": elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_paths = [] extension_dependency_map = {}
# scan through extensions directory and load metadata
for dirname in [extensions_dir, extensions_builtin_dir]: for dirname in [extensions_dir, extensions_builtin_dir]:
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
return continue
for extension_dirname in sorted(os.listdir(dirname)): for extension_dirname in sorted(os.listdir(dirname)):
path = os.path.join(dirname, extension_dirname) path = os.path.join(dirname, extension_dirname)
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) canonical_name = extension_dirname
requires = None
for dirname, path, is_builtin in extension_paths: if os.path.isfile(os.path.join(path, "sd_webui_metadata.ini")):
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) try:
config = configparser.ConfigParser()
config.read(os.path.join(path, "sd_webui_metadata.ini"))
canonical_name = config.get("Extension", "Name", fallback=canonical_name)
requires = config.get("Extension", "Requires", fallback=None)
continue
except Exception:
errors.report(f"Error reading sd_webui_metadata.ini for extension {extension_dirname}. "
f"Will load regardless.", exc_info=True)
canonical_name = canonical_name.lower().strip()
# check for duplicated canonical names
if canonical_name in extension_dependency_map:
errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
f"The current loading extension will be discarded.", exc_info=False)
continue
# we want to wash the data to lowercase and remove whitespaces just in case
requires = [x.strip() for x in requires.lower().split(',')] if requires else []
extension_dependency_map[canonical_name] = {
"dirname": extension_dirname,
"path": path,
"requires": requires,
}
# check for requirements
for (_, extension_data) in extension_dependency_map.items():
dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
requirement_met = True
for req in requires:
if req not in extension_dependency_map:
errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
f"The current loading extension will be discarded.", exc_info=False)
requirement_met = False
break
dep_dirname = extension_dependency_map[req]['dirname']
if dep_dirname in shared.opts.disabled_extensions:
errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
f"The current loading extension will be discarded.", exc_info=False)
requirement_met = False
break
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=dirname, path=path,
enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
is_builtin=is_builtin)
extensions.append(extension) extensions.append(extension)

View File

@ -2,6 +2,7 @@ import os
import re import re
import sys import sys
import inspect import inspect
from graphlib import TopologicalSorter, CycleError
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
@ -314,15 +315,131 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi
def list_scripts(scriptdirname, extension, *, include_extensions=True): def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = [] scripts_list = []
script_dependency_map = {}
basedir = os.path.join(paths.script_path, scriptdirname) # build script dependency map
if os.path.exists(basedir):
for filename in sorted(os.listdir(basedir)): root_script_basedir = os.path.join(paths.script_path, scriptdirname)
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) if os.path.exists(root_script_basedir):
for filename in sorted(os.listdir(root_script_basedir)):
script_dependency_map[filename] = {
"extension": None,
"extension_dirname": None,
"script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
"requires": [],
"load_before": [],
"load_after": [],
}
if include_extensions: if include_extensions:
for ext in extensions.active(): for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension) extension_scripts_list = ext.list_files(scriptdirname, extension)
for extension_script in extension_scripts_list:
# this is built on the assumption that script name is unique.
# I think bad thing is gonna happen if name collide in the current implementation anyway, but we
# will need to refactor here if this assumption is broken later on.
if extension_script.filename in script_dependency_map:
errors.report(f"Duplicate script name \"{extension_script.filename}\" found in extensions "
f"\"{ext.name}\" and \"{script_dependency_map[extension_script.filename]['extension_dirname'] or 'builtin'}\". "
f"The current loading file will be discarded.", exc_info=False)
continue
relative_path = scriptdirname + "/" + extension_script.filename
requires = None
load_before = None
load_after = None
if ext.metadata is not None:
requires = ext.metadata.get(relative_path, "Requires", fallback=None)
load_before = ext.metadata.get(relative_path, "Before", fallback=None)
load_after = ext.metadata.get(relative_path, "After", fallback=None)
requires = [x.strip() for x in requires.split(',')] if requires else []
load_after = [x.strip() for x in load_after.split(',')] if load_after else []
load_before = [x.strip() for x in load_before.split(',')] if load_before else []
script_dependency_map[extension_script.filename] = {
"extension": ext.canonical_name,
"extension_dirname": ext.name,
"script_file": extension_script,
"requires": requires,
"load_before": load_before,
"load_after": load_after,
}
# resolve dependencies
loaded_extensions = set()
for _, script_data in script_dependency_map.items():
if script_data['extension'] is not None:
loaded_extensions.add(script_data['extension'])
for script_filename, script_data in script_dependency_map.items():
# load before requires inverse dependency
# in this case, append the script name into the load_after list of the specified script
for load_before_script in script_data['load_before']:
if load_before_script.startswith('ext:'):
# if this requires an extension to be loaded before
required_extension = load_before_script[4:]
for _, script_data2 in script_dependency_map.items():
if script_data2['extension'] == required_extension:
script_data2['load_after'].append(script_filename)
break
else:
# if this requires an individual script to be loaded before
if load_before_script in script_dependency_map:
script_dependency_map[load_before_script]['load_after'].append(script_filename)
# resolve extension name in load_after lists
for load_after_script in script_data['load_after']:
if load_after_script.startswith('ext:'):
# if this requires an extension to be loaded after
required_extension = load_after_script[4:]
for script_file_name2, script_data2 in script_dependency_map.items():
if script_data2['extension'] == required_extension:
script_data['load_after'].append(script_file_name2)
# remove all extension names in load_after lists
script_data['load_after'] = [x for x in script_data['load_after'] if not x.startswith('ext:')]
# build the DAG
sorter = TopologicalSorter()
for script_filename, script_data in script_dependency_map.items():
requirement_met = True
for required_script in script_data['requires']:
if required_script.startswith('ext:'):
# if this requires an extension to be installed
required_extension = required_script[4:]
if required_extension not in loaded_extensions:
errors.report(f"Script \"{script_filename}\" requires extension \"{required_extension}\" to "
f"be loaded, but it is not. Skipping.",
exc_info=False)
requirement_met = False
break
else:
# if this requires an individual script to be loaded
if required_script not in script_dependency_map:
errors.report(f"Script \"{script_filename}\" requires script \"{required_script}\" to "
f"be loaded, but it is not. Skipping.",
exc_info=False)
requirement_met = False
break
if not requirement_met:
continue
sorter.add(script_filename, *script_data['load_after'])
# sort the scripts
try:
ordered_script = sorter.static_order()
except CycleError:
errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
ordered_script = script_dependency_map.keys()
for script_filename in ordered_script:
script_data = script_dependency_map[script_filename]
scripts_list.append(script_data['script_file'])
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@ -365,15 +482,9 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
def orderby(basedir): # here the scripts_list is already ordered
# 1st webui, 2nd extensions-builtin, 3rd extensions # processing_script is not considered though
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} for scriptfile in scripts_list:
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try: try:
if scriptfile.basedir != paths.script_path: if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path