mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-03-09 23:44:55 +08:00
add support for specifying callback order in metadata
This commit is contained in:
parent
7e5e67330b
commit
2f55d669a2
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import re
|
import re
|
||||||
@ -22,6 +23,13 @@ def active():
|
|||||||
return [x for x in extensions if x.enabled]
|
return [x for x in extensions if x.enabled]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CallbackOrderInfo:
|
||||||
|
name: str
|
||||||
|
before: list
|
||||||
|
after: list
|
||||||
|
|
||||||
|
|
||||||
class ExtensionMetadata:
|
class ExtensionMetadata:
|
||||||
filename = "metadata.ini"
|
filename = "metadata.ini"
|
||||||
config: configparser.ConfigParser
|
config: configparser.ConfigParser
|
||||||
@ -65,6 +73,22 @@ class ExtensionMetadata:
|
|||||||
# both "," and " " are accepted as separator
|
# both "," and " " are accepted as separator
|
||||||
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
||||||
|
|
||||||
|
def list_callback_order_instructions(self):
|
||||||
|
for section in self.config.sections():
|
||||||
|
if not section.startswith("callbacks/"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
callback_name = section[10:]
|
||||||
|
|
||||||
|
if not callback_name.startswith(self.canonical_name):
|
||||||
|
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
|
||||||
|
after = self.parse_list(self.config.get(section, 'After', fallback=''))
|
||||||
|
|
||||||
|
yield CallbackOrderInfo(callback_name, before, after)
|
||||||
|
|
||||||
|
|
||||||
class Extension:
|
class Extension:
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
@ -8,7 +8,7 @@ from typing import Optional, Any
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from gradio import Blocks
|
from gradio import Blocks
|
||||||
|
|
||||||
from modules import errors, timer, extensions, shared
|
from modules import errors, timer, extensions, shared, util
|
||||||
|
|
||||||
|
|
||||||
def report_exception(c, job):
|
def report_exception(c, job):
|
||||||
@ -149,6 +149,38 @@ def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None
|
|||||||
|
|
||||||
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
|
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
|
||||||
callbacks = unordered_callbacks.copy()
|
callbacks = unordered_callbacks.copy()
|
||||||
|
callback_lookup = {x.name: x for x in callbacks}
|
||||||
|
dependencies = {}
|
||||||
|
|
||||||
|
order_instructions = {}
|
||||||
|
for extension in extensions.extensions:
|
||||||
|
for order_instruction in extension.metadata.list_callback_order_instructions():
|
||||||
|
if order_instruction.name in callback_lookup:
|
||||||
|
if order_instruction.name not in order_instructions:
|
||||||
|
order_instructions[order_instruction.name] = []
|
||||||
|
|
||||||
|
order_instructions[order_instruction.name].append(order_instruction)
|
||||||
|
|
||||||
|
if order_instructions:
|
||||||
|
for callback in callbacks:
|
||||||
|
dependencies[callback.name] = []
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
for order_instruction in order_instructions.get(callback.name, []):
|
||||||
|
for after in order_instruction.after:
|
||||||
|
if after not in callback_lookup:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dependencies[callback.name].append(after)
|
||||||
|
|
||||||
|
for before in order_instruction.before:
|
||||||
|
if before not in callback_lookup:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dependencies[before].append(callback.name)
|
||||||
|
|
||||||
|
sorted_names = util.topological_sort(dependencies)
|
||||||
|
callbacks = [callback_lookup[x] for x in sorted_names]
|
||||||
|
|
||||||
if enable_user_sort:
|
if enable_user_sort:
|
||||||
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
|
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
|
||||||
|
@ -7,7 +7,9 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util
|
||||||
|
|
||||||
|
topological_sort = util.topological_sort
|
||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
@ -368,29 +370,6 @@ scripts_data = []
|
|||||||
postprocessing_scripts_data = []
|
postprocessing_scripts_data = []
|
||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
def topological_sort(dependencies):
|
|
||||||
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
|
|
||||||
Ignores errors relating to missing dependeencies or circular dependencies
|
|
||||||
"""
|
|
||||||
|
|
||||||
visited = {}
|
|
||||||
result = []
|
|
||||||
|
|
||||||
def inner(name):
|
|
||||||
visited[name] = True
|
|
||||||
|
|
||||||
for dep in dependencies.get(name, []):
|
|
||||||
if dep in dependencies and dep not in visited:
|
|
||||||
inner(dep)
|
|
||||||
|
|
||||||
result.append(name)
|
|
||||||
|
|
||||||
for depname in dependencies:
|
|
||||||
if depname not in visited:
|
|
||||||
inner(depname)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScriptWithDependencies:
|
class ScriptWithDependencies:
|
||||||
|
@ -136,3 +136,27 @@ class MassFileLister:
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
"""Clear the cache of all directories."""
|
"""Clear the cache of all directories."""
|
||||||
self.cached_dirs.clear()
|
self.cached_dirs.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def topological_sort(dependencies):
|
||||||
|
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
|
||||||
|
Ignores errors relating to missing dependeencies or circular dependencies
|
||||||
|
"""
|
||||||
|
|
||||||
|
visited = {}
|
||||||
|
result = []
|
||||||
|
|
||||||
|
def inner(name):
|
||||||
|
visited[name] = True
|
||||||
|
|
||||||
|
for dep in dependencies.get(name, []):
|
||||||
|
if dep in dependencies and dep not in visited:
|
||||||
|
inner(dep)
|
||||||
|
|
||||||
|
result.append(name)
|
||||||
|
|
||||||
|
for depname in dependencies:
|
||||||
|
if depname not in visited:
|
||||||
|
inner(depname)
|
||||||
|
|
||||||
|
return result
|
||||||
|
Loading…
Reference in New Issue
Block a user