From d35bf649456da2558cbb6f2ea16fa1606022b7e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 1 Nov 2022 14:19:24 +0300 Subject: [PATCH] make launch.py run installers for extensions that have ones add some more classes to safety module for an extension --- launch.py | 22 ++++++++++++++++++++-- modules/safe.py | 2 +- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/launch.py b/launch.py index 958336f21..0d90d5539 100644 --- a/launch.py +++ b/launch.py @@ -7,6 +7,7 @@ import shlex import platform dir_repos = "repositories" +dir_extensions = "extensions" python = sys.executable git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") @@ -101,9 +102,24 @@ def version_check(commit): else: print("Not a git clone, can't perform version check.") except Exception as e: - print("versipm check failed",e) + print("version check failed", e) + + +def run_extensions_installers(): + if not os.path.isdir(dir_extensions): + return + + for dirname_extension in os.listdir(dir_extensions): + path_installer = os.path.join(dir_extensions, dirname_extension, "install.py") + if not os.path.isfile(path_installer): + continue + + try: + print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}")) + except Exception as e: + print(e, file=sys.stderr) + - def prepare_enviroment(): torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") @@ -189,6 +205,8 @@ def prepare_enviroment(): run_pip(f"install -r {requirements_file}", "requirements for Web UI") + run_extensions_installers() + if update_check: version_check(commit) diff --git a/modules/safe.py b/modules/safe.py index 399165a19..348a24fcd 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler): return getattr(collections, name) if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: return getattr(torch._utils, name) - if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']: + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']: return getattr(torch, name) if module == 'torch.nn.modules.container' and name in ['ParameterDict']: return getattr(torch.nn.modules.container, name)