fix mps in gui-v1.py (#769)

* Fix mps on realtime

* Added back repeat chs
This commit is contained in:
Naozumi 2023-07-17 22:54:15 +08:00 committed by GitHub
parent 2e0dfeec50
commit 2b3fe8cf1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 24 additions and 8 deletions

View File

@ -1,5 +1,8 @@
import os, sys
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
now_dir = os.getcwd()
sys.path.append(now_dir)
import multiprocessing
@ -45,7 +48,7 @@ if __name__ == "__main__":
from i18n import I18nAuto
i18n = I18nAuto()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
current_dir = os.getcwd()
inp_q = Queue()
opt_q = Queue()
@ -441,8 +444,9 @@ if __name__ == "__main__":
"""
接受音频输入
"""
channels = 1 if sys.platform == "darwin" else 2
with sd.Stream(
channels=2,
channels=channels,
callback=self.audio_callback,
blocksize=self.block_frame,
samplerate=self.config.samplerate,
@ -524,6 +528,9 @@ if __name__ == "__main__":
)
+ 1e-8
)
if sys.platform == "darwin":
cor_nom = cor_nom.cpu()
cor_den = cor_den.cpu()
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
print("sola offset: " + str(int(sola_offset)))
self.output_wav[:] = infer_wav[sola_offset : sola_offset + self.block_frame]
@ -545,14 +552,23 @@ if __name__ == "__main__":
infer_wav[-self.crossfade_frame :] * self.fade_out_window
)
if self.config.O_noise_reduce:
outdata[:] = np.tile(
nr.reduce_noise(
if sys.platform == "darwin":
noise_reduced_signal = nr.reduce_noise(
y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate
),
(2, 1),
).T
)
outdata[:] = noise_reduced_signal[:, np.newaxis]
else:
outdata[:] = np.tile(
nr.reduce_noise(
y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate
),
(2, 1),
).T
else:
outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy()
if sys.platform == "darwin":
outdata[:] = self.output_wav[:].cpu().numpy()[:, np.newaxis]
else:
outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy()
total_time = time.perf_counter() - start_time
self.window["infer_time"].update(int(total_time * 1000))
print("infer time:" + str(total_time))