add support for specifying callback order in metadata

This commit is contained in:
AUTOMATIC1111 2024-03-10 15:14:04 +03:00
parent 7e5e67330b
commit 2f55d669a2
4 changed files with 84 additions and 25 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import configparser import configparser
import dataclasses
import os import os
import threading import threading
import re import re
@ -22,6 +23,13 @@ def active():
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
@dataclasses.dataclass
class CallbackOrderInfo:
name: str
before: list
after: list
class ExtensionMetadata: class ExtensionMetadata:
filename = "metadata.ini" filename = "metadata.ini"
config: configparser.ConfigParser config: configparser.ConfigParser
@ -65,6 +73,22 @@ class ExtensionMetadata:
# both "," and " " are accepted as separator # both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x] return [x for x in re.split(r"[,\s]+", text.strip()) if x]
def list_callback_order_instructions(self):
for section in self.config.sections():
if not section.startswith("callbacks/"):
continue
callback_name = section[10:]
if not callback_name.startswith(self.canonical_name):
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
continue
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
after = self.parse_list(self.config.get(section, 'After', fallback=''))
yield CallbackOrderInfo(callback_name, before, after)
class Extension: class Extension:
lock = threading.Lock() lock = threading.Lock()

View File

@ -8,7 +8,7 @@ from typing import Optional, Any
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
from modules import errors, timer, extensions, shared from modules import errors, timer, extensions, shared, util
def report_exception(c, job): def report_exception(c, job):
@ -149,6 +149,38 @@ def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True): def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
callbacks = unordered_callbacks.copy() callbacks = unordered_callbacks.copy()
callback_lookup = {x.name: x for x in callbacks}
dependencies = {}
order_instructions = {}
for extension in extensions.extensions:
for order_instruction in extension.metadata.list_callback_order_instructions():
if order_instruction.name in callback_lookup:
if order_instruction.name not in order_instructions:
order_instructions[order_instruction.name] = []
order_instructions[order_instruction.name].append(order_instruction)
if order_instructions:
for callback in callbacks:
dependencies[callback.name] = []
for callback in callbacks:
for order_instruction in order_instructions.get(callback.name, []):
for after in order_instruction.after:
if after not in callback_lookup:
continue
dependencies[callback.name].append(after)
for before in order_instruction.before:
if before not in callback_lookup:
continue
dependencies[before].append(callback.name)
sorted_names = util.topological_sort(dependencies)
callbacks = [callback_lookup[x] for x in sorted_names]
if enable_user_sort: if enable_user_sort:
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])): for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):

View File

@ -7,7 +7,9 @@ from dataclasses import dataclass
import gradio as gr import gradio as gr
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util
topological_sort = util.topological_sort
AlwaysVisible = object() AlwaysVisible = object()
@ -368,29 +370,6 @@ scripts_data = []
postprocessing_scripts_data = [] postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result
@dataclass @dataclass
class ScriptWithDependencies: class ScriptWithDependencies:

View File

@ -136,3 +136,27 @@ class MassFileLister:
def reset(self): def reset(self):
"""Clear the cache of all directories.""" """Clear the cache of all directories."""
self.cached_dirs.clear() self.cached_dirs.clear()
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result