From d255818097002efb5be84b977a080b88dc105b09 Mon Sep 17 00:00:00 2001 From: Ftps <63702646+Tps-F@users.noreply.github.com> Date: Sun, 5 May 2024 15:22:37 +0900 Subject: [PATCH] Fix memory doesn't unload on mps device (#2035) * Solving the cache not being cleared in mps * Fix protect not to be dict --- infer/modules/uvr5/modules.py | 3 +++ infer/modules/vc/modules.py | 16 ++++++---------- infer/modules/vc/pipeline.py | 4 ++++ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/infer/modules/uvr5/modules.py b/infer/modules/uvr5/modules.py index bce3cef..2abb62d 100644 --- a/infer/modules/uvr5/modules.py +++ b/infer/modules/uvr5/modules.py @@ -105,4 +105,7 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Executed torch.cuda.empty_cache()") + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() + logger.info("Executed torch.mps.empty_cache()") yield "\n".join(infos) diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index 6f695cc..9069072 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -62,6 +62,8 @@ class VC: ) = None if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() ###楼下不这么折腾清理不干净 self.if_f0 = self.cpt.get("f0", 1) self.version = self.cpt.get("version", "v1") @@ -82,18 +84,12 @@ class VC: del self.net_g, self.cpt if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() return ( {"visible": False, "__type__": "update"}, - { - "visible": True, - "value": to_return_protect0, - "__type__": "update", - }, - { - "visible": True, - "value": to_return_protect1, - "__type__": "update", - }, + to_return_protect0, + to_return_protect1, "", "", ) diff --git a/infer/modules/vc/pipeline.py b/infer/modules/vc/pipeline.py index 371836b..4b6e39e 100644 --- a/infer/modules/vc/pipeline.py +++ b/infer/modules/vc/pipeline.py @@ -291,6 +291,8 @@ class Pipeline(object): del feats, p_len, padding_mask if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() t2 = ttime() times[0] += t1 - t0 times[2] += t2 - t1 @@ -472,4 +474,6 @@ class Pipeline(object): del pitch, pitchf, sid if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() return audio_opt