mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
make existing script loading and new preload code use same code for loading modules
limit extension preload scripts to just one file named preload.py
This commit is contained in:
parent
e5690d0bf2
commit
a1a376331c
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from importlib.machinery import SourceFileLoader
|
|
||||||
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
@ -85,23 +84,3 @@ def list_extensions():
|
|||||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
||||||
extensions.append(extension)
|
extensions.append(extension)
|
||||||
|
|
||||||
|
|
||||||
def preload_extensions(parser):
|
|
||||||
if not os.path.isdir(extensions_dir):
|
|
||||||
return
|
|
||||||
|
|
||||||
for dirname in sorted(os.listdir(extensions_dir)):
|
|
||||||
path = os.path.join(extensions_dir, dirname)
|
|
||||||
if not os.path.isdir(path):
|
|
||||||
continue
|
|
||||||
for file in os.listdir(path):
|
|
||||||
if "preload.py" in file:
|
|
||||||
full_file = os.path.join(path, file)
|
|
||||||
print(f"Got preload file: {full_file}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
ext = SourceFileLoader("preload", full_file).load_module()
|
|
||||||
parser = ext.preload(parser)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception preloading script: {e}")
|
|
||||||
return parser
|
|
34
modules/script_loading.py
Normal file
34
modules/script_loading.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
|
||||||
|
def load_module(path):
|
||||||
|
with open(path, "r", encoding="utf8") as file:
|
||||||
|
text = file.read()
|
||||||
|
|
||||||
|
compiled = compile(text, path, 'exec')
|
||||||
|
module = ModuleType(os.path.basename(path))
|
||||||
|
exec(compiled, module.__dict__)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def preload_extensions(extensions_dir, parser):
|
||||||
|
if not os.path.isdir(extensions_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
for dirname in sorted(os.listdir(extensions_dir)):
|
||||||
|
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||||
|
if not os.path.isfile(preload_script):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = load_module(preload_script)
|
||||||
|
if hasattr(module, 'preload'):
|
||||||
|
module.preload(parser)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
@ -6,7 +6,7 @@ from collections import namedtuple
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules.processing import StableDiffusionProcessing
|
from modules.processing import StableDiffusionProcessing
|
||||||
from modules import shared, paths, script_callbacks, extensions
|
from modules import shared, paths, script_callbacks, extensions, script_loading
|
||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
@ -161,13 +161,7 @@ def load_scripts():
|
|||||||
sys.path = [scriptfile.basedir] + sys.path
|
sys.path = [scriptfile.basedir] + sys.path
|
||||||
current_basedir = scriptfile.basedir
|
current_basedir = scriptfile.basedir
|
||||||
|
|
||||||
with open(scriptfile.path, "r", encoding="utf8") as file:
|
module = script_loading.load_module(scriptfile.path)
|
||||||
text = file.read()
|
|
||||||
|
|
||||||
from types import ModuleType
|
|
||||||
compiled = compile(text, scriptfile.path, 'exec')
|
|
||||||
module = ModuleType(scriptfile.filename)
|
|
||||||
exec(compiled, module.__dict__)
|
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for key, script_class in module.__dict__.items():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
@ -328,19 +322,13 @@ class ScriptRunner:
|
|||||||
|
|
||||||
def reload_sources(self, cache):
|
def reload_sources(self, cache):
|
||||||
for si, script in list(enumerate(self.scripts)):
|
for si, script in list(enumerate(self.scripts)):
|
||||||
with open(script.filename, "r", encoding="utf8") as file:
|
|
||||||
args_from = script.args_from
|
args_from = script.args_from
|
||||||
args_to = script.args_to
|
args_to = script.args_to
|
||||||
filename = script.filename
|
filename = script.filename
|
||||||
text = file.read()
|
|
||||||
|
|
||||||
from types import ModuleType
|
|
||||||
|
|
||||||
module = cache.get(filename, None)
|
module = cache.get(filename, None)
|
||||||
if module is None:
|
if module is None:
|
||||||
compiled = compile(text, filename, 'exec')
|
module = script_loading.load_module(script.filename)
|
||||||
module = ModuleType(script.filename)
|
|
||||||
exec(compiled, module.__dict__)
|
|
||||||
cache[filename] = module
|
cache[filename] = module
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for key, script_class in module.__dict__.items():
|
||||||
|
@ -3,7 +3,6 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections import OrderedDict
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -15,7 +14,7 @@ import modules.memmon
|
|||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import sd_samplers, sd_models, localization, sd_vae, extensions
|
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
|
|||||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||||
|
|
||||||
extensions.preload_extensions(parser)
|
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user