From b1cd0189bc30ed7c8dca6fb7bc3d248a056f3e15 Mon Sep 17 00:00:00 2001 From: Andray Date: Sun, 17 Mar 2024 12:51:40 +0400 Subject: [PATCH] allow variants for extension name in metadata.ini --- modules/extensions.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index 88a389388..6a3c6c7ee 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -10,6 +10,10 @@ from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 +extensions: list[Extension] = [] +extension_paths: dict[str, Extension] = {} +loaded_extensions: dict[str, Exception] = {} + os.makedirs(extensions_dir, exist_ok=True) @@ -50,7 +54,7 @@ class ExtensionMetadata: self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) self.canonical_name = canonical_name.lower().strip() - self.requires = self.get_script_requirements("Requires", "Extension") + self.requires = None def get_script_requirements(self, field, section, extra_section=None): """reads a list of requirements from the config; field is the name of the field in the ini file, @@ -62,7 +66,33 @@ class ExtensionMetadata: if extra_section: x = x + ', ' + self.config.get(extra_section, field, fallback='') - return self.parse_list(x.lower()) + tmp_list = self.parse_list(x.lower()) + + if len(tmp_list) >= 3: + names_variants = [] + i = 0 + while i < len(tmp_list) - 2: + if tmp_list[i] != "|": + names_variants.append([tmp_list[i]]) + i += 1 + else: + names_variants[-1].append(tmp_list[i + 1]) + i += 2 + while i < len(tmp_list): + names_variants.append([tmp_list[i]]) + i += 1 + + result_list = [] + + for name_variants in names_variants: + for variant in name_variants: + if variant in loaded_extensions.keys(): + break + result_list.append(variant) + else: + result_list = tmp_list + + return result_list def parse_list(self, text): """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" @@ -213,6 +243,7 @@ class Extension: def list_extensions(): extensions.clear() extension_paths.clear() + loaded_extensions.clear() if shared.cmd_opts.disable_all_extensions: print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") @@ -223,7 +254,6 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - loaded_extensions = {} # scan through extensions directory and load metadata for dirname in [extensions_builtin_dir, extensions_dir]: @@ -250,6 +280,9 @@ def list_extensions(): extension_paths[extension.path] = extension loaded_extensions[canonical_name] = extension + for extension in extensions: + extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension") + # check for requirements for extension in extensions: if not extension.enabled: @@ -279,6 +312,3 @@ def find_extension(filename): return None - -extensions: list[Extension] = [] -extension_paths: dict[str, Extension] = {}