From 0411eced89688f55aec356eea0c7d29377d37bc8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 10 Mar 2024 07:52:57 +0300 Subject: [PATCH] add names to callbacks --- modules/extensions.py | 17 +++++ modules/script_callbacks.py | 126 +++++++++++++++++++++--------------- 2 files changed, 91 insertions(+), 52 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index 04bda297e..ab835d3f2 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -186,6 +186,7 @@ class Extension: def list_extensions(): extensions.clear() + extension_paths.clear() if shared.cmd_opts.disable_all_extensions: print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") @@ -220,6 +221,7 @@ def list_extensions(): is_builtin = dirname == extensions_builtin_dir extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata) extensions.append(extension) + extension_paths[extension.path] = extension loaded_extensions[canonical_name] = extension # check for requirements @@ -238,4 +240,19 @@ def list_extensions(): continue +def find_extension(filename): + parentdir = os.path.dirname(os.path.realpath(filename)) + + while parentdir != filename: + extension = extension_paths.get(parentdir) + if extension is not None: + return extension + + filename = parentdir + parentdir = os.path.dirname(filename) + + return None + + extensions: list[Extension] = [] +extension_paths: dict[str, Extension] = {} diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 08bc52564..98952ec78 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,13 +1,14 @@ +from __future__ import annotations + import dataclasses import inspect import os -from collections import namedtuple from typing import Optional, Any from fastapi import FastAPI from gradio import Blocks -from modules import errors, timer +from modules import errors, timer, extensions def report_exception(c, job): @@ -116,7 +117,35 @@ class BeforeTokenCounterParams: is_positive: bool = True -ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) +@dataclasses.dataclass +class ScriptCallback: + script: str + callback: '' + name: str = None + + +def add_callback(callbacks, fun, *, name=None, category='unknown'): + stack = [x for x in inspect.stack() if x.filename != __file__] + filename = stack[0].filename if stack else 'unknown file' + + extension = extensions.find_extension(filename) + extension_name = extension.canonical_name if extension else 'base' + + callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}" + if name is not None: + callback_name += f'/{name}' + + unique_callback_name = callback_name + for index in range(1000): + existing = any(x.name == unique_callback_name for x in callbacks) + if not existing: + break + + unique_callback_name = f'{callback_name}-{index+1}' + + callbacks.append(ScriptCallback(filename, fun, unique_callback_name)) + + callback_map = dict( callbacks_app_started=[], callbacks_model_loaded=[], @@ -328,13 +357,6 @@ def before_token_counter_callback(params: BeforeTokenCounterParams): report_exception(c, 'before_token_counter') -def add_callback(callbacks, fun): - stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if stack else 'unknown file' - - callbacks.append(ScriptCallback(filename, fun)) - - def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if stack else 'unknown file' @@ -351,24 +373,24 @@ def remove_callbacks_for_function(callback_func): callback_list.remove(callback_to_remove) -def on_app_started(callback): +def on_app_started(callback, *, name=None): """register a function to be called when the webui started, the gradio `Block` component and fastapi `FastAPI` object are passed as the arguments""" - add_callback(callback_map['callbacks_app_started'], callback) + add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started') -def on_before_reload(callback): +def on_before_reload(callback, *, name=None): """register a function to be called just before the server reloads.""" - add_callback(callback_map['callbacks_on_reload'], callback) + add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload') -def on_model_loaded(callback): +def on_model_loaded(callback, *, name=None): """register a function to be called when the stable diffusion model is created; the model is passed as an argument; this function is also called when the script is reloaded. """ - add_callback(callback_map['callbacks_model_loaded'], callback) + add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded') -def on_ui_tabs(callback): +def on_ui_tabs(callback, *, name=None): """register a function to be called when the UI is creating new tabs. The function must either return a None, which means no new tabs to be added, or a list, where each element is a tuple: @@ -378,71 +400,71 @@ def on_ui_tabs(callback): title is tab text displayed to user in the UI elem_id is HTML id for the tab """ - add_callback(callback_map['callbacks_ui_tabs'], callback) + add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs') -def on_ui_train_tabs(callback): +def on_ui_train_tabs(callback, *, name=None): """register a function to be called when the UI is creating new tabs for the train tab. Create your new tabs with gr.Tab. """ - add_callback(callback_map['callbacks_ui_train_tabs'], callback) + add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs') -def on_ui_settings(callback): +def on_ui_settings(callback, *, name=None): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ - add_callback(callback_map['callbacks_ui_settings'], callback) + add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings') -def on_before_image_saved(callback): +def on_before_image_saved(callback, *, name=None): """register a function to be called before an image is saved to a file. The callback is called with one argument: - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. """ - add_callback(callback_map['callbacks_before_image_saved'], callback) + add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved') -def on_image_saved(callback): +def on_image_saved(callback, *, name=None): """register a function to be called after an image is saved to a file. The callback is called with one argument: - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ - add_callback(callback_map['callbacks_image_saved'], callback) + add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved') -def on_extra_noise(callback): +def on_extra_noise(callback, *, name=None): """register a function to be called before adding extra noise in img2img or hires fix; The callback is called with one argument: - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image """ - add_callback(callback_map['callbacks_extra_noise'], callback) + add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise') -def on_cfg_denoiser(callback): +def on_cfg_denoiser(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. The callback is called with one argument: - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(callback_map['callbacks_cfg_denoiser'], callback) + add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser') -def on_cfg_denoised(callback): +def on_cfg_denoised(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. The callback is called with one argument: - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(callback_map['callbacks_cfg_denoised'], callback) + add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised') -def on_cfg_after_cfg(callback): +def on_cfg_after_cfg(callback, *, name=None): """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. The callback is called with one argument: - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. """ - add_callback(callback_map['callbacks_cfg_after_cfg'], callback) + add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg') -def on_before_component(callback): +def on_before_component(callback, *, name=None): """register a function to be called before a component is created. The callback is called with arguments: - component - gradio component that is about to be created. @@ -451,61 +473,61 @@ def on_before_component(callback): Use elem_id/label fields of kwargs to figure out which component it is. This can be useful to inject your own components somewhere in the middle of vanilla UI. """ - add_callback(callback_map['callbacks_before_component'], callback) + add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component') -def on_after_component(callback): +def on_after_component(callback, *, name=None): """register a function to be called after a component is created. See on_before_component for more.""" - add_callback(callback_map['callbacks_after_component'], callback) + add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component') -def on_image_grid(callback): +def on_image_grid(callback, *, name=None): """register a function to be called before making an image grid. The callback is called with one argument: - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ - add_callback(callback_map['callbacks_image_grid'], callback) + add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid') -def on_infotext_pasted(callback): +def on_infotext_pasted(callback, *, name=None): """register a function to be called before applying an infotext. The callback is called with two arguments: - infotext: str - raw infotext. - result: dict[str, any] - parsed infotext parameters. """ - add_callback(callback_map['callbacks_infotext_pasted'], callback) + add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted') -def on_script_unloaded(callback): +def on_script_unloaded(callback, *, name=None): """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that the script did should be reverted here""" - add_callback(callback_map['callbacks_script_unloaded'], callback) + add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded') -def on_before_ui(callback): +def on_before_ui(callback, *, name=None): """register a function to be called before the UI is created.""" - add_callback(callback_map['callbacks_before_ui'], callback) + add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui') -def on_list_optimizers(callback): +def on_list_optimizers(callback, *, name=None): """register a function to be called when UI is making a list of cross attention optimization options. The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization to it.""" - add_callback(callback_map['callbacks_list_optimizers'], callback) + add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers') -def on_list_unets(callback): +def on_list_unets(callback, *, name=None): """register a function to be called when UI is making a list of alternative options for unet. The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" - add_callback(callback_map['callbacks_list_unets'], callback) + add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets') -def on_before_token_counter(callback): +def on_before_token_counter(callback, *, name=None): """register a function to be called when UI is counting tokens for a prompt. The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" - add_callback(callback_map['callbacks_before_token_counter'], callback) + add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')