diff --git a/infer/lib/rvcmd.py b/infer/lib/rvcmd.py index 2f916a5..db4c89f 100644 --- a/infer/lib/rvcmd.py +++ b/infer/lib/rvcmd.py @@ -7,17 +7,19 @@ import logging logger = logging.getLogger(__name__) + def sha256(f) -> str: sha256_hash = hashlib.sha256() # Read and update hash in chunks of 4M - for byte_block in iter(lambda: f.read(4*1024*1024), b""): + for byte_block in iter(lambda: f.read(4 * 1024 * 1024), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() + def check_model(dir_name: Path, model_name: str, hash: str) -> bool: target = dir_name / model_name relname = str(target) - relname = relname[relname.rindex("assets/"):] + relname = relname[relname.rindex("assets/") :] logger.debug(f"checking {relname}...") if not os.path.exists(target): logger.info(f"{target} not exist.") @@ -31,16 +33,25 @@ def check_model(dir_name: Path, model_name: str, hash: str) -> bool: return False return True + def check_all_assets() -> bool: BASE_DIR = Path(__file__).resolve().parent.parent.parent logger.info("checking hubret & rmvpe...") - if not check_model(BASE_DIR / "assets/hubert", "hubert_base.pt", os.environ["sha256_hubert_base_pt"]): + if not check_model( + BASE_DIR / "assets/hubert", + "hubert_base.pt", + os.environ["sha256_hubert_base_pt"], + ): return False - if not check_model(BASE_DIR / "assets/rmvpe", "rmvpe.pt", os.environ["sha256_rmvpe_pt"]): + if not check_model( + BASE_DIR / "assets/rmvpe", "rmvpe.pt", os.environ["sha256_rmvpe_pt"] + ): return False - if not check_model(BASE_DIR / "assets/rmvpe", "rmvpe.onnx", os.environ["sha256_rmvpe_onnx"]): + if not check_model( + BASE_DIR / "assets/rmvpe", "rmvpe.onnx", os.environ["sha256_rmvpe_onnx"] + ): return False rvc_models_dir = BASE_DIR / "assets/pretrained" @@ -91,13 +102,16 @@ def check_all_assets() -> bool: BASE_DIR / "assets/uvr5_weights/onnx_dereverb_By_FoxJoy", "vocals.onnx", os.environ[f"sha256_uvr5_vocals_onnx"], - ): return False + ): + return False logger.info("all assets are already latest.") return True + def download_and_extract_tar_gz(url: str, folder: str): import tarfile + logger.info(f"downloading {url}") response = requests.get(url, stream=True) with BytesIO() as out_file: @@ -108,8 +122,10 @@ def download_and_extract_tar_gz(url: str, folder: str): tar.extractall(folder) logger.info(f"extracted into {folder}") + def download_and_extract_zip(url: str, folder: str): import zipfile + logger.info(f"downloading {url}") response = requests.get(url) with BytesIO() as out_file: @@ -120,14 +136,22 @@ def download_and_extract_zip(url: str, folder: str): zip_ref.extractall(folder) logger.info(f"extracted into {folder}") + def download_all_assets(tmpdir: str, version="0.2.1"): import subprocess import platform archs = { - "aarch64": "arm64", "armv8l": "arm64", "arm64": "arm64", - "x86": "386", "i386": "386", "i686": "386", "386": "386", - "x86_64": "amd64", "x64": "amd64", "amd64": "amd64", + "aarch64": "arm64", + "armv8l": "arm64", + "arm64": "arm64", + "x86": "386", + "i386": "386", + "i686": "386", + "386": "386", + "x86_64": "amd64", + "x64": "amd64", + "amd64": "amd64", } system_type = platform.system().lower() architecture = platform.machine().lower() @@ -139,7 +163,7 @@ def download_all_assets(tmpdir: str, version="0.2.1"): exit(1) BASE_URL = "https://github.com/RVC-Project/RVC-Models-Downloader/releases/download/" suffix = "zip" if is_win else "tar.gz" - RVCMD_URL = BASE_URL+f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}" + RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}" cmdfile = tmpdir + "/rvcmd" if is_win: download_and_extract_zip(RVCMD_URL, tmpdir)