feat(rvcmd): add option --update

This commit is contained in:
源文雨 2024-04-23 03:16:27 +09:00
parent deca566ab2
commit 659f5d25b0
4 changed files with 27 additions and 19 deletions

View File

@ -58,6 +58,7 @@ class Config:
self.noautoopen, self.noautoopen,
self.dml, self.dml,
self.nocheck, self.nocheck,
self.update,
) = self.arg_parse() ) = self.arg_parse()
self.instead = "" self.instead = ""
self.preprocess_per = 3.7 self.preprocess_per = 3.7
@ -97,6 +98,9 @@ class Config:
parser.add_argument( parser.add_argument(
"--nocheck", action="store_true", help="Run without checking assets" "--nocheck", action="store_true", help="Run without checking assets"
) )
parser.add_argument(
"--update", action="store_true", help="Update to latest assets"
)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865 cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865
@ -109,6 +113,7 @@ class Config:
cmd_opts.noautoopen, cmd_opts.noautoopen,
cmd_opts.dml, cmd_opts.dml,
cmd_opts.nocheck, cmd_opts.nocheck,
cmd_opts.update,
) )
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.

View File

@ -156,11 +156,12 @@ if __name__ == "__main__":
tmp = os.path.join(now_dir, "TEMP") tmp = os.path.join(now_dir, "TEMP")
shutil.rmtree(tmp, ignore_errors=True) shutil.rmtree(tmp, ignore_errors=True)
os.makedirs(tmp, exist_ok=True) os.makedirs(tmp, exist_ok=True)
if not check_all_assets(): if not check_all_assets(update=self.config.update):
download_all_assets(tmpdir=tmp) if self.config.update:
if not check_all_assets(): download_all_assets(tmpdir=tmp)
printt("counld not satisfy all assets needed.") if not check_all_assets(update=False):
exit(1) printt("counld not satisfy all assets needed.")
exit(1)
def load(self): def load(self):
try: try:

View File

@ -57,11 +57,12 @@ vc = VC(config)
if not config.nocheck: if not config.nocheck:
from infer.lib.rvcmd import check_all_assets, download_all_assets from infer.lib.rvcmd import check_all_assets, download_all_assets
if not check_all_assets(): if not check_all_assets(update=config.update):
download_all_assets(tmpdir=tmp) if config.update:
if not check_all_assets(): download_all_assets(tmpdir=tmp)
logging.error("counld not satisfy all assets needed.") if not check_all_assets(update=False):
exit(1) logging.error("counld not satisfy all assets needed.")
exit(1)
if config.dml == True: if config.dml == True:

View File

@ -16,7 +16,7 @@ def sha256(f) -> str:
return sha256_hash.hexdigest() return sha256_hash.hexdigest()
def check_model(dir_name: Path, model_name: str, hash: str) -> bool: def check_model(dir_name: Path, model_name: str, hash: str, remove_incorrect=False) -> bool:
target = dir_name / model_name target = dir_name / model_name
relname = target.as_posix() relname = target.as_posix()
relname = relname[relname.rindex("assets/") :] relname = relname[relname.rindex("assets/") :]
@ -30,12 +30,12 @@ def check_model(dir_name: Path, model_name: str, hash: str) -> bool:
logger.info(f"{target} sha256 hash mismatch.") logger.info(f"{target} sha256 hash mismatch.")
logger.info(f"expected: {hash}") logger.info(f"expected: {hash}")
logger.info(f"real val: {digest}") logger.info(f"real val: {digest}")
os.remove(str(target)) if remove_incorrect: os.remove(str(target))
return False return False
return True return True
def check_all_assets() -> bool: def check_all_assets(update=False) -> bool:
BASE_DIR = Path(__file__).resolve().parent.parent.parent BASE_DIR = Path(__file__).resolve().parent.parent.parent
logger.info("checking hubret & rmvpe...") logger.info("checking hubret & rmvpe...")
@ -44,14 +44,15 @@ def check_all_assets() -> bool:
BASE_DIR / "assets" / "hubert", BASE_DIR / "assets" / "hubert",
"hubert_base.pt", "hubert_base.pt",
os.environ["sha256_hubert_base_pt"], os.environ["sha256_hubert_base_pt"],
update,
): ):
return False return False
if not check_model( if not check_model(
BASE_DIR / "assets" / "rmvpe", "rmvpe.pt", os.environ["sha256_rmvpe_pt"] BASE_DIR / "assets" / "rmvpe", "rmvpe.pt", os.environ["sha256_rmvpe_pt"], update,
): ):
return False return False
if not check_model( if not check_model(
BASE_DIR / "assets" / "rmvpe", "rmvpe.onnx", os.environ["sha256_rmvpe_onnx"] BASE_DIR / "assets" / "rmvpe", "rmvpe.onnx", os.environ["sha256_rmvpe_onnx"], update,
): ):
return False return False
@ -73,14 +74,14 @@ def check_all_assets() -> bool:
] ]
for model in model_names: for model in model_names:
menv = model.replace(".", "_") menv = model.replace(".", "_")
if not check_model(rvc_models_dir, model, os.environ[f"sha256_v1_{menv}"]): if not check_model(rvc_models_dir, model, os.environ[f"sha256_v1_{menv}"], update):
return False return False
rvc_models_dir = BASE_DIR / "assets" / "pretrained_v2" rvc_models_dir = BASE_DIR / "assets" / "pretrained_v2"
logger.info("checking pretrained models v2...") logger.info("checking pretrained models v2...")
for model in model_names: for model in model_names:
menv = model.replace(".", "_") menv = model.replace(".", "_")
if not check_model(rvc_models_dir, model, os.environ[f"sha256_v2_{menv}"]): if not check_model(rvc_models_dir, model, os.environ[f"sha256_v2_{menv}"], update):
return False return False
logger.info("checking uvr5_weights...") logger.info("checking uvr5_weights...")
@ -97,12 +98,12 @@ def check_all_assets() -> bool:
] ]
for model in model_names: for model in model_names:
menv = model.replace(".", "_") menv = model.replace(".", "_")
if not check_model(rvc_models_dir, model, os.environ[f"sha256_uvr5_{menv}"]): if not check_model(rvc_models_dir, model, os.environ[f"sha256_uvr5_{menv}"], update):
return False return False
if not check_model( if not check_model(
BASE_DIR / "assets" / "uvr5_weights" / "onnx_dereverb_By_FoxJoy", BASE_DIR / "assets" / "uvr5_weights" / "onnx_dereverb_By_FoxJoy",
"vocals.onnx", "vocals.onnx",
os.environ[f"sha256_uvr5_vocals_onnx"], os.environ[f"sha256_uvr5_vocals_onnx"], update,
): ):
return False return False