allow variants for extension name in metadata.ini

This commit is contained in:
Andray 2024-03-17 12:51:40 +04:00
parent c95c46004a
commit b1cd0189bc

View File

@ -10,6 +10,10 @@ from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions: list[Extension] = []
extension_paths: dict[str, Extension] = {}
loaded_extensions: dict[str, Exception] = {}
os.makedirs(extensions_dir, exist_ok=True) os.makedirs(extensions_dir, exist_ok=True)
@ -50,7 +54,7 @@ class ExtensionMetadata:
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
self.canonical_name = canonical_name.lower().strip() self.canonical_name = canonical_name.lower().strip()
self.requires = self.get_script_requirements("Requires", "Extension") self.requires = None
def get_script_requirements(self, field, section, extra_section=None): def get_script_requirements(self, field, section, extra_section=None):
"""reads a list of requirements from the config; field is the name of the field in the ini file, """reads a list of requirements from the config; field is the name of the field in the ini file,
@ -62,7 +66,33 @@ class ExtensionMetadata:
if extra_section: if extra_section:
x = x + ', ' + self.config.get(extra_section, field, fallback='') x = x + ', ' + self.config.get(extra_section, field, fallback='')
return self.parse_list(x.lower()) tmp_list = self.parse_list(x.lower())
if len(tmp_list) >= 3:
names_variants = []
i = 0
while i < len(tmp_list) - 2:
if tmp_list[i] != "|":
names_variants.append([tmp_list[i]])
i += 1
else:
names_variants[-1].append(tmp_list[i + 1])
i += 2
while i < len(tmp_list):
names_variants.append([tmp_list[i]])
i += 1
result_list = []
for name_variants in names_variants:
for variant in name_variants:
if variant in loaded_extensions.keys():
break
result_list.append(variant)
else:
result_list = tmp_list
return result_list
def parse_list(self, text): def parse_list(self, text):
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
@ -213,6 +243,7 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
extension_paths.clear() extension_paths.clear()
loaded_extensions.clear()
if shared.cmd_opts.disable_all_extensions: if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
@ -223,7 +254,6 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra": elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
loaded_extensions = {}
# scan through extensions directory and load metadata # scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]: for dirname in [extensions_builtin_dir, extensions_dir]:
@ -250,6 +280,9 @@ def list_extensions():
extension_paths[extension.path] = extension extension_paths[extension.path] = extension
loaded_extensions[canonical_name] = extension loaded_extensions[canonical_name] = extension
for extension in extensions:
extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension")
# check for requirements # check for requirements
for extension in extensions: for extension in extensions:
if not extension.enabled: if not extension.enabled:
@ -279,6 +312,3 @@ def find_extension(filename):
return None return None
extensions: list[Extension] = []
extension_paths: dict[str, Extension] = {}