diff --git a/configs/config.py b/configs/config.py index 78bfdae..c8aaff6 100644 --- a/configs/config.py +++ b/configs/config.py @@ -58,6 +58,7 @@ class Config: self.noautoopen, self.dml, self.nocheck, + self.update, ) = self.arg_parse() self.instead = "" self.preprocess_per = 3.7 @@ -97,6 +98,9 @@ class Config: parser.add_argument( "--nocheck", action="store_true", help="Run without checking assets" ) + parser.add_argument( + "--update", action="store_true", help="Update to latest assets" + ) cmd_opts = parser.parse_args() cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865 @@ -109,6 +113,7 @@ class Config: cmd_opts.noautoopen, cmd_opts.dml, cmd_opts.nocheck, + cmd_opts.update, ) # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. diff --git a/gui_v1.py b/gui_v1.py index 2b30b00..8f910d9 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -156,11 +156,12 @@ if __name__ == "__main__": tmp = os.path.join(now_dir, "TEMP") shutil.rmtree(tmp, ignore_errors=True) os.makedirs(tmp, exist_ok=True) - if not check_all_assets(): - download_all_assets(tmpdir=tmp) - if not check_all_assets(): - printt("counld not satisfy all assets needed.") - exit(1) + if not check_all_assets(update=self.config.update): + if self.config.update: + download_all_assets(tmpdir=tmp) + if not check_all_assets(update=False): + printt("counld not satisfy all assets needed.") + exit(1) def load(self): try: diff --git a/infer-web.py b/infer-web.py index c38a8fc..bdf3357 100644 --- a/infer-web.py +++ b/infer-web.py @@ -57,11 +57,12 @@ vc = VC(config) if not config.nocheck: from infer.lib.rvcmd import check_all_assets, download_all_assets - if not check_all_assets(): - download_all_assets(tmpdir=tmp) - if not check_all_assets(): - logging.error("counld not satisfy all assets needed.") - exit(1) + if not check_all_assets(update=config.update): + if config.update: + download_all_assets(tmpdir=tmp) + if not check_all_assets(update=False): + logging.error("counld not satisfy all assets needed.") + exit(1) if config.dml == True: diff --git a/infer/lib/rvcmd.py b/infer/lib/rvcmd.py index 06971f8..56e9277 100644 --- a/infer/lib/rvcmd.py +++ b/infer/lib/rvcmd.py @@ -16,7 +16,7 @@ def sha256(f) -> str: return sha256_hash.hexdigest() -def check_model(dir_name: Path, model_name: str, hash: str) -> bool: +def check_model(dir_name: Path, model_name: str, hash: str, remove_incorrect=False) -> bool: target = dir_name / model_name relname = target.as_posix() relname = relname[relname.rindex("assets/") :] @@ -30,12 +30,12 @@ def check_model(dir_name: Path, model_name: str, hash: str) -> bool: logger.info(f"{target} sha256 hash mismatch.") logger.info(f"expected: {hash}") logger.info(f"real val: {digest}") - os.remove(str(target)) + if remove_incorrect: os.remove(str(target)) return False return True -def check_all_assets() -> bool: +def check_all_assets(update=False) -> bool: BASE_DIR = Path(__file__).resolve().parent.parent.parent logger.info("checking hubret & rmvpe...") @@ -44,14 +44,15 @@ def check_all_assets() -> bool: BASE_DIR / "assets" / "hubert", "hubert_base.pt", os.environ["sha256_hubert_base_pt"], + update, ): return False if not check_model( - BASE_DIR / "assets" / "rmvpe", "rmvpe.pt", os.environ["sha256_rmvpe_pt"] + BASE_DIR / "assets" / "rmvpe", "rmvpe.pt", os.environ["sha256_rmvpe_pt"], update, ): return False if not check_model( - BASE_DIR / "assets" / "rmvpe", "rmvpe.onnx", os.environ["sha256_rmvpe_onnx"] + BASE_DIR / "assets" / "rmvpe", "rmvpe.onnx", os.environ["sha256_rmvpe_onnx"], update, ): return False @@ -73,14 +74,14 @@ def check_all_assets() -> bool: ] for model in model_names: menv = model.replace(".", "_") - if not check_model(rvc_models_dir, model, os.environ[f"sha256_v1_{menv}"]): + if not check_model(rvc_models_dir, model, os.environ[f"sha256_v1_{menv}"], update): return False rvc_models_dir = BASE_DIR / "assets" / "pretrained_v2" logger.info("checking pretrained models v2...") for model in model_names: menv = model.replace(".", "_") - if not check_model(rvc_models_dir, model, os.environ[f"sha256_v2_{menv}"]): + if not check_model(rvc_models_dir, model, os.environ[f"sha256_v2_{menv}"], update): return False logger.info("checking uvr5_weights...") @@ -97,12 +98,12 @@ def check_all_assets() -> bool: ] for model in model_names: menv = model.replace(".", "_") - if not check_model(rvc_models_dir, model, os.environ[f"sha256_uvr5_{menv}"]): + if not check_model(rvc_models_dir, model, os.environ[f"sha256_uvr5_{menv}"], update): return False if not check_model( BASE_DIR / "assets" / "uvr5_weights" / "onnx_dereverb_By_FoxJoy", "vocals.onnx", - os.environ[f"sha256_uvr5_vocals_onnx"], + os.environ[f"sha256_uvr5_vocals_onnx"], update, ): return False