From c3e65cdf96c3cb3ca1897ef2b69a74266ca69759 Mon Sep 17 00:00:00 2001 From: yxlllc <33565655+yxlllc@users.noreply.github.com> Date: Tue, 16 Jan 2024 19:22:55 +0800 Subject: [PATCH] optimize: realtime inference (#1693) * update real-time gui * update real-time gui * update real-time gui --- configs/config.json | 2 +- gui_v1.py | 170 +++++++++++++++++++++++++------------- i18n/locale/en_US.json | 2 +- i18n/locale/es_ES.json | 2 +- i18n/locale/fr_FR.json | 2 +- i18n/locale/it_IT.json | 2 +- i18n/locale/ja_JP.json | 2 +- i18n/locale/ru_RU.json | 2 +- i18n/locale/tr_TR.json | 2 +- i18n/locale/zh_CN.json | 4 +- i18n/locale/zh_HK.json | 2 +- i18n/locale/zh_SG.json | 2 +- i18n/locale/zh_TW.json | 2 +- tools/rvc_for_realtime.py | 43 ++++------ 14 files changed, 143 insertions(+), 96 deletions(-) diff --git a/configs/config.json b/configs/config.json index f874bd5..3c99324 100644 --- a/configs/config.json +++ b/configs/config.json @@ -1 +1 @@ -{"pth_path": "assets/weights/kikiV1.pth", "index_path": "logs/kikiV1.index", "sg_input_device": "VoiceMeeter Output (VB-Audio Vo (MME)", "sg_output_device": "VoiceMeeter Input (VB-Audio Voi (MME)", "sr_type": "sr_model", "threhold": -60.0, "pitch": 12.0, "rms_mix_rate": 0.5, "index_rate": 0.0, "block_time": 0.2, "crossfade_length": 0.08, "extra_time": 2.00, "n_cpu": 4.0, "use_jit": false, "use_pv": false, "f0method": "fcpe"} \ No newline at end of file +{"pth_path": "assets/weights/kikiV1.pth", "index_path": "logs/kikiV1.index", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "VoiceMeeter Output (VB-Audio Vo", "sg_output_device": "VoiceMeeter Input (VB-Audio Voi", "sr_type": "sr_device", "threhold": -60.0, "pitch": 12.0, "rms_mix_rate": 0.5, "index_rate": 0.0, "block_time": 0.15, "crossfade_length": 0.08, "extra_time": 2.0, "n_cpu": 4.0, "use_jit": false, "use_pv": false, "f0method": "fcpe"} \ No newline at end of file diff --git a/gui_v1.py b/gui_v1.py index ca5219e..2ce9a10 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -125,6 +125,8 @@ if __name__ == "__main__": self.index_rate: float = 0.0 self.n_cpu: int = min(n_cpu, 4) self.f0method: str = "fcpe" + self.sg_hostapi: str = "" + self.wasapi_exclusive: bool = False self.sg_input_device: str = "" self.sg_output_device: str = "" @@ -134,6 +136,7 @@ if __name__ == "__main__": self.config = Config() self.function = "vc" self.delay_time = 0 + self.hostapis = None self.input_devices = None self.output_devices = None self.input_devices_indices = None @@ -153,11 +156,22 @@ if __name__ == "__main__": data["crepe"] = data["f0method"] == "crepe" data["rmvpe"] = data["f0method"] == "rmvpe" data["fcpe"] = data["f0method"] == "fcpe" - if data["sg_input_device"] not in self.input_devices: + if data["sg_hostapi"] in self.hostapis: + self.update_devices(hostapi_name=data["sg_hostapi"]) + if data["sg_input_device"] not in self.input_devices or data["sg_output_device"] not in self.output_devices: + self.update_devices() + data["sg_hostapi"] = self.hostapis[0] + data["sg_input_device"] = self.input_devices[ + self.input_devices_indices.index(sd.default.device[0]) + ] + data["sg_output_device"] = self.output_devices[ + self.output_devices_indices.index(sd.default.device[1]) + ] + else: + data["sg_hostapi"] = self.hostapis[0] data["sg_input_device"] = self.input_devices[ self.input_devices_indices.index(sd.default.device[0]) ] - if data["sg_output_device"] not in self.output_devices: data["sg_output_device"] = self.output_devices[ self.output_devices_indices.index(sd.default.device[1]) ] @@ -166,6 +180,8 @@ if __name__ == "__main__": data = { "pth_path": "", "index_path": "", + "sg_hostapi": self.hostapis[0], + "sg_wasapi_exclusive": False, "sg_input_device": self.input_devices[ self.input_devices_indices.index(sd.default.device[0]) ], @@ -233,12 +249,30 @@ if __name__ == "__main__": [ sg.Frame( layout=[ + [ + sg.Text(i18n("设备类型")), + sg.Combo( + self.hostapis, + key="sg_hostapi", + default_value=data.get("sg_hostapi", ""), + enable_events=True, + size=(20, 1) + ), + sg.Checkbox( + i18n("独占 WASAPI 设备"), + key="sg_wasapi_exclusive", + default=data.get("sg_wasapi_exclusive", False), + enable_events=True, + ), + ], [ sg.Text(i18n("输入设备")), sg.Combo( self.input_devices, key="sg_input_device", default_value=data.get("sg_input_device", ""), + enable_events=True, + size=(45, 1), ), ], [ @@ -247,6 +281,8 @@ if __name__ == "__main__": self.output_devices, key="sg_output_device", default_value=data.get("sg_output_device", ""), + enable_events=True, + size=(45, 1), ), ], [ @@ -269,7 +305,7 @@ if __name__ == "__main__": sg.Text("", key="sr_stream"), ], ], - title=i18n("音频设备(请使用同种类驱动)"), + title=i18n("音频设备"), ) ], [ @@ -365,7 +401,7 @@ if __name__ == "__main__": [ sg.Text(i18n("采样长度")), sg.Slider( - range=(0.02, 2.4), + range=(0.02, 1.5), key="block_time", resolution=0.01, orientation="h", @@ -481,8 +517,15 @@ if __name__ == "__main__": if event == sg.WINDOW_CLOSED: self.stop_stream() exit() - if event == "reload_devices": - self.update_devices() + if event == "reload_devices" or event == "sg_hostapi": + self.gui_config.sg_hostapi = values["sg_hostapi"] + self.update_devices(hostapi_name=values["sg_hostapi"]) + if self.gui_config.sg_hostapi not in self.hostapis: + self.gui_config.sg_hostapi = self.hostapis[0] + self.window["sg_hostapi"].Update(values=self.hostapis) + self.window["sg_hostapi"].Update( + value=self.gui_config.sg_hostapi + ) if self.gui_config.sg_input_device not in self.input_devices: self.gui_config.sg_input_device = self.input_devices[0] self.window["sg_input_device"].Update(values=self.input_devices) @@ -502,6 +545,8 @@ if __name__ == "__main__": settings = { "pth_path": values["pth_path"], "index_path": values["index_path"], + "sg_hostapi": values["sg_hostapi"], + "sg_wasapi_exclusive": values["sg_wasapi_exclusive"], "sg_input_device": values["sg_input_device"], "sg_output_device": values["sg_output_device"], "sr_type": ["sr_model", "sr_device"][ @@ -544,7 +589,7 @@ if __name__ == "__main__": if values["I_noise_reduce"]: self.delay_time += min(values["crossfade_length"], 0.04) self.window["sr_stream"].update(self.gui_config.samplerate) - self.window["delay_time"].update(int(self.delay_time * 1000)) + self.window["delay_time"].update(int(np.round(self.delay_time * 1000))) # Parameter hot update if event == "threhold": self.gui_config.threhold = values["threhold"] @@ -566,7 +611,7 @@ if __name__ == "__main__": self.delay_time += ( 1 if values["I_noise_reduce"] else -1 ) * min(values["crossfade_length"], 0.04) - self.window["delay_time"].update(int(self.delay_time * 1000)) + self.window["delay_time"].update(int(np.round(self.delay_time * 1000))) elif event == "O_noise_reduce": self.gui_config.O_noise_reduce = values["O_noise_reduce"] elif event == "use_pv": @@ -594,6 +639,8 @@ if __name__ == "__main__": self.set_devices(values["sg_input_device"], values["sg_output_device"]) self.config.use_jit = False # values["use_jit"] # self.device_latency = values["device_latency"] + self.gui_config.sg_hostapi = values["sg_hostapi"] + self.gui_config.sg_wasapi_exclusive = values["sg_wasapi_exclusive"] self.gui_config.sg_input_device = values["sg_input_device"] self.gui_config.sg_output_device = values["sg_output_device"] self.gui_config.pth_path = values["pth_path"] @@ -644,6 +691,7 @@ if __name__ == "__main__": if self.gui_config.sr_type == "sr_model" else self.get_device_samplerate() ) + self.gui_config.channels = self.get_device_channels() self.zc = self.gui_config.samplerate // 100 self.block_frame = ( int( @@ -686,19 +734,18 @@ if __name__ == "__main__": device=self.config.device, dtype=torch.float32, ) + self.input_wav_denoise: torch.Tensor = self.input_wav.clone() self.input_wav_res: torch.Tensor = torch.zeros( 160 * self.input_wav.shape[0] // self.zc, device=self.config.device, dtype=torch.float32, ) + self.rms_buffer: np.ndarray = np.zeros(4 * self.zc, dtype="float32") self.sola_buffer: torch.Tensor = torch.zeros( self.sola_buffer_frame, device=self.config.device, dtype=torch.float32 ) self.nr_buffer: torch.Tensor = self.sola_buffer.clone() self.output_buffer: torch.Tensor = self.input_wav.clone() - self.res_buffer: torch.Tensor = torch.zeros( - 2 * self.zc, device=self.config.device, dtype=torch.float32 - ) self.skip_head = self.extra_frame // self.zc self.return_length = ( self.block_frame + self.sola_buffer_frame + self.sola_search_frame @@ -740,13 +787,17 @@ if __name__ == "__main__": global flag_vc if not flag_vc: flag_vc = True - channels = 1 if sys.platform == "darwin" else 2 + if 'WASAPI' in self.gui_config.sg_hostapi and self.gui_config.sg_wasapi_exclusive: + extra_settings = sd.WasapiSettings(exclusive=True) + else: + extra_settings = None self.stream = sd.Stream( - channels=channels, callback=self.audio_callback, blocksize=self.block_frame, samplerate=self.gui_config.samplerate, + channels=self.gui_config.channels, dtype="float32", + extra_settings=extra_settings, ) self.stream.start() @@ -755,7 +806,7 @@ if __name__ == "__main__": if flag_vc: flag_vc = False if self.stream is not None: - self.stream.stop() + self.stream.abort() self.stream.close() self.stream = None @@ -769,47 +820,51 @@ if __name__ == "__main__": start_time = time.perf_counter() indata = librosa.to_mono(indata.T) if self.gui_config.threhold > -60: + indata = np.append(self.rms_buffer, indata) rms = librosa.feature.rms( y=indata, frame_length=4 * self.zc, hop_length=self.zc - ) + )[:, 2:] + self.rms_buffer[:] = indata[-4 * self.zc :] + indata = indata[2 * self.zc - self.zc // 2 :] db_threhold = ( librosa.amplitude_to_db(rms, ref=1.0)[0] < self.gui_config.threhold ) for i in range(db_threhold.shape[0]): if db_threhold[i]: indata[i * self.zc : (i + 1) * self.zc] = 0 + indata = indata[self.zc // 2 :] self.input_wav[: -self.block_frame] = self.input_wav[ self.block_frame : ].clone() - self.input_wav[-self.block_frame :] = torch.from_numpy(indata).to( + self.input_wav[-indata.shape[0] :] = torch.from_numpy(indata).to( self.config.device ) self.input_wav_res[: -self.block_frame_16k] = self.input_wav_res[ self.block_frame_16k : ].clone() # input noise reduction and resampling - if self.gui_config.I_noise_reduce and self.function == "vc": + if self.gui_config.I_noise_reduce: + self.input_wav_denoise[: -self.block_frame] = self.input_wav_denoise[ + self.block_frame : + ].clone() input_wav = self.input_wav[ - -self.sola_buffer_frame - self.block_frame - 2 * self.zc : + -self.sola_buffer_frame - self.block_frame: ] input_wav = self.tg( input_wav.unsqueeze(0), self.input_wav.unsqueeze(0) - )[0, 2 * self.zc :] + ).squeeze(0) input_wav[: self.sola_buffer_frame] *= self.fade_in_window input_wav[: self.sola_buffer_frame] += ( self.nr_buffer * self.fade_out_window ) + self.input_wav_denoise[-self.block_frame :] = input_wav[: self.block_frame] self.nr_buffer[:] = input_wav[self.block_frame :] - input_wav = torch.cat( - (self.res_buffer[:], input_wav[: self.block_frame]) - ) - self.res_buffer[:] = input_wav[-2 * self.zc :] self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler( - input_wav + self.input_wav_denoise[-self.block_frame - 2 * self.zc :] )[160:] else: - self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler( - self.input_wav[-self.block_frame - 2 * self.zc :] + self.input_wav_res[-160 * (indata.shape[0] // self.zc + 1) :] = self.resampler( + self.input_wav[-indata.shape[0] - 2 * self.zc :] )[160:] # infer if self.function == "vc": @@ -822,14 +877,12 @@ if __name__ == "__main__": ) if self.resampler2 is not None: infer_wav = self.resampler2(infer_wav) + elif self.gui_config.I_noise_reduce: + infer_wav = self.input_wav_denoise[self.extra_frame :].clone() else: - infer_wav = self.input_wav[ - -self.crossfade_frame - self.sola_search_frame - self.block_frame : - ].clone() + infer_wav = self.input_wav[self.extra_frame :].clone() # output noise reduction - if (self.gui_config.O_noise_reduce and self.function == "vc") or ( - self.gui_config.I_noise_reduce and self.function == "im" - ): + if self.gui_config.O_noise_reduce and self.function == "vc": self.output_buffer[: -self.block_frame] = self.output_buffer[ self.block_frame : ].clone() @@ -839,16 +892,14 @@ if __name__ == "__main__": ).squeeze(0) # volume envelop mixing if self.gui_config.rms_mix_rate < 1 and self.function == "vc": + if self.gui_config.I_noise_reduce: + input_wav = self.input_wav_denoise[self.extra_frame :] + else: + input_wav = self.input_wav[self.extra_frame :] rms1 = librosa.feature.rms( - y=self.input_wav_res[ - 160 - * self.skip_head : 160 - * (self.skip_head + self.return_length) - ] - .cpu() - .numpy(), - frame_length=640, - hop_length=160, + y=input_wav[: infer_wav.shape[0]].cpu().numpy(), + frame_length=4 * self.zc, + hop_length=self.zc, ) rms1 = torch.from_numpy(rms1).to(self.config.device) rms1 = F.interpolate( @@ -907,19 +958,18 @@ if __name__ == "__main__": self.sola_buffer[:] = infer_wav[ self.block_frame : self.block_frame + self.sola_buffer_frame ] - if sys.platform == "darwin": - outdata[:] = infer_wav[: self.block_frame].cpu().numpy()[:, np.newaxis] - else: - outdata[:] = ( - infer_wav[: self.block_frame].repeat(2, 1).t().cpu().numpy() - ) + outdata[:] = ( + infer_wav[: self.block_frame].repeat(self.gui_config.channels, 1).t().cpu().numpy() + ) total_time = time.perf_counter() - start_time if flag_vc: self.window["infer_time"].update(int(total_time * 1000)) printt("Infer time: %.2f", total_time) - - def update_devices(self): + + def update_devices(self, hostapi_name=None): """获取设备列表""" + global flag_vc + flag_vc = False sd._terminate() sd._initialize() devices = sd.query_devices() @@ -927,25 +977,26 @@ if __name__ == "__main__": for hostapi in hostapis: for device_idx in hostapi["devices"]: devices[device_idx]["hostapi_name"] = hostapi["name"] + self.hostapis = [hostapi["name"] for hostapi in hostapis] + if hostapi_name not in self.hostapis: + hostapi_name = self.hostapis[0] self.input_devices = [ - f"{d['name']} ({d['hostapi_name']})" - for d in devices - if d["max_input_channels"] > 0 + d['name'] for d in devices + if d["max_input_channels"] > 0 and d['hostapi_name'] == hostapi_name ] self.output_devices = [ - f"{d['name']} ({d['hostapi_name']})" - for d in devices - if d["max_output_channels"] > 0 + d['name'] for d in devices + if d["max_output_channels"] > 0 and d['hostapi_name'] == hostapi_name ] self.input_devices_indices = [ d["index"] if "index" in d else d["name"] for d in devices - if d["max_input_channels"] > 0 + if d["max_input_channels"] > 0 and d['hostapi_name'] == hostapi_name ] self.output_devices_indices = [ d["index"] if "index" in d else d["name"] for d in devices - if d["max_output_channels"] > 0 + if d["max_output_channels"] > 0 and d['hostapi_name'] == hostapi_name ] def set_devices(self, input_device, output_device): @@ -963,5 +1014,10 @@ if __name__ == "__main__": return int( sd.query_devices(device=sd.default.device[0])["default_samplerate"] ) + + def get_device_channels(self): + max_input_channels = sd.query_devices(device=sd.default.device[0])["max_input_channels"] + max_output_channels = sd.query_devices(device=sd.default.device[1])["max_output_channels"] + return min(max_input_channels, max_output_channels, 2) gui = GUI() diff --git a/i18n/locale/en_US.json b/i18n/locale/en_US.json index d585505..002844f 100644 --- a/i18n/locale/en_US.json +++ b/i18n/locale/en_US.json @@ -129,7 +129,7 @@ "采样长度": "Sample length", "重载设备列表": "Reload device list", "音调设置": "Pitch settings", - "音频设备(请使用同种类驱动)": "Audio device (please use the same type of driver)", + "音频设备": "Audio device", "音高算法": "pitch detection algorithm", "额外推理时长": "Extra inference time" } diff --git a/i18n/locale/es_ES.json b/i18n/locale/es_ES.json index 08b8176..742564e 100644 --- a/i18n/locale/es_ES.json +++ b/i18n/locale/es_ES.json @@ -129,7 +129,7 @@ "采样长度": "Longitud de muestreo", "重载设备列表": "Actualizar lista de dispositivos", "音调设置": "Ajuste de tono", - "音频设备(请使用同种类驱动)": "Dispositivo de audio (utilice el mismo tipo de controlador)", + "音频设备": "Dispositivo de audio", "音高算法": "Algoritmo de tono", "额外推理时长": "Tiempo de inferencia adicional" } diff --git a/i18n/locale/fr_FR.json b/i18n/locale/fr_FR.json index db93e9a..95b746f 100644 --- a/i18n/locale/fr_FR.json +++ b/i18n/locale/fr_FR.json @@ -129,7 +129,7 @@ "采样长度": "Longueur de l'échantillon", "重载设备列表": "Recharger la liste des dispositifs", "音调设置": "Réglages de la hauteur", - "音频设备(请使用同种类驱动)": "Périphérique audio (veuillez utiliser le même type de pilote)", + "音频设备": "Périphérique audio", "音高算法": "algorithme de détection de la hauteur", "额外推理时长": "Temps d'inférence supplémentaire" } diff --git a/i18n/locale/it_IT.json b/i18n/locale/it_IT.json index dc089be..34c05e2 100644 --- a/i18n/locale/it_IT.json +++ b/i18n/locale/it_IT.json @@ -129,7 +129,7 @@ "采样长度": "Lunghezza del campione", "重载设备列表": "Ricaricare l'elenco dei dispositivi", "音调设置": "Impostazioni del tono", - "音频设备(请使用同种类驱动)": "Dispositivo audio (utilizzare lo stesso tipo di driver)", + "音频设备": "Dispositivo audio", "音高算法": "音高算法", "额外推理时长": "Tempo di inferenza extra" } diff --git a/i18n/locale/ja_JP.json b/i18n/locale/ja_JP.json index c5b33ff..061b7ec 100644 --- a/i18n/locale/ja_JP.json +++ b/i18n/locale/ja_JP.json @@ -129,7 +129,7 @@ "采样长度": "サンプル長", "重载设备列表": "デバイスリストをリロードする", "音调设置": "音程設定", - "音频设备(请使用同种类驱动)": "オーディオデバイス(同じ種類のドライバーを使用してください)", + "音频设备": "オーディオデバイス", "音高算法": "ピッチアルゴリズム", "额外推理时长": "追加推論時間" } diff --git a/i18n/locale/ru_RU.json b/i18n/locale/ru_RU.json index f01bc8f..6f6b0c0 100644 --- a/i18n/locale/ru_RU.json +++ b/i18n/locale/ru_RU.json @@ -129,7 +129,7 @@ "采样长度": "Длина сэмпла", "重载设备列表": "Обновить список устройств", "音调设置": "Настройка высоты звука", - "音频设备(请使用同种类驱动)": "Аудиоустройство (пожалуйста, используйте такой же тип драйвера)", + "音频设备": "Аудиоустройство", "音高算法": "Алгоритм оценки высоты звука", "额外推理时长": "Доп. время переработки" } diff --git a/i18n/locale/tr_TR.json b/i18n/locale/tr_TR.json index bd1c17b..a40ee0a 100644 --- a/i18n/locale/tr_TR.json +++ b/i18n/locale/tr_TR.json @@ -129,7 +129,7 @@ "采样长度": "Örnekleme uzunluğu", "重载设备列表": "Cihaz listesini yeniden yükle", "音调设置": "Pitch ayarları", - "音频设备(请使用同种类驱动)": "Ses cihazı (aynı tür sürücüyü kullanın)", + "音频设备": "Ses cihazı", "音高算法": "音高算法", "额外推理时长": "Ekstra çıkartma süresi" } diff --git a/i18n/locale/zh_CN.json b/i18n/locale/zh_CN.json index 32ca5ef..9c12fa6 100644 --- a/i18n/locale/zh_CN.json +++ b/i18n/locale/zh_CN.json @@ -106,6 +106,8 @@ "请选择pth文件": "请选择pth文件", "请选择说话人id": "请选择说话人id", "转换": "转换", + "设备类型": "设备类型", + "独占 WASAPI 设备": "独占 WASAPI 设备", "输入实验名": "输入实验名", "输入待处理音频文件夹路径": "输入待处理音频文件夹路径", "输入待处理音频文件夹路径(去文件管理器地址栏拷就行了)": "输入待处理音频文件夹路径(去文件管理器地址栏拷就行了)", @@ -129,7 +131,7 @@ "采样长度": "采样长度", "重载设备列表": "重载设备列表", "音调设置": "音调设置", - "音频设备(请使用同种类驱动)": "音频设备(请使用同种类驱动)", + "音频设备": "音频设备", "音高算法": "音高算法", "额外推理时长": "额外推理时长" } diff --git a/i18n/locale/zh_HK.json b/i18n/locale/zh_HK.json index 93aaff3..2395614 100644 --- a/i18n/locale/zh_HK.json +++ b/i18n/locale/zh_HK.json @@ -129,7 +129,7 @@ "采样长度": "取樣長度", "重载设备列表": "重載設備列表", "音调设置": "音調設定", - "音频设备(请使用同种类驱动)": "音訊設備 (請使用同種類驅動)", + "音频设备": "音訊設備", "音高算法": "音高演算法", "额外推理时长": "額外推理時長" } diff --git a/i18n/locale/zh_SG.json b/i18n/locale/zh_SG.json index 93aaff3..2395614 100644 --- a/i18n/locale/zh_SG.json +++ b/i18n/locale/zh_SG.json @@ -129,7 +129,7 @@ "采样长度": "取樣長度", "重载设备列表": "重載設備列表", "音调设置": "音調設定", - "音频设备(请使用同种类驱动)": "音訊設備 (請使用同種類驅動)", + "音频设备": "音訊設備", "音高算法": "音高演算法", "额外推理时长": "額外推理時長" } diff --git a/i18n/locale/zh_TW.json b/i18n/locale/zh_TW.json index 93aaff3..2395614 100644 --- a/i18n/locale/zh_TW.json +++ b/i18n/locale/zh_TW.json @@ -129,7 +129,7 @@ "采样长度": "取樣長度", "重载设备列表": "重載設備列表", "音调设置": "音調設定", - "音频设备(请使用同种类驱动)": "音訊設備 (請使用同種類驅動)", + "音频设备": "音訊設備", "音高算法": "音高演算法", "额外推理时长": "額外推理時長" } diff --git a/tools/rvc_for_realtime.py b/tools/rvc_for_realtime.py index ff1ea88..a4daf90 100644 --- a/tools/rvc_for_realtime.py +++ b/tools/rvc_for_realtime.py @@ -91,8 +91,8 @@ class RVC: self.pth_path: str = pth_path self.index_path = index_path self.index_rate = index_rate - self.cache_pitch: np.ndarray = np.zeros(1024, dtype="int32") - self.cache_pitchf = np.zeros(1024, dtype="float32") + self.cache_pitch: torch.Tensor = torch.zeros(1024, device=self.device, dtype=torch.long) + self.cache_pitchf = torch.zeros(1024, device=self.device, dtype=torch.float32) if last_rvc is None: models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( @@ -199,15 +199,17 @@ class RVC: self.index_rate = new_index_rate def get_f0_post(self, f0): - f0bak = f0.copy() - f0_mel = 1127 * np.log(1 + f0 / 700) + if not torch.is_tensor(f0): + f0 = torch.from_numpy(f0) + f0 = f0.float().to(self.device).squeeze() + f0_mel = 1127 * torch.log(1 + f0 / 700) f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / ( self.f0_mel_max - self.f0_mel_min ) + 1 f0_mel[f0_mel <= 1] = 1 f0_mel[f0_mel > 255] = 255 - f0_coarse = np.rint(f0_mel).astype(np.int32) - return f0_coarse, f0bak + f0_coarse = torch.round(f0_mel).long() + return f0_coarse, f0 def get_f0(self, x, f0_up_key, n_cpu, method="harvest"): n_cpu = int(n_cpu) @@ -299,14 +301,12 @@ class RVC: pd = torchcrepe.filter.median(pd, 3) f0 = torchcrepe.filter.mean(f0, 3) f0[pd < 0.1] = 0 - f0 = f0[0].cpu().numpy() f0 *= pow(2, f0_up_key / 12) return self.get_f0_post(f0) def get_f0_rmvpe(self, x, f0_up_key): if hasattr(self, "model_rmvpe") == False: from infer.lib.rmvpe import RMVPE - printt("Loading rmvpe model") self.model_rmvpe = RMVPE( "assets/rmvpe/rmvpe.pt", @@ -335,7 +335,6 @@ class RVC: threshold=0.006, ) f0 *= pow(2, f0_up_key / 12) - f0 = f0.squeeze().cpu().numpy() return self.get_f0_post(f0) def infer( @@ -383,6 +382,7 @@ class RVC: traceback.print_exc() printt("Index search FAILED") t3 = ttime() + p_len = input_wav.shape[0] // 160 if self.if_f0 == 1: f0_extractor_frame = block_frame_16k + 800 if f0method == "rmvpe": @@ -390,25 +390,14 @@ class RVC: pitch, pitchf = self.get_f0( input_wav[-f0_extractor_frame:], self.f0_up_key, self.n_cpu, f0method ) - start_frame = block_frame_16k // 160 - end_frame = len(self.cache_pitch) - (pitch.shape[0] - 4) + start_frame - self.cache_pitch[:] = np.append( - self.cache_pitch[start_frame:end_frame], pitch[3:-1] - ) - self.cache_pitchf[:] = np.append( - self.cache_pitchf[start_frame:end_frame], pitchf[3:-1] - ) + shift = block_frame_16k // 160 + self.cache_pitch[: -shift] = self.cache_pitch[shift :].clone() + self.cache_pitchf[: -shift] = self.cache_pitchf[shift :].clone() + self.cache_pitch[4 - pitch.shape[0] :] = pitch[3:-1] + self.cache_pitchf[4 - pitch.shape[0] :] = pitchf[3:-1] + cache_pitch = self.cache_pitch[None, -p_len:] + cache_pitchf = self.cache_pitchf[None, -p_len:] t4 = ttime() - p_len = input_wav.shape[0] // 160 - if self.if_f0 == 1: - cache_pitch = ( - torch.LongTensor(self.cache_pitch[-p_len:]).to(self.device).unsqueeze(0) - ) - cache_pitchf = ( - torch.FloatTensor(self.cache_pitchf[-p_len:]) - .to(self.device) - .unsqueeze(0) - ) feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) feats = feats[:, :p_len, :] p_len = torch.LongTensor([p_len]).to(self.device)