From 2b3fe8cf1bab92194dbde96b6d454032d312de6f Mon Sep 17 00:00:00 2001 From: Naozumi Date: Mon, 17 Jul 2023 22:54:15 +0800 Subject: [PATCH] fix mps in gui-v1.py (#769) * Fix mps on realtime * Added back repeat chs --- gui_v1.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/gui_v1.py b/gui_v1.py index 8aa999e..1391efe 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -1,5 +1,8 @@ import os, sys +if sys.platform == "darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + now_dir = os.getcwd() sys.path.append(now_dir) import multiprocessing @@ -45,7 +48,7 @@ if __name__ == "__main__": from i18n import I18nAuto i18n = I18nAuto() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")) current_dir = os.getcwd() inp_q = Queue() opt_q = Queue() @@ -441,8 +444,9 @@ if __name__ == "__main__": """ 接受音频输入 """ + channels = 1 if sys.platform == "darwin" else 2 with sd.Stream( - channels=2, + channels=channels, callback=self.audio_callback, blocksize=self.block_frame, samplerate=self.config.samplerate, @@ -524,6 +528,9 @@ if __name__ == "__main__": ) + 1e-8 ) + if sys.platform == "darwin": + cor_nom = cor_nom.cpu() + cor_den = cor_den.cpu() sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0]) print("sola offset: " + str(int(sola_offset))) self.output_wav[:] = infer_wav[sola_offset : sola_offset + self.block_frame] @@ -545,14 +552,23 @@ if __name__ == "__main__": infer_wav[-self.crossfade_frame :] * self.fade_out_window ) if self.config.O_noise_reduce: - outdata[:] = np.tile( - nr.reduce_noise( + if sys.platform == "darwin": + noise_reduced_signal = nr.reduce_noise( y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate - ), - (2, 1), - ).T + ) + outdata[:] = noise_reduced_signal[:, np.newaxis] + else: + outdata[:] = np.tile( + nr.reduce_noise( + y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate + ), + (2, 1), + ).T else: - outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy() + if sys.platform == "darwin": + outdata[:] = self.output_wav[:].cpu().numpy()[:, np.newaxis] + else: + outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy() total_time = time.perf_counter() - start_time self.window["infer_time"].update(int(total_time * 1000)) print("infer time:" + str(total_time))