From 95821f0132f5437ef30b0dbcac7c51e55818c18f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 18:11:13 +0300 Subject: [PATCH] split webui.py's initialization and utility functions into separate files --- modules/gradio_extensons.py | 4 +- modules/initialize.py | 168 ++++++++++++++++ modules/initialize_util.py | 195 +++++++++++++++++++ modules/shared_init.py | 3 - modules/ui_extra_networks.py | 3 +- modules/ui_tempdir.py | 5 +- webui.py | 368 ++++------------------------------- 7 files changed, 405 insertions(+), 341 deletions(-) create mode 100644 modules/initialize.py create mode 100644 modules/initialize_util.py diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py index 5af7fd8ec..77c34c8ba 100644 --- a/modules/gradio_extensons.py +++ b/modules/gradio_extensons.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import scripts +from modules import scripts, ui_tempdir def add_classes_to_gradio_component(comp): """ @@ -58,3 +58,5 @@ original_BlockContext_init = gr.blocks.BlockContext.__init__ gr.components.IOComponent.__init__ = IOComponent_init gr.blocks.Block.get_config = Block_get_config gr.blocks.BlockContext.__init__ = BlockContext_init + +ui_tempdir.install_ui_tempdir_override() diff --git a/modules/initialize.py b/modules/initialize.py new file mode 100644 index 000000000..f24f76375 --- /dev/null +++ b/modules/initialize.py @@ -0,0 +1,168 @@ +import importlib +import logging +import sys +import warnings +from threading import Thread + +from modules.timer import startup_timer + + +def imports(): + logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... + logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + + import torch # noqa: F401 + startup_timer.record("import torch") + import pytorch_lightning # noqa: F401 + startup_timer.record("import torch") + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + + import gradio # noqa: F401 + startup_timer.record("import gradio") + + from modules import paths, timer, import_hook, errors # noqa: F401 + startup_timer.record("setup paths") + + import ldm.modules.encoders.modules # noqa: F401 + startup_timer.record("import ldm") + + import sgm.modules.encoders.modules # noqa: F401 + startup_timer.record("import sgm") + + from modules import shared_init + shared_init.initialize() + startup_timer.record("initialize shared") + + from modules import processing, gradio_extensons, ui # noqa: F401 + startup_timer.record("other imports") + + +def check_versions(): + from modules.shared_cmd_options import cmd_opts + + if not cmd_opts.skip_version_check: + from modules import errors + errors.check_versions() + + +def initialize(): + from modules import initialize_util + initialize_util.fix_torch_version() + initialize_util.fix_asyncio_event_loop_policy() + initialize_util.validate_tls_options() + initialize_util.configure_sigint_handler() + initialize_util.configure_opts_onchange() + + from modules import modelloader + modelloader.cleanup_models() + + from modules import sd_models + sd_models.setup_model() + startup_timer.record("setup SD model") + + from modules.shared_cmd_options import cmd_opts + + from modules import codeformer_model + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor") + codeformer_model.setup_model(cmd_opts.codeformer_models_path) + startup_timer.record("setup codeformer") + + from modules import gfpgan_model + gfpgan_model.setup_model(cmd_opts.gfpgan_models_path) + startup_timer.record("setup gfpgan") + + initialize_rest(reload_script_modules=False) + + +def initialize_rest(*, reload_script_modules=False): + """ + Called both from initialize() and when reloading the webui. + """ + from modules.shared_cmd_options import cmd_opts + + from modules import sd_samplers + sd_samplers.set_samplers() + startup_timer.record("set samplers") + + from modules import extensions + extensions.list_extensions() + startup_timer.record("list extensions") + + from modules import initialize_util + initialize_util.restore_config_state_file() + startup_timer.record("restore config state file") + + from modules import shared, upscaler, scripts + if cmd_opts.ui_debug_mode: + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + scripts.load_scripts() + return + + from modules import sd_models + sd_models.list_models() + startup_timer.record("list SD models") + + from modules import localization + localization.list_localizations(cmd_opts.localizations_dir) + startup_timer.record("list localizations") + + with startup_timer.subcategory("load scripts"): + scripts.load_scripts() + + if reload_script_modules: + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + startup_timer.record("reload script modules") + + from modules import modelloader + modelloader.load_upscalers() + startup_timer.record("load upscalers") + + from modules import sd_vae + sd_vae.refresh_vae_list() + startup_timer.record("refresh VAE") + + from modules import textual_inversion + textual_inversion.textual_inversion.list_textual_inversion_templates() + startup_timer.record("refresh textual inversion templates") + + from modules import script_callbacks, sd_hijack_optimizations, sd_hijack + script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers) + sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + + from modules import sd_unet + sd_unet.list_unets() + startup_timer.record("scripts list_unets") + + def load_model(): + """ + Accesses shared.sd_model property to load model. + After it's available, if it has been loaded before this access by some extension, + its optimization may be None because the list of optimizaers has neet been filled + by that time, so we apply optimization again. + """ + + shared.sd_model # noqa: B018 + + if sd_hijack.current_optimizer is None: + sd_hijack.apply_optimizations() + + from modules import devices + devices.first_time_calculation() + + Thread(target=load_model).start() + + from modules import shared_items + shared_items.reload_hypernetworks() + startup_timer.record("reload hypernetworks") + + from modules import ui_extra_networks + ui_extra_networks.initialize() + ui_extra_networks.register_default_pages() + + from modules import extra_networks + extra_networks.initialize() + extra_networks.register_default_extra_networks() + startup_timer.record("initialize extra networks") diff --git a/modules/initialize_util.py b/modules/initialize_util.py new file mode 100644 index 000000000..e59bd3c49 --- /dev/null +++ b/modules/initialize_util.py @@ -0,0 +1,195 @@ +import json +import logging +import os +import signal +import sys +import re + +from modules.timer import startup_timer + +def setup_logging(): + # We can't use cmd_opts for this because it will not have been initialized at this point. + log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") + if log_level: + log_level = getattr(logging, log_level.upper(), None) or logging.INFO + logging.basicConfig( + level=log_level, + format='%(asctime)s %(levelname)s [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + + +def gradio_server_name(): + from modules.shared_cmd_options import cmd_opts + + if cmd_opts.server_name: + return cmd_opts.server_name + else: + return "0.0.0.0" if cmd_opts.listen else None + + +def fix_torch_version(): + import torch + + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors + if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__long_version__ = torch.__version__ + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) + + +def fix_asyncio_event_loop_policy(): + """ + The default `asyncio` event loop policy only automatically creates + event loops in the main threads. Other threads must create event + loops explicitly or `asyncio.get_event_loop` (and therefore + `.IOLoop.current`) will fail. Installing this policy allows event + loops to be created automatically on any thread, matching the + behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). + """ + + import asyncio + + if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # "Any thread" and "selector" should be orthogonal, but there's not a clean + # interface for composing policies so pick the right base. + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore + else: + _BasePolicy = asyncio.DefaultEventLoopPolicy + + class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore + """Event loop policy that allows loop creation on any thread. + Usage:: + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + """ + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + # This was an AssertionError in python 3.4.2 (which ships with debian jessie) + # and changed to a RuntimeError in 3.4.3. + # "There is no current event loop in thread %r" + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + + +def restore_config_state_file(): + from modules import shared, config_states + + config_state_file = shared.opts.restore_config_state_file + if config_state_file == "": + return + + shared.opts.restore_config_state_file = "" + shared.opts.save(shared.config_filename) + + if os.path.isfile(config_state_file): + print(f"*** About to restore extension state from file: {config_state_file}") + with open(config_state_file, "r", encoding="utf-8") as f: + config_state = json.load(f) + config_states.restore_extension_config(config_state) + startup_timer.record("restore extension config") + elif config_state_file: + print(f"!!! Config state backup not found: {config_state_file}") + + +def validate_tls_options(): + from modules.shared_cmd_options import cmd_opts + + if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile): + return + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + startup_timer.record("TLS") + + +def get_gradio_auth_creds(): + """ + Convert the gradio_auth and gradio_auth_path commandline arguments into + an iterable of (username, password) tuples. + """ + from modules.shared_cmd_options import cmd_opts + + def process_credential_line(s): + s = s.strip() + if not s: + return None + return tuple(s.split(':', 1)) + + if cmd_opts.gradio_auth: + for cred in cmd_opts.gradio_auth.split(','): + cred = process_credential_line(cred) + if cred: + yield cred + + if cmd_opts.gradio_auth_path: + with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + for cred in line.strip().split(','): + cred = process_credential_line(cred) + if cred: + yield cred + + +def configure_sigint_handler(): + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + if not os.environ.get("COVERAGE_RUN"): + # Don't install the immediate-quit handler when running under coverage, + # as then the coverage report won't be generated. + signal.signal(signal.SIGINT, sigint_handler) + + +def configure_opts_onchange(): + from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack + from modules.call_queue import wrap_queued_call + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) + shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + startup_timer.record("opts onchange") + + +def setup_middleware(app): + from starlette.middleware.gzip import GZipMiddleware + + app.middleware_stack = None # reset current middleware to allow modifying user provided list + app.add_middleware(GZipMiddleware, minimum_size=1000) + configure_cors_middleware(app) + app.build_middleware_stack() # rebuild middleware stack on-the-fly + + +def configure_cors_middleware(app): + from starlette.middleware.cors import CORSMiddleware + from modules.shared_cmd_options import cmd_opts + + cors_options = { + "allow_methods": ["*"], + "allow_headers": ["*"], + "allow_credentials": True, + } + if cmd_opts.cors_allow_origins: + cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',') + if cmd_opts.cors_allow_origins_regex: + cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex + app.add_middleware(CORSMiddleware, **cors_options) + diff --git a/modules/shared_init.py b/modules/shared_init.py index b88d1d8e7..d3fb687e0 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -5,9 +5,6 @@ import torch from modules import shared from modules.shared import cmd_opts -import sys -sys.setrecursionlimit(1000) - def initialize(): """Initializes fields inside the shared module in a controlled manner. diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e0b932b94..16d76a452 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -4,7 +4,6 @@ from pathlib import Path from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo -from modules.ui import up_down_symbol import gradio as gr import json import html @@ -348,6 +347,8 @@ def pages_in_preferred_order(pages): def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): + from modules.ui import up_down_symbol + ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index fb75137e6..506017e50 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -57,8 +57,9 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"): return file_obj.name -# override save to file function so that it also writes PNG info -gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file +def install_ui_tempdir_override(): + """override save to file function so that it also writes PNG info""" + gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file def on_tmpdir_changed(): diff --git a/webui.py b/webui.py index 0f1ace97a..738b3bef2 100644 --- a/webui.py +++ b/webui.py @@ -1,349 +1,43 @@ from __future__ import annotations import os -import sys import time -import importlib -import signal -import re -import warnings -import json -from threading import Thread -from typing import Iterable - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.middleware.gzip import GZipMiddleware - -import logging - -# We can't use cmd_opts for this because it will not have been initialized at this point. -log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") -if log_level: - log_level = getattr(logging, log_level.upper(), None) or logging.INFO - logging.basicConfig( - level=log_level, - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - ) - -logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... -logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) from modules import timer +from modules import initialize_util +from modules import initialize + startup_timer = timer.startup_timer startup_timer.record("launcher") -import torch -import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them -warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") -warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") -startup_timer.record("import torch") +initialize_util.setup_logging() -import gradio # noqa: F401 -startup_timer.record("import gradio") +initialize.imports() -from modules import paths, timer, import_hook, errors # noqa: F401 -startup_timer.record("setup paths") - -import ldm.modules.encoders.modules # noqa: F401 -startup_timer.record("import ldm") - -from modules import shared_init, shared, shared_items -shared_init.initialize() -startup_timer.record("initialize shared") - -from modules import extra_networks -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 - -# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors -if ".dev" in torch.__version__ or "+git" in torch.__version__: - torch.__long_version__ = torch.__version__ - torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) - -if not shared.cmd_opts.skip_version_check: - errors.check_versions() - -import modules.codeformer_model as codeformer -import modules.gfpgan_model as gfpgan -from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states -import modules.face_restoration -import modules.img2img - -import modules.lowvram -import modules.scripts -import modules.sd_hijack -import modules.sd_hijack_optimizations -import modules.sd_models -import modules.sd_vae -import modules.sd_unet -import modules.txt2img -import modules.script_callbacks -import modules.textual_inversion.textual_inversion -import modules.progress - -import modules.ui -from modules import modelloader, devices -from modules.shared import cmd_opts -import modules.hypernetworks.hypernetwork - -startup_timer.record("other imports") - - -if cmd_opts.server_name: - server_name = cmd_opts.server_name -else: - server_name = "0.0.0.0" if cmd_opts.listen else None - - -def fix_asyncio_event_loop_policy(): - """ - The default `asyncio` event loop policy only automatically creates - event loops in the main threads. Other threads must create event - loops explicitly or `asyncio.get_event_loop` (and therefore - `.IOLoop.current`) will fail. Installing this policy allows event - loops to be created automatically on any thread, matching the - behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). - """ - - import asyncio - - if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): - # "Any thread" and "selector" should be orthogonal, but there's not a clean - # interface for composing policies so pick the right base. - _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore - else: - _BasePolicy = asyncio.DefaultEventLoopPolicy - - class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore - """Event loop policy that allows loop creation on any thread. - Usage:: - - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - """ - - def get_event_loop(self) -> asyncio.AbstractEventLoop: - try: - return super().get_event_loop() - except (RuntimeError, AssertionError): - # This was an AssertionError in python 3.4.2 (which ships with debian jessie) - # and changed to a RuntimeError in 3.4.3. - # "There is no current event loop in thread %r" - loop = self.new_event_loop() - self.set_event_loop(loop) - return loop - - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - - -def restore_config_state_file(): - config_state_file = shared.opts.restore_config_state_file - if config_state_file == "": - return - - shared.opts.restore_config_state_file = "" - shared.opts.save(shared.config_filename) - - if os.path.isfile(config_state_file): - print(f"*** About to restore extension state from file: {config_state_file}") - with open(config_state_file, "r", encoding="utf-8") as f: - config_state = json.load(f) - config_states.restore_extension_config(config_state) - startup_timer.record("restore extension config") - elif config_state_file: - print(f"!!! Config state backup not found: {config_state_file}") - - -def validate_tls_options(): - if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile): - return - - try: - if not os.path.exists(cmd_opts.tls_keyfile): - print("Invalid path to TLS keyfile given") - if not os.path.exists(cmd_opts.tls_certfile): - print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") - except TypeError: - cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None - print("TLS setup invalid, running webui without TLS") - else: - print("Running with TLS") - startup_timer.record("TLS") - - -def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]: - """ - Convert the gradio_auth and gradio_auth_path commandline arguments into - an iterable of (username, password) tuples. - """ - def process_credential_line(s) -> tuple[str, ...] | None: - s = s.strip() - if not s: - return None - return tuple(s.split(':', 1)) - - if cmd_opts.gradio_auth: - for cred in cmd_opts.gradio_auth.split(','): - cred = process_credential_line(cred) - if cred: - yield cred - - if cmd_opts.gradio_auth_path: - with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: - for line in file.readlines(): - for cred in line.strip().split(','): - cred = process_credential_line(cred) - if cred: - yield cred - - -def configure_sigint_handler(): - # make the program just exit at ctrl+c without waiting for anything - def sigint_handler(sig, frame): - print(f'Interrupted with signal {sig} in {frame}') - os._exit(0) - - if not os.environ.get("COVERAGE_RUN"): - # Don't install the immediate-quit handler when running under coverage, - # as then the coverage report won't be generated. - signal.signal(signal.SIGINT, sigint_handler) - - -def configure_opts_onchange(): - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False) - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) - shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) - shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) - startup_timer.record("opts onchange") - - -def initialize(): - fix_asyncio_event_loop_policy() - validate_tls_options() - configure_sigint_handler() - modelloader.cleanup_models() - configure_opts_onchange() - - modules.sd_models.setup_model() - startup_timer.record("setup SD model") - - codeformer.setup_model(cmd_opts.codeformer_models_path) - startup_timer.record("setup codeformer") - - gfpgan.setup_model(cmd_opts.gfpgan_models_path) - startup_timer.record("setup gfpgan") - - initialize_rest(reload_script_modules=False) - - -def initialize_rest(*, reload_script_modules=False): - """ - Called both from initialize() and when reloading the webui. - """ - sd_samplers.set_samplers() - extensions.list_extensions() - startup_timer.record("list extensions") - - restore_config_state_file() - - if cmd_opts.ui_debug_mode: - shared.sd_upscalers = upscaler.UpscalerLanczos().scalers - modules.scripts.load_scripts() - return - - modules.sd_models.list_models() - startup_timer.record("list SD models") - - localization.list_localizations(cmd_opts.localizations_dir) - - with startup_timer.subcategory("load scripts"): - modules.scripts.load_scripts() - - if reload_script_modules: - for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: - importlib.reload(module) - startup_timer.record("reload script modules") - - modelloader.load_upscalers() - startup_timer.record("load upscalers") - - modules.sd_vae.refresh_vae_list() - startup_timer.record("refresh VAE") - modules.textual_inversion.textual_inversion.list_textual_inversion_templates() - startup_timer.record("refresh textual inversion templates") - - modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers) - modules.sd_hijack.list_optimizers() - startup_timer.record("scripts list_optimizers") - - modules.sd_unet.list_unets() - startup_timer.record("scripts list_unets") - - def load_model(): - """ - Accesses shared.sd_model property to load model. - After it's available, if it has been loaded before this access by some extension, - its optimization may be None because the list of optimizaers has neet been filled - by that time, so we apply optimization again. - """ - - shared.sd_model # noqa: B018 - - if modules.sd_hijack.current_optimizer is None: - modules.sd_hijack.apply_optimizations() - - devices.first_time_calculation() - - Thread(target=load_model).start() - - shared_items.reload_hypernetworks() - startup_timer.record("reload hypernetworks") - - ui_extra_networks.initialize() - ui_extra_networks.register_default_pages() - - extra_networks.initialize() - extra_networks.register_default_extra_networks() - startup_timer.record("initialize extra networks") - - -def setup_middleware(app): - app.middleware_stack = None # reset current middleware to allow modifying user provided list - app.add_middleware(GZipMiddleware, minimum_size=1000) - configure_cors_middleware(app) - app.build_middleware_stack() # rebuild middleware stack on-the-fly - - -def configure_cors_middleware(app): - cors_options = { - "allow_methods": ["*"], - "allow_headers": ["*"], - "allow_credentials": True, - } - if cmd_opts.cors_allow_origins: - cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',') - if cmd_opts.cors_allow_origins_regex: - cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex - app.add_middleware(CORSMiddleware, **cors_options) +initialize.check_versions() def create_api(app): from modules.api.api import Api + from modules.call_queue import queue_lock + api = Api(app, queue_lock) return api def api_only(): - initialize() + from fastapi import FastAPI + from modules.shared_cmd_options import cmd_opts + + initialize.initialize() app = FastAPI() - setup_middleware(app) + initialize_util.setup_middleware(app) api = create_api(app) - modules.script_callbacks.before_ui_callback() - modules.script_callbacks.app_started_callback(None, app) + from modules import script_callbacks + script_callbacks.before_ui_callback() + script_callbacks.app_started_callback(None, app) print(f"Startup time: {startup_timer.summary()}.") api.launch( @@ -354,24 +48,28 @@ def api_only(): def webui(): + from modules.shared_cmd_options import cmd_opts + launch_api = cmd_opts.api - initialize() + initialize.initialize() + + from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks while 1: if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() startup_timer.record("cleanup temp dir") - modules.script_callbacks.before_ui_callback() + script_callbacks.before_ui_callback() startup_timer.record("scripts before_ui_callback") - shared.demo = modules.ui.create_ui() + shared.demo = ui.create_ui() startup_timer.record("create ui") if not cmd_opts.no_gradio_queue: shared.demo.queue(64) - gradio_auth_creds = list(get_gradio_auth_creds()) or None + gradio_auth_creds = list(initialize_util.get_gradio_auth_creds()) or None auto_launch_browser = False if os.getenv('SD_WEBUI_RESTARTING') != '1': @@ -382,7 +80,7 @@ def webui(): app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, - server_name=server_name, + server_name=initialize_util.gradio_server_name(), server_port=cmd_opts.port, ssl_keyfile=cmd_opts.tls_keyfile, ssl_certfile=cmd_opts.tls_certfile, @@ -407,10 +105,10 @@ def webui(): # running its code. We disable this here. Suggested by RyotaK. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] - setup_middleware(app) + initialize_util.setup_middleware(app) - modules.progress.setup_progress_api(app) - modules.ui.setup_ui_api(app) + progress.setup_progress_api(app) + ui.setup_ui_api(app) if launch_api: create_api(app) @@ -420,7 +118,7 @@ def webui(): startup_timer.record("add APIs") with startup_timer.subcategory("app_started_callback"): - modules.script_callbacks.app_started_callback(shared.demo, app) + script_callbacks.app_started_callback(shared.demo, app) timer.startup_record = startup_timer.dump() print(f"Startup time: {startup_timer.summary()}.") @@ -450,14 +148,16 @@ def webui(): shared.demo.close() time.sleep(0.5) startup_timer.reset() - modules.script_callbacks.app_reload_callback() + script_callbacks.app_reload_callback() startup_timer.record("app reload callback") - modules.script_callbacks.script_unloaded_callback() + script_callbacks.script_unloaded_callback() startup_timer.record("scripts unloaded callback") - initialize_rest(reload_script_modules=True) + initialize.initialize_rest(reload_script_modules=True) if __name__ == "__main__": + from modules.shared_cmd_options import cmd_opts + if cmd_opts.nowebui: api_only() else: