diff --git a/infer/modules/uvr5/modules.py b/infer/modules/uvr5/modules.py index bce3cef..2abb62d 100644 --- a/infer/modules/uvr5/modules.py +++ b/infer/modules/uvr5/modules.py @@ -105,4 +105,7 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Executed torch.cuda.empty_cache()") + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() + logger.info("Executed torch.mps.empty_cache()") yield "\n".join(infos) diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index 6f695cc..9069072 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -62,6 +62,8 @@ class VC: ) = None if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() ###楼下不这么折腾清理不干净 self.if_f0 = self.cpt.get("f0", 1) self.version = self.cpt.get("version", "v1") @@ -82,18 +84,12 @@ class VC: del self.net_g, self.cpt if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() return ( {"visible": False, "__type__": "update"}, - { - "visible": True, - "value": to_return_protect0, - "__type__": "update", - }, - { - "visible": True, - "value": to_return_protect1, - "__type__": "update", - }, + to_return_protect0, + to_return_protect1, "", "", ) diff --git a/infer/modules/vc/pipeline.py b/infer/modules/vc/pipeline.py index 371836b..4b6e39e 100644 --- a/infer/modules/vc/pipeline.py +++ b/infer/modules/vc/pipeline.py @@ -291,6 +291,8 @@ class Pipeline(object): del feats, p_len, padding_mask if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() t2 = ttime() times[0] += t1 - t0 times[2] += t2 - t1 @@ -472,4 +474,6 @@ class Pipeline(object): del pitch, pitchf, sid if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() return audio_opt