Merge pull request #15205 from AUTOMATIC1111/callback_order

Callback order
This commit is contained in:
AUTOMATIC1111 2024-03-16 09:45:41 +03:00 committed by GitHub
commit 5bd2724765
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 361 additions and 121 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import configparser import configparser
import dataclasses
import os import os
import threading import threading
import re import re
@ -22,6 +23,13 @@ def active():
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
@dataclasses.dataclass
class CallbackOrderInfo:
name: str
before: list
after: list
class ExtensionMetadata: class ExtensionMetadata:
filename = "metadata.ini" filename = "metadata.ini"
config: configparser.ConfigParser config: configparser.ConfigParser
@ -65,6 +73,22 @@ class ExtensionMetadata:
# both "," and " " are accepted as separator # both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x] return [x for x in re.split(r"[,\s]+", text.strip()) if x]
def list_callback_order_instructions(self):
for section in self.config.sections():
if not section.startswith("callbacks/"):
continue
callback_name = section[10:]
if not callback_name.startswith(self.canonical_name):
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
continue
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
after = self.parse_list(self.config.get(section, 'After', fallback=''))
yield CallbackOrderInfo(callback_name, before, after)
class Extension: class Extension:
lock = threading.Lock() lock = threading.Lock()
@ -188,6 +212,7 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
extension_paths.clear()
if shared.cmd_opts.disable_all_extensions: if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
@ -222,6 +247,7 @@ def list_extensions():
is_builtin = dirname == extensions_builtin_dir is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata) extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension) extensions.append(extension)
extension_paths[extension.path] = extension
loaded_extensions[canonical_name] = extension loaded_extensions[canonical_name] = extension
# check for requirements # check for requirements
@ -240,4 +266,19 @@ def list_extensions():
continue continue
def find_extension(filename):
parentdir = os.path.dirname(os.path.realpath(filename))
while parentdir != filename:
extension = extension_paths.get(parentdir)
if extension is not None:
return extension
filename = parentdir
parentdir = os.path.dirname(filename)
return None
extensions: list[Extension] = [] extensions: list[Extension] = []
extension_paths: dict[str, Extension] = {}

View File

@ -1,13 +1,14 @@
from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import os import os
from collections import namedtuple
from typing import Optional, Any from typing import Optional, Any
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
from modules import errors, timer from modules import errors, timer, extensions, shared, util
def report_exception(c, job): def report_exception(c, job):
@ -116,7 +117,105 @@ class BeforeTokenCounterParams:
is_positive: bool = True is_positive: bool = True
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) @dataclasses.dataclass
class ScriptCallback:
script: str
callback: any
name: str = None
def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
if filename is None:
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
extension = extensions.find_extension(filename)
extension_name = extension.canonical_name if extension else 'base'
callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}"
if name is not None:
callback_name += f'/{name}'
unique_callback_name = callback_name
for index in range(1000):
existing = any(x.name == unique_callback_name for x in callbacks)
if not existing:
break
unique_callback_name = f'{callback_name}-{index+1}'
callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
callbacks = unordered_callbacks.copy()
callback_lookup = {x.name: x for x in callbacks}
dependencies = {}
order_instructions = {}
for extension in extensions.extensions:
for order_instruction in extension.metadata.list_callback_order_instructions():
if order_instruction.name in callback_lookup:
if order_instruction.name not in order_instructions:
order_instructions[order_instruction.name] = []
order_instructions[order_instruction.name].append(order_instruction)
if order_instructions:
for callback in callbacks:
dependencies[callback.name] = []
for callback in callbacks:
for order_instruction in order_instructions.get(callback.name, []):
for after in order_instruction.after:
if after not in callback_lookup:
continue
dependencies[callback.name].append(after)
for before in order_instruction.before:
if before not in callback_lookup:
continue
dependencies[before].append(callback.name)
sorted_names = util.topological_sort(dependencies)
callbacks = [callback_lookup[x] for x in sorted_names]
if enable_user_sort:
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
if index is not None:
callbacks.insert(0, callbacks.pop(index))
return callbacks
def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
if unordered_callbacks is None:
unordered_callbacks = callback_map.get('callbacks_' + category, [])
if not enable_user_sort:
return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
callbacks = ordered_callbacks_map.get(category)
if callbacks is not None and len(callbacks) == len(unordered_callbacks):
return callbacks
callbacks = sort_callbacks(category, unordered_callbacks)
ordered_callbacks_map[category] = callbacks
return callbacks
def enumerate_callbacks():
for category, callbacks in callback_map.items():
if category.startswith('callbacks_'):
category = category[10:]
yield category, callbacks
callback_map = dict( callback_map = dict(
callbacks_app_started=[], callbacks_app_started=[],
callbacks_model_loaded=[], callbacks_model_loaded=[],
@ -141,14 +240,18 @@ callback_map = dict(
callbacks_before_token_counter=[], callbacks_before_token_counter=[],
) )
ordered_callbacks_map = {}
def clear_callbacks(): def clear_callbacks():
for callback_list in callback_map.values(): for callback_list in callback_map.values():
callback_list.clear() callback_list.clear()
ordered_callbacks_map.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI): def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']: for c in ordered_callbacks('app_started'):
try: try:
c.callback(demo, app) c.callback(demo, app)
timer.startup_timer.record(os.path.basename(c.script)) timer.startup_timer.record(os.path.basename(c.script))
@ -157,7 +260,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def app_reload_callback(): def app_reload_callback():
for c in callback_map['callbacks_on_reload']: for c in ordered_callbacks('on_reload'):
try: try:
c.callback() c.callback()
except Exception: except Exception:
@ -165,7 +268,7 @@ def app_reload_callback():
def model_loaded_callback(sd_model): def model_loaded_callback(sd_model):
for c in callback_map['callbacks_model_loaded']: for c in ordered_callbacks('model_loaded'):
try: try:
c.callback(sd_model) c.callback(sd_model)
except Exception: except Exception:
@ -175,7 +278,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback(): def ui_tabs_callback():
res = [] res = []
for c in callback_map['callbacks_ui_tabs']: for c in ordered_callbacks('ui_tabs'):
try: try:
res += c.callback() or [] res += c.callback() or []
except Exception: except Exception:
@ -185,7 +288,7 @@ def ui_tabs_callback():
def ui_train_tabs_callback(params: UiTrainTabParams): def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']: for c in ordered_callbacks('ui_train_tabs'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -193,7 +296,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams):
def ui_settings_callback(): def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']: for c in ordered_callbacks('ui_settings'):
try: try:
c.callback() c.callback()
except Exception: except Exception:
@ -201,7 +304,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams): def before_image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_before_image_saved']: for c in ordered_callbacks('before_image_saved'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -209,7 +312,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_image_saved']: for c in ordered_callbacks('image_saved'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -217,7 +320,7 @@ def image_saved_callback(params: ImageSaveParams):
def extra_noise_callback(params: ExtraNoiseParams): def extra_noise_callback(params: ExtraNoiseParams):
for c in callback_map['callbacks_extra_noise']: for c in ordered_callbacks('extra_noise'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -225,7 +328,7 @@ def extra_noise_callback(params: ExtraNoiseParams):
def cfg_denoiser_callback(params: CFGDenoiserParams): def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']: for c in ordered_callbacks('cfg_denoiser'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -233,7 +336,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
def cfg_denoised_callback(params: CFGDenoisedParams): def cfg_denoised_callback(params: CFGDenoisedParams):
for c in callback_map['callbacks_cfg_denoised']: for c in ordered_callbacks('cfg_denoised'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -241,7 +344,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
def cfg_after_cfg_callback(params: AfterCFGCallbackParams): def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
for c in callback_map['callbacks_cfg_after_cfg']: for c in ordered_callbacks('cfg_after_cfg'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -249,7 +352,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
def before_component_callback(component, **kwargs): def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']: for c in ordered_callbacks('before_component'):
try: try:
c.callback(component, **kwargs) c.callback(component, **kwargs)
except Exception: except Exception:
@ -257,7 +360,7 @@ def before_component_callback(component, **kwargs):
def after_component_callback(component, **kwargs): def after_component_callback(component, **kwargs):
for c in callback_map['callbacks_after_component']: for c in ordered_callbacks('after_component'):
try: try:
c.callback(component, **kwargs) c.callback(component, **kwargs)
except Exception: except Exception:
@ -265,7 +368,7 @@ def after_component_callback(component, **kwargs):
def image_grid_callback(params: ImageGridLoopParams): def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']: for c in ordered_callbacks('image_grid'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -273,7 +376,7 @@ def image_grid_callback(params: ImageGridLoopParams):
def infotext_pasted_callback(infotext: str, params: dict[str, Any]): def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']: for c in ordered_callbacks('infotext_pasted'):
try: try:
c.callback(infotext, params) c.callback(infotext, params)
except Exception: except Exception:
@ -281,7 +384,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
def script_unloaded_callback(): def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']): for c in reversed(ordered_callbacks('script_unloaded')):
try: try:
c.callback() c.callback()
except Exception: except Exception:
@ -289,7 +392,7 @@ def script_unloaded_callback():
def before_ui_callback(): def before_ui_callback():
for c in reversed(callback_map['callbacks_before_ui']): for c in reversed(ordered_callbacks('before_ui')):
try: try:
c.callback() c.callback()
except Exception: except Exception:
@ -299,7 +402,7 @@ def before_ui_callback():
def list_optimizers_callback(): def list_optimizers_callback():
res = [] res = []
for c in callback_map['callbacks_list_optimizers']: for c in ordered_callbacks('list_optimizers'):
try: try:
c.callback(res) c.callback(res)
except Exception: except Exception:
@ -311,7 +414,7 @@ def list_optimizers_callback():
def list_unets_callback(): def list_unets_callback():
res = [] res = []
for c in callback_map['callbacks_list_unets']: for c in ordered_callbacks('list_unets'):
try: try:
c.callback(res) c.callback(res)
except Exception: except Exception:
@ -321,20 +424,13 @@ def list_unets_callback():
def before_token_counter_callback(params: BeforeTokenCounterParams): def before_token_counter_callback(params: BeforeTokenCounterParams):
for c in callback_map['callbacks_before_token_counter']: for c in ordered_callbacks('before_token_counter'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
report_exception(c, 'before_token_counter') report_exception(c, 'before_token_counter')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks(): def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file' filename = stack[0].filename if stack else 'unknown file'
@ -351,24 +447,24 @@ def remove_callbacks_for_function(callback_func):
callback_list.remove(callback_to_remove) callback_list.remove(callback_to_remove)
def on_app_started(callback): def on_app_started(callback, *, name=None):
"""register a function to be called when the webui started, the gradio `Block` component and """register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments""" fastapi `FastAPI` object are passed as the arguments"""
add_callback(callback_map['callbacks_app_started'], callback) add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started')
def on_before_reload(callback): def on_before_reload(callback, *, name=None):
"""register a function to be called just before the server reloads.""" """register a function to be called just before the server reloads."""
add_callback(callback_map['callbacks_on_reload'], callback) add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload')
def on_model_loaded(callback): def on_model_loaded(callback, *, name=None):
"""register a function to be called when the stable diffusion model is created; the model is """register a function to be called when the stable diffusion model is created; the model is
passed as an argument; this function is also called when the script is reloaded. """ passed as an argument; this function is also called when the script is reloaded. """
add_callback(callback_map['callbacks_model_loaded'], callback) add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded')
def on_ui_tabs(callback): def on_ui_tabs(callback, *, name=None):
"""register a function to be called when the UI is creating new tabs. """register a function to be called when the UI is creating new tabs.
The function must either return a None, which means no new tabs to be added, or a list, where The function must either return a None, which means no new tabs to be added, or a list, where
each element is a tuple: each element is a tuple:
@ -378,71 +474,71 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI title is tab text displayed to user in the UI
elem_id is HTML id for the tab elem_id is HTML id for the tab
""" """
add_callback(callback_map['callbacks_ui_tabs'], callback) add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs')
def on_ui_train_tabs(callback): def on_ui_train_tabs(callback, *, name=None):
"""register a function to be called when the UI is creating new tabs for the train tab. """register a function to be called when the UI is creating new tabs for the train tab.
Create your new tabs with gr.Tab. Create your new tabs with gr.Tab.
""" """
add_callback(callback_map['callbacks_ui_train_tabs'], callback) add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs')
def on_ui_settings(callback): def on_ui_settings(callback, *, name=None):
"""register a function to be called before UI settings are populated; add your settings """register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """ by using shared.opts.add_option(shared.OptionInfo(...)) """
add_callback(callback_map['callbacks_ui_settings'], callback) add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings')
def on_before_image_saved(callback): def on_before_image_saved(callback, *, name=None):
"""register a function to be called before an image is saved to a file. """register a function to be called before an image is saved to a file.
The callback is called with one argument: The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
""" """
add_callback(callback_map['callbacks_before_image_saved'], callback) add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved')
def on_image_saved(callback): def on_image_saved(callback, *, name=None):
"""register a function to be called after an image is saved to a file. """register a function to be called after an image is saved to a file.
The callback is called with one argument: The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
""" """
add_callback(callback_map['callbacks_image_saved'], callback) add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved')
def on_extra_noise(callback): def on_extra_noise(callback, *, name=None):
"""register a function to be called before adding extra noise in img2img or hires fix; """register a function to be called before adding extra noise in img2img or hires fix;
The callback is called with one argument: The callback is called with one argument:
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
""" """
add_callback(callback_map['callbacks_extra_noise'], callback) add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise')
def on_cfg_denoiser(callback): def on_cfg_denoiser(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument: The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
""" """
add_callback(callback_map['callbacks_cfg_denoiser'], callback) add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser')
def on_cfg_denoised(callback): def on_cfg_denoised(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument: The callback is called with one argument:
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
""" """
add_callback(callback_map['callbacks_cfg_denoised'], callback) add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised')
def on_cfg_after_cfg(callback): def on_cfg_after_cfg(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
The callback is called with one argument: The callback is called with one argument:
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
""" """
add_callback(callback_map['callbacks_cfg_after_cfg'], callback) add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg')
def on_before_component(callback): def on_before_component(callback, *, name=None):
"""register a function to be called before a component is created. """register a function to be called before a component is created.
The callback is called with arguments: The callback is called with arguments:
- component - gradio component that is about to be created. - component - gradio component that is about to be created.
@ -451,61 +547,61 @@ def on_before_component(callback):
Use elem_id/label fields of kwargs to figure out which component it is. Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI. This can be useful to inject your own components somewhere in the middle of vanilla UI.
""" """
add_callback(callback_map['callbacks_before_component'], callback) add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component')
def on_after_component(callback): def on_after_component(callback, *, name=None):
"""register a function to be called after a component is created. See on_before_component for more.""" """register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback) add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component')
def on_image_grid(callback): def on_image_grid(callback, *, name=None):
"""register a function to be called before making an image grid. """register a function to be called before making an image grid.
The callback is called with one argument: The callback is called with one argument:
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
""" """
add_callback(callback_map['callbacks_image_grid'], callback) add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid')
def on_infotext_pasted(callback): def on_infotext_pasted(callback, *, name=None):
"""register a function to be called before applying an infotext. """register a function to be called before applying an infotext.
The callback is called with two arguments: The callback is called with two arguments:
- infotext: str - raw infotext. - infotext: str - raw infotext.
- result: dict[str, any] - parsed infotext parameters. - result: dict[str, any] - parsed infotext parameters.
""" """
add_callback(callback_map['callbacks_infotext_pasted'], callback) add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted')
def on_script_unloaded(callback): def on_script_unloaded(callback, *, name=None):
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
the script did should be reverted here""" the script did should be reverted here"""
add_callback(callback_map['callbacks_script_unloaded'], callback) add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded')
def on_before_ui(callback): def on_before_ui(callback, *, name=None):
"""register a function to be called before the UI is created.""" """register a function to be called before the UI is created."""
add_callback(callback_map['callbacks_before_ui'], callback) add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui')
def on_list_optimizers(callback): def on_list_optimizers(callback, *, name=None):
"""register a function to be called when UI is making a list of cross attention optimization options. """register a function to be called when UI is making a list of cross attention optimization options.
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
to it.""" to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback) add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers')
def on_list_unets(callback): def on_list_unets(callback, *, name=None):
"""register a function to be called when UI is making a list of alternative options for unet. """register a function to be called when UI is making a list of alternative options for unet.
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
add_callback(callback_map['callbacks_list_unets'], callback) add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets')
def on_before_token_counter(callback): def on_before_token_counter(callback, *, name=None):
"""register a function to be called when UI is counting tokens for a prompt. """register a function to be called when UI is counting tokens for a prompt.
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
add_callback(callback_map['callbacks_before_token_counter'], callback) add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')

View File

@ -7,7 +7,9 @@ from dataclasses import dataclass
import gradio as gr import gradio as gr
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util
topological_sort = util.topological_sort
AlwaysVisible = object() AlwaysVisible = object()
@ -138,7 +140,6 @@ class Script:
""" """
pass pass
def before_process(self, p, *args): def before_process(self, p, *args):
""" """
This function is called very early during processing begins for AlwaysVisible scripts. This function is called very early during processing begins for AlwaysVisible scripts.
@ -369,29 +370,6 @@ scripts_data = []
postprocessing_scripts_data = [] postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result
@dataclass @dataclass
class ScriptWithDependencies: class ScriptWithDependencies:
@ -562,6 +540,25 @@ class ScriptRunner:
self.paste_field_names = [] self.paste_field_names = []
self.inputs = [None] self.inputs = [None]
self.callback_map = {}
self.callback_names = [
'before_process',
'process',
'before_process_batch',
'after_extra_networks_activate',
'process_batch',
'postprocess',
'postprocess_batch',
'postprocess_batch_list',
'post_sample',
'on_mask_blend',
'postprocess_image',
'postprocess_maskoverlay',
'postprocess_image_after_composite',
'before_component',
'after_component',
]
self.on_before_component_elem_id = {} self.on_before_component_elem_id = {}
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks""" """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
@ -600,6 +597,8 @@ class ScriptRunner:
self.scripts.append(script) self.scripts.append(script)
self.selectable_scripts.append(script) self.selectable_scripts.append(script)
self.callback_map.clear()
self.apply_on_before_component_callbacks() self.apply_on_before_component_callbacks()
def apply_on_before_component_callbacks(self): def apply_on_before_component_callbacks(self):
@ -769,8 +768,42 @@ class ScriptRunner:
return processed return processed
def list_scripts_for_method(self, method_name):
if method_name in ('before_component', 'after_component'):
return self.scripts
else:
return self.alwayson_scripts
def create_ordered_callbacks_list(self, method_name, *, enable_user_sort=True):
script_list = self.list_scripts_for_method(method_name)
category = f'script_{method_name}'
callbacks = []
for script in script_list:
if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None):
continue
script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename)
return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort)
def ordered_callbacks(self, method_name, *, enable_user_sort=True):
script_list = self.list_scripts_for_method(method_name)
category = f'script_{method_name}'
scrpts_len, callbacks = self.callback_map.get(category, (-1, None))
if callbacks is None or scrpts_len != len(script_list):
callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort)
self.callback_map[category] = len(script_list), callbacks
return callbacks
def ordered_scripts(self, method_name):
return [x.callback for x in self.ordered_callbacks(method_name)]
def before_process(self, p): def before_process(self, p):
for script in self.alwayson_scripts: for script in self.ordered_scripts('before_process'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.before_process(p, *script_args) script.before_process(p, *script_args)
@ -778,7 +811,7 @@ class ScriptRunner:
errors.report(f"Error running before_process: {script.filename}", exc_info=True) errors.report(f"Error running before_process: {script.filename}", exc_info=True)
def process(self, p): def process(self, p):
for script in self.alwayson_scripts: for script in self.ordered_scripts('process'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args) script.process(p, *script_args)
@ -786,7 +819,7 @@ class ScriptRunner:
errors.report(f"Error running process: {script.filename}", exc_info=True) errors.report(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs): def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('before_process_batch'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs) script.before_process_batch(p, *script_args, **kwargs)
@ -794,7 +827,7 @@ class ScriptRunner:
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
def after_extra_networks_activate(self, p, **kwargs): def after_extra_networks_activate(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('after_extra_networks_activate'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.after_extra_networks_activate(p, *script_args, **kwargs) script.after_extra_networks_activate(p, *script_args, **kwargs)
@ -802,7 +835,7 @@ class ScriptRunner:
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True) errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs): def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('process_batch'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs) script.process_batch(p, *script_args, **kwargs)
@ -810,7 +843,7 @@ class ScriptRunner:
errors.report(f"Error running process_batch: {script.filename}", exc_info=True) errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed): def postprocess(self, p, processed):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args) script.postprocess(p, processed, *script_args)
@ -818,7 +851,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess: {script.filename}", exc_info=True) errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs): def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_batch'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs) script.postprocess_batch(p, *script_args, images=images, **kwargs)
@ -826,7 +859,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_batch_list'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch_list(p, pp, *script_args, **kwargs) script.postprocess_batch_list(p, pp, *script_args, **kwargs)
@ -834,7 +867,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def post_sample(self, p, ps: PostSampleArgs): def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('post_sample'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args) script.post_sample(p, ps, *script_args)
@ -842,7 +875,7 @@ class ScriptRunner:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True) errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def on_mask_blend(self, p, mba: MaskBlendArgs): def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('on_mask_blend'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args) script.on_mask_blend(p, mba, *script_args)
@ -850,7 +883,7 @@ class ScriptRunner:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True) errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs): def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_image'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args) script.postprocess_image(p, pp, *script_args)
@ -858,7 +891,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_maskoverlay'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args) script.postprocess_maskoverlay(p, ppmo, *script_args)
@ -866,7 +899,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs): def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_image_after_composite'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image_after_composite(p, pp, *script_args) script.postprocess_image_after_composite(p, pp, *script_args)
@ -880,7 +913,7 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True) errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
for script in self.scripts: for script in self.ordered_scripts('before_component'):
try: try:
script.before_component(component, **kwargs) script.before_component(component, **kwargs)
except Exception: except Exception:
@ -893,7 +926,7 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True) errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
for script in self.scripts: for script in self.ordered_scripts('after_component'):
try: try:
script.after_component(component, **kwargs) script.after_component(component, **kwargs)
except Exception: except Exception:
@ -921,7 +954,7 @@ class ScriptRunner:
self.scripts[si].args_to = args_to self.scripts[si].args_to = args_to
def before_hr(self, p): def before_hr(self, p):
for script in self.alwayson_scripts: for script in self.ordered_scripts('before_hr'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.before_hr(p, *script_args) script.before_hr(p, *script_args)
@ -929,7 +962,7 @@ class ScriptRunner:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True) errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
def setup_scrips(self, p, *, is_ui=True): def setup_scrips(self, p, *, is_ui=True):
for script in self.alwayson_scripts: for script in self.ordered_scripts('setup'):
if not is_ui and script.setup_for_ui_only: if not is_ui and script.setup_for_ui_only:
continue continue

View File

@ -1,5 +1,8 @@
import html
import sys import sys
from modules import script_callbacks, scripts, ui_components
from modules.options import OptionHTML, OptionInfo
from modules.shared_cmd_options import cmd_opts from modules.shared_cmd_options import cmd_opts
@ -118,6 +121,45 @@ def ui_reorder_categories():
yield "scripts" yield "scripts"
def callbacks_order_settings():
options = {
"sd_vae_explanation": OptionHTML("""
For categories below, callbacks added to dropdowns happen before others, in order listed.
"""),
}
callback_options = {}
for category, _ in script_callbacks.enumerate_callbacks():
callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False)
for method_name in scripts.scripts_txt2img.callback_names:
callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False)
for method_name in scripts.scripts_img2img.callback_names:
callbacks = callback_options.get("script_" + method_name, [])
for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False):
if any(x.name == addition.name for x in callbacks):
continue
callbacks.append(addition)
callback_options["script_" + method_name] = callbacks
for category, callbacks in callback_options.items():
if not callbacks:
continue
option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]})
option_info.needs_restart()
option_info.html("<div class='info'>Default order: <ol>" + "".join(f"<li>{html.escape(x.name)}</li>\n" for x in callbacks) + "</ol></div>")
options['prioritized_callbacks_' + category] = option_info
return options
class Shared(sys.modules[__name__].__class__): class Shared(sys.modules[__name__].__class__):
""" """
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than

View File

@ -1,7 +1,8 @@
import gradio as gr import gradio as gr
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
from modules.call_queue import wrap_gradio_call from modules.call_queue import wrap_gradio_call
from modules.options import options_section
from modules.shared import opts from modules.shared import opts
from modules.ui_components import FormRow from modules.ui_components import FormRow
from modules.ui_gradio_extensions import reload_javascript from modules.ui_gradio_extensions import reload_javascript
@ -108,6 +109,11 @@ class UiSettings:
shared.settings_components = self.component_dict shared.settings_components = self.component_dict
# we add this as late as possible so that scripts have already registered their callbacks
opts.data_labels.update(options_section(('callbacks', "Callbacks", "system"), {
**shared_items.callbacks_order_settings(),
}))
opts.reorder() opts.reorder()
with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Blocks(analytics_enabled=False) as settings_interface:

View File

@ -148,8 +148,26 @@ class MassFileLister:
"""Clear the cache of all directories.""" """Clear the cache of all directories."""
self.cached_dirs.clear() self.cached_dirs.clear()
def update_file_entry(self, path):
"""Update the cache for a specific directory.""" def topological_sort(dependencies):
dirname, filename = os.path.split(path) """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
if cached_dir := self.cached_dirs.get(dirname): Ignores errors relating to missing dependeencies or circular dependencies
cached_dir.update_entry(filename) """
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result

View File

@ -528,6 +528,10 @@ table.popup-table .link{
opacity: 0.75; opacity: 0.75;
} }
.settings-comment .info ol{
margin: 0.4em 0 0.8em 1em;
}
#sysinfo_download a.sysinfo_big_link{ #sysinfo_download a.sysinfo_big_link{
font-size: 24pt; font-size: 24pt;
} }