diff --git a/config.py b/config.py index eadbcee..959b2b7 100644 --- a/config.py +++ b/config.py @@ -58,12 +58,13 @@ class Config: cmd_opts.noparallel, cmd_opts.noautoopen, ) - + # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # check `getattr` and try it for compatibility @staticmethod def has_mps() -> bool: - if not torch.backends.mps.is_available(): return False + if not torch.backends.mps.is_available(): + return False try: torch.zeros(1).to(torch.device("mps")) return True diff --git a/extract_feature_print.py b/extract_feature_print.py index 605d5dd..dfa74e2 100644 --- a/extract_feature_print.py +++ b/extract_feature_print.py @@ -1,6 +1,6 @@ import os, sys, traceback -os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # device=sys.argv[1] n_part = int(sys.argv[2]) diff --git a/infer-web.py b/infer-web.py index ca80dbd..7de75cc 100644 --- a/infer-web.py +++ b/infer-web.py @@ -43,7 +43,9 @@ logging.getLogger("numba").setLevel(logging.WARNING) tmp = os.path.join(now_dir, "TEMP") shutil.rmtree(tmp, ignore_errors=True) -shutil.rmtree("%s/runtime/Lib/site-packages/lib.infer_pack" % (now_dir), ignore_errors=True) +shutil.rmtree( + "%s/runtime/Lib/site-packages/lib.infer_pack" % (now_dir), ignore_errors=True +) shutil.rmtree("%s/runtime/Lib/site-packages/uvr5_pack" % (now_dir), ignore_errors=True) os.makedirs(tmp, exist_ok=True) os.makedirs(os.path.join(now_dir, "logs"), exist_ok=True) @@ -328,6 +330,7 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format ) if model_name == "onnx_dereverb_By_FoxJoy": from MDXNet import MDXNetDereverb + pre_fun = MDXNetDereverb(15) else: func = _audio_pre_ if "DeEcho" not in model_name else _audio_pre_new diff --git a/lib/infer_pack/onnx_inference.py b/lib/infer_pack/onnx_inference.py index a590627..b4aba75 100644 --- a/lib/infer_pack/onnx_inference.py +++ b/lib/infer_pack/onnx_inference.py @@ -39,7 +39,9 @@ def get_f0_predictor(f0_predictor, hop_length, sampling_rate, **kargs): hop_length=hop_length, sampling_rate=sampling_rate ) elif f0_predictor == "harvest": - from lib.infer_pack.modules.F0Predictor.HarvestF0Predictor import HarvestF0Predictor + from lib.infer_pack.modules.F0Predictor.HarvestF0Predictor import ( + HarvestF0Predictor, + ) f0_predictor_object = HarvestF0Predictor( hop_length=hop_length, sampling_rate=sampling_rate