diff --git a/go-realtime-gui.bat b/go-realtime-gui.bat index ed07321..8c08290 100644 --- a/go-realtime-gui.bat +++ b/go-realtime-gui.bat @@ -1,2 +1,2 @@ -runtime\python.exe gui.py +runtime\python.exe gui_v1.py pause diff --git a/gui_v1.py b/gui_v1.py new file mode 100644 index 0000000..9dff0e5 --- /dev/null +++ b/gui_v1.py @@ -0,0 +1,498 @@ +import os,sys +now_dir = os.getcwd() +sys.path.append(now_dir) +import multiprocessing +class Harvest(multiprocessing.Process): + def __init__(self,inp_q,opt_q): + multiprocessing.Process.__init__(self) + self.inp_q=inp_q + self.opt_q=opt_q + + def run(self): + import numpy as np, pyworld + while(1): + idx, x, res_f0,n_cpu,ts=self.inp_q.get() + f0,t=pyworld.harvest( + x.astype(np.double), + fs=16000, + f0_ceil=1100, + f0_floor=50, + frame_period=10, + ) + res_f0[idx]=f0 + if(len(res_f0.keys())>=n_cpu): + self.opt_q.put(ts) + +if __name__ == '__main__': + from multiprocessing import Queue + from queue import Empty + import numpy as np + import multiprocessing + import traceback, re + import json + import PySimpleGUI as sg + import sounddevice as sd + import noisereduce as nr + from multiprocessing import cpu_count + import librosa, torch, time, threading + import torch.nn.functional as F + import torchaudio.transforms as tat + from i18n import I18nAuto + + i18n = I18nAuto() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + current_dir = os.getcwd() + inp_q = Queue() + opt_q=Queue() + n_cpu=min(cpu_count(),8) + for _ in range(n_cpu): + Harvest(inp_q,opt_q).start() + from rvc_for_realtime import RVC + class GUIConfig: + def __init__(self) -> None: + self.pth_path: str = "" + self.index_path: str = "" + self.pitch: int = 12 + self.samplerate: int = 40000 + self.block_time: float = 1.0 # s + self.buffer_num: int = 1 + self.threhold: int = -30 + self.crossfade_time: float = 0.08 + self.extra_time: float = 0.04 + self.I_noise_reduce = False + self.O_noise_reduce = False + self.index_rate = 0.3 + self.n_cpu=min(n_cpu,8) + self.f0method="harvest" + + + class GUI: + def __init__(self) -> None: + self.config = GUIConfig() + self.flag_vc = False + + self.launcher() + + def load(self): + input_devices, output_devices, _, _ = self.get_devices() + try: + with open("values1.json", "r") as j: + data = json.load(j) + data["pm"]=data["f0method"]=="pm" + data["harvest"]=data["f0method"]=="harvest" + data["crepe"]=data["f0method"]=="crepe" + except: + with open("values1.json", "w") as j: + data = { + "pth_path": " ", + "index_path": " ", + "sg_input_device": input_devices[sd.default.device[0]], + "sg_output_device": output_devices[sd.default.device[1]], + "threhold": "-45", + "pitch": "0", + "index_rate": "0", + "block_time": "1", + "crossfade_length": "0.04", + "extra_time": "1", + "f0method": "harvest", + } + return data + + def launcher(self): + data = self.load() + sg.theme("LightBlue3") + input_devices, output_devices, _, _ = self.get_devices() + layout = [ + [ + sg.Frame( + title=i18n("加载模型"), + layout=[ + [ + sg.Input( + default_text=data.get("pth_path", ""), + key="pth_path", + ), + sg.FileBrowse( + i18n("选择.pth文件"), + initial_folder=os.path.join(os.getcwd(), "weights"), + file_types=((". pth"),), + ), + ], + [ + sg.Input( + default_text=data.get("index_path", ""), + key="index_path", + ), + sg.FileBrowse( + i18n("选择.index文件"), + initial_folder=os.path.join(os.getcwd(), "logs"), + file_types=((". index"),), + ), + ], + ], + ) + ], + [ + sg.Frame( + layout=[ + [ + sg.Text(i18n("输入设备")), + sg.Combo( + input_devices, + key="sg_input_device", + default_value=data.get("sg_input_device", ""), + ), + ], + [ + sg.Text(i18n("输出设备")), + sg.Combo( + output_devices, + key="sg_output_device", + default_value=data.get("sg_output_device", ""), + ), + ], + ], + title=i18n("音频设备(请使用同种类驱动)"), + ) + ], + [ + sg.Frame( + layout=[ + [ + sg.Text(i18n("响应阈值")), + sg.Slider( + range=(-60, 0), + key="threhold", + resolution=1, + orientation="h", + default_value=data.get("threhold", ""), + ), + ], + [ + sg.Text(i18n("音调设置")), + sg.Slider( + range=(-24, 24), + key="pitch", + resolution=1, + orientation="h", + default_value=data.get("pitch", ""), + ), + ], + [ + sg.Text(i18n("Index Rate")), + sg.Slider( + range=(0.0, 1.0), + key="index_rate", + resolution=0.01, + orientation="h", + default_value=data.get("index_rate", ""), + ), + ], + [ + sg.Text(i18n("音高算法")), + sg.Radio("pm","f0method",key="pm",default=data.get("pm","")==True), + sg.Radio("harvest","f0method",key="harvest",default=data.get("harvest","")==True), + sg.Radio("crepe","f0method",key="crepe",default=data.get("crepe","")==True), + ], + ], + title=i18n("常规设置"), + ), + sg.Frame( + layout=[ + [ + sg.Text(i18n("采样长度")), + sg.Slider( + range=(0.12, 2.4), + key="block_time", + resolution=0.03, + orientation="h", + default_value=data.get("block_time", ""), + ), + ], + [ + sg.Text(i18n("harvest进程数")), + sg.Slider( + range=(1, n_cpu), + key="n_cpu", + resolution=1, + orientation="h", + default_value=data.get("n_cpu", min(self.config.n_cpu,n_cpu)), + ), + ], + [ + sg.Text(i18n("淡入淡出长度")), + sg.Slider( + range=(0.01, 0.15), + key="crossfade_length", + resolution=0.01, + orientation="h", + default_value=data.get("crossfade_length", ""), + ), + ], + [ + sg.Text(i18n("额外推理时长")), + sg.Slider( + range=(0.05, 3.00), + key="extra_time", + resolution=0.01, + orientation="h", + default_value=data.get("extra_time", ""), + ), + ], + [ + sg.Checkbox(i18n("输入降噪"), key="I_noise_reduce"), + sg.Checkbox(i18n("输出降噪"), key="O_noise_reduce"), + ], + ], + title=i18n("性能设置"), + ), + ], + [ + sg.Button(i18n("开始音频转换"), key="start_vc"), + sg.Button(i18n("停止音频转换"), key="stop_vc"), + sg.Text(i18n("推理时间(ms):")), + sg.Text("0", key="infer_time"), + ], + ] + self.window = sg.Window("RVC - GUI", layout=layout) + self.event_handler() + + def event_handler(self): + while True: + event, values = self.window.read() + if event == sg.WINDOW_CLOSED: + self.flag_vc = False + exit() + if event == "start_vc" and self.flag_vc == False: + if self.set_values(values) == True: + print("using_cuda:" + str(torch.cuda.is_available())) + self.start_vc() + settings = { + "pth_path": values["pth_path"], + "index_path": values["index_path"], + "sg_input_device": values["sg_input_device"], + "sg_output_device": values["sg_output_device"], + "threhold": values["threhold"], + "pitch": values["pitch"], + "index_rate": values["index_rate"], + "block_time": values["block_time"], + "crossfade_length": values["crossfade_length"], + "extra_time": values["extra_time"], + "n_cpu": values["n_cpu"], + "f0method": ["pm","harvest","crepe"][[values["pm"],values["harvest"],values["crepe"]].index(True)], + } + with open("values1.json", "w") as j: + json.dump(settings, j) + if event == "stop_vc" and self.flag_vc == True: + self.flag_vc = False + + def set_values(self, values): + if len(values["pth_path"].strip()) == 0: + sg.popup(i18n("请选择pth文件")) + return False + if len(values["index_path"].strip()) == 0: + sg.popup(i18n("请选择index文件")) + return False + pattern = re.compile("[^\x00-\x7F]+") + if pattern.findall(values["pth_path"]): + sg.popup(i18n("pth文件路径不可包含中文")) + return False + if pattern.findall(values["index_path"]): + sg.popup(i18n("index文件路径不可包含中文")) + return False + self.set_devices(values["sg_input_device"], values["sg_output_device"]) + self.config.pth_path = values["pth_path"] + self.config.index_path = values["index_path"] + self.config.threhold = values["threhold"] + self.config.pitch = values["pitch"] + self.config.block_time = values["block_time"] + self.config.crossfade_time = values["crossfade_length"] + self.config.extra_time = values["extra_time"] + self.config.I_noise_reduce = values["I_noise_reduce"] + self.config.O_noise_reduce = values["O_noise_reduce"] + self.config.index_rate = values["index_rate"] + self.config.n_cpu = values["n_cpu"] + self.config.f0method = ["pm","harvest","crepe"][[values["pm"],values["harvest"],values["crepe"]].index(True)] + return True + + def start_vc(self): + torch.cuda.empty_cache() + self.flag_vc = True + self.rvc = RVC( + self.config.pitch, + self.config.pth_path, + self.config.index_path, + self.config.index_rate, + self.config.n_cpu,inp_q,opt_q,device + ) + self.config.samplerate=self.rvc.tgt_sr + self.config.crossfade_time=min(self.config.crossfade_time,self.config.block_time) + self.block_frame = int(self.config.block_time * self.config.samplerate) + self.crossfade_frame = int(self.config.crossfade_time * self.config.samplerate) + self.sola_search_frame = int(0.01 * self.config.samplerate) + self.extra_frame = int(self.config.extra_time * self.config.samplerate) + self.zc=self.rvc.tgt_sr//100 + self.input_wav: np.ndarray = np.zeros(int(np.ceil((self.extra_frame+ self.crossfade_frame+ self.sola_search_frame+ self.block_frame)/self.zc)*self.zc),dtype="float32",) + self.output_wav_cache: torch.Tensor = torch.zeros(int(np.ceil((self.extra_frame+ self.crossfade_frame+ self.sola_search_frame+ self.block_frame)/self.zc)*self.zc), device=device,dtype=torch.float32) + self.pitch: np.ndarray = np.zeros(self.input_wav.shape[0]//self.zc,dtype="int32",) + self.pitchf: np.ndarray = np.zeros(self.input_wav.shape[0]//self.zc,dtype="float64",) + self.output_wav: torch.Tensor = torch.zeros(self.block_frame, device=device, dtype=torch.float32) + self.sola_buffer: torch.Tensor = torch.zeros( + self.crossfade_frame, device=device, dtype=torch.float32 + ) + self.fade_in_window: torch.Tensor = torch.linspace( + 0.0, 1.0, steps=self.crossfade_frame, device=device, dtype=torch.float32 + ) + self.fade_out_window: torch.Tensor = 1 - self.fade_in_window + self.resampler = tat.Resample( + orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32 + ) + thread_vc = threading.Thread(target=self.soundinput) + thread_vc.start() + + def soundinput(self): + """ + 接受音频输入 + """ + with sd.Stream( + channels=2, + callback=self.audio_callback, + blocksize=self.block_frame, + samplerate=self.config.samplerate, + dtype="float32", + ): + while self.flag_vc: + time.sleep(self.config.block_time) + print("Audio block passed.") + print("ENDing VC") + + def audio_callback( + self, indata: np.ndarray, outdata: np.ndarray, frames, times, status + ): + """ + 音频处理 + """ + start_time = time.perf_counter() + indata = librosa.to_mono(indata.T) + if self.config.I_noise_reduce: + indata[:] = nr.reduce_noise(y=indata, sr=self.config.samplerate) + """noise gate""" + frame_length = 2048 + hop_length = 1024 + rms = librosa.feature.rms( + y=indata, frame_length=frame_length, hop_length=hop_length + ) + if(self.config.threhold>-60): + db_threhold = librosa.amplitude_to_db(rms, ref=1.0)[0] < self.config.threhold + for i in range(db_threhold.shape[0]): + if db_threhold[i]: + indata[i * hop_length : (i + 1) * hop_length] = 0 + self.input_wav[:] = np.append(self.input_wav[self.block_frame :], indata) + # infer + inp=torch.from_numpy(self.input_wav) + res1=self.resampler(inp) + rate1=self.block_frame/(self.extra_frame+ self.crossfade_frame+ self.sola_search_frame+ self.block_frame) + rate2=(self.crossfade_frame + self.sola_search_frame + self.block_frame)/(self.extra_frame+ self.crossfade_frame+ self.sola_search_frame+ self.block_frame) + res2=self.rvc.infer(res1,res1[-self.block_frame:].numpy(),rate1,rate2,self.pitch,self.pitchf,self.config.f0method) + self.output_wav_cache[-res2.shape[0]:]=res2 + infer_wav = self.output_wav_cache[-self.crossfade_frame - self.sola_search_frame - self.block_frame :].to(device) + # SOLA algorithm from https://github.com/yxlllc/DDSP-SVC + cor_nom = F.conv1d( + infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame], + self.sola_buffer[None, None, :], + ) + cor_den = torch.sqrt( + F.conv1d( + infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame] + ** 2, + torch.ones(1, 1, self.crossfade_frame, device=device), + ) + + 1e-8 + ) + sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0]) + print("sola offset: " + str(int(sola_offset))) + self.output_wav[:] = infer_wav[sola_offset : sola_offset + self.block_frame] + self.output_wav[: self.crossfade_frame] *= self.fade_in_window + self.output_wav[: self.crossfade_frame] += self.sola_buffer[:] + # crossfade + if sola_offset < self.sola_search_frame: + self.sola_buffer[:] = ( + infer_wav[ + -self.sola_search_frame + - self.crossfade_frame + + sola_offset : -self.sola_search_frame + + sola_offset + ] + * self.fade_out_window + ) + else: + self.sola_buffer[:] = ( + infer_wav[-self.crossfade_frame :] * self.fade_out_window + ) + if self.config.O_noise_reduce: + outdata[:] = np.tile( + nr.reduce_noise( + y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate + ), + (2, 1), + ).T + else: + outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy() + total_time = time.perf_counter() - start_time + self.window["infer_time"].update(int(total_time * 1000)) + print("infer time:" + str(total_time)) + + def get_devices(self, update: bool = True): + """获取设备列表""" + if update: + sd._terminate() + sd._initialize() + devices = sd.query_devices() + hostapis = sd.query_hostapis() + for hostapi in hostapis: + for device_idx in hostapi["devices"]: + devices[device_idx]["hostapi_name"] = hostapi["name"] + input_devices = [ + f"{d['name']} ({d['hostapi_name']})" + for d in devices + if d["max_input_channels"] > 0 + ] + output_devices = [ + f"{d['name']} ({d['hostapi_name']})" + for d in devices + if d["max_output_channels"] > 0 + ] + input_devices_indices = [ + d["index"] if "index" in d else d["name"] + for d in devices + if d["max_input_channels"] > 0 + ] + output_devices_indices = [ + d["index"] if "index" in d else d["name"] + for d in devices + if d["max_output_channels"] > 0 + ] + return ( + input_devices, + output_devices, + input_devices_indices, + output_devices_indices, + ) + + def set_devices(self, input_device, output_device): + """设置输出设备""" + ( + input_devices, + output_devices, + input_device_indices, + output_device_indices, + ) = self.get_devices() + sd.default.device[0] = input_device_indices[input_devices.index(input_device)] + sd.default.device[1] = output_device_indices[ + output_devices.index(output_device) + ] + print("input device:" + str(sd.default.device[0]) + ":" + str(input_device)) + print("output device:" + str(sd.default.device[1]) + ":" + str(output_device)) + + + gui = GUI() diff --git a/rvc_for_realtime.py b/rvc_for_realtime.py new file mode 100644 index 0000000..7a56d1a --- /dev/null +++ b/rvc_for_realtime.py @@ -0,0 +1,257 @@ +import faiss,torch,traceback,parselmouth,numpy as np,torchcrepe,torch.nn as nn,pyworld +from fairseq import checkpoint_utils +from lib.infer_pack.models import ( + SynthesizerTrnMs256NSFsid, + SynthesizerTrnMs256NSFsid_nono, + SynthesizerTrnMs768NSFsid, + SynthesizerTrnMs768NSFsid_nono, +) +import os,sys +from time import time as ttime +import torch.nn.functional as F +import scipy.signal as signal +now_dir = os.getcwd() +sys.path.append(now_dir) +from config import Config +from multiprocessing import Manager as M +mm = M() +config = Config() + +class RVC: + def __init__( + self, key, pth_path, index_path, index_rate, n_cpu,inp_q,opt_q,device + ) -> None: + """ + 初始化 + """ + try: + global config + self.inp_q=inp_q + self.opt_q=opt_q + self.device=device + self.f0_up_key = key + self.time_step = 160 / 16000 * 1000 + self.f0_min = 50 + self.f0_max = 1100 + self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700) + self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700) + self.sr = 16000 + self.window = 160 + self.n_cpu = n_cpu + if index_rate != 0: + self.index = faiss.read_index(index_path) + self.big_npy = self.index.reconstruct_n(0, self.index.ntotal) + print("index search enabled") + self.index_rate = index_rate + models, _, _ = checkpoint_utils.load_model_ensemble_and_task( + ["hubert_base.pt"], + suffix="", + ) + hubert_model = models[0] + hubert_model = hubert_model.to(config.device) + if config.is_half: + hubert_model = hubert_model.half() + else: + hubert_model = hubert_model.float() + hubert_model.eval() + self.model = hubert_model + cpt = torch.load(pth_path, map_location="cpu") + self.tgt_sr = cpt["config"][-1] + cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] + self.if_f0 = cpt.get("f0", 1) + self.version = cpt.get("version", "v1") + if self.version == "v1": + if self.if_f0 == 1: + self.net_g = SynthesizerTrnMs256NSFsid( + *cpt["config"], is_half=config.is_half + ) + else: + self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) + elif self.version == "v2": + if self.if_f0 == 1: + self.net_g = SynthesizerTrnMs768NSFsid( + *cpt["config"], is_half=config.is_half + ) + else: + self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) + del self.net_g.enc_q + print(self.net_g.load_state_dict(cpt["weight"], strict=False)) + self.net_g.eval().to(device) + if config.is_half: + self.net_g = self.net_g.half() + else: + self.net_g = self.net_g.float() + except: + print(traceback.format_exc()) + + def get_f0_post(self, f0): + f0_min = self.f0_min + f0_max = self.f0_max + f0_mel_min = 1127 * np.log(1 + f0_min / 700) + f0_mel_max = 1127 * np.log(1 + f0_max / 700) + f0bak = f0.copy() + f0_mel = 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / ( + f0_mel_max - f0_mel_min + ) + 1 + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > 255] = 255 + f0_coarse = np.rint(f0_mel).astype(np.int) + return f0_coarse, f0bak + + def get_f0(self, x, f0_up_key, n_cpu, method="harvest"): + n_cpu = int(n_cpu) + if (method == "crepe"): return self.get_f0_crepe(x, f0_up_key) + if (method == "pm"): + p_len = x.shape[0] // 160 + f0 = ( + parselmouth.Sound(x, 16000) + .to_pitch_ac( + time_step=0.01, + voicing_threshold=0.6, + pitch_floor=50, + pitch_ceiling=1100, + ) + .selected_array["frequency"] + ) + + pad_size = (p_len - len(f0) + 1) // 2 + if pad_size > 0 or p_len - len(f0) - pad_size > 0: + print(pad_size, p_len - len(f0) - pad_size) + f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant") + + f0 *= pow(2, f0_up_key / 12) + return self.get_f0_post(f0) + if (n_cpu == 1): + f0, t = pyworld.harvest( + x.astype(np.double), + fs=16000, + f0_ceil=1100, + f0_floor=50, + frame_period=10, + ) + f0 = signal.medfilt(f0, 3) + f0 *= pow(2, f0_up_key / 12) + return self.get_f0_post(f0) + f0bak = np.zeros(x.shape[0] // 160, dtype=np.float64) + length = len(x) + part_length = int(length / n_cpu / 160) * 160 + ts = ttime() + res_f0 = mm.dict() + for idx in range(n_cpu): + tail = part_length * (idx + 1) + 320 + if (idx == 0): + self.inp_q.put((idx, x[:tail], res_f0, n_cpu, ts)) + else: + self.inp_q.put((idx, x[part_length * idx - 320:tail], res_f0, n_cpu, ts)) + while (1): + res_ts = self.opt_q.get() + if (res_ts == ts): + break + f0s = [i[1] for i in sorted(res_f0.items(), key=lambda x: x[0])] + for idx, f0 in enumerate(f0s): + if (idx == 0): + f0 = f0[:-3] + elif (idx != n_cpu - 1): + f0 = f0[2:-3] + else: + f0 = f0[2:-1] + f0bak[part_length * idx // 160:part_length * idx // 160 + f0.shape[0]] = f0 + f0bak = signal.medfilt(f0bak, 3) + f0bak *= pow(2, f0_up_key / 12) + return self.get_f0_post(f0bak) + + def get_f0_crepe(self, x, f0_up_key): + audio = torch.tensor(np.copy(x))[None].float() + f0, pd = torchcrepe.predict( + audio, + self.sr, + 160, + self.f0_min, + self.f0_max, + "full", + batch_size=512, + device=self.device, + return_periodicity=True, + ) + pd = torchcrepe.filter.median(pd, 3) + f0 = torchcrepe.filter.mean(f0, 3) + f0[pd < 0.1] = 0 + f0 = f0[0].cpu().numpy() + f0 *= pow(2, f0_up_key / 12) + return self.get_f0_post(f0) + + def infer(self, feats: torch.Tensor, indata: np.ndarray, rate1, rate2, cache_pitch, cache_pitchf, f0method) -> np.ndarray: + feats = feats.view(1, -1) + if config.is_half: + feats = feats.half() + else: + feats = feats.float() + feats = feats.to(self.device) + t1 = ttime() + with torch.no_grad(): + padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False) + inputs = { + "source": feats, + "padding_mask": padding_mask, + "output_layer": 9 if self.version == "v1" else 12, + } + logits = self.model.extract_features(**inputs) + feats = self.model.final_proj(logits[0]) if self.version == "v1" else logits[0] + t2 = ttime() + try: + if ( + hasattr(self, "index") + and self.index_rate != 0 + ): + leng_replace_head = int(rate1 * feats[0].shape[0]) + npy = feats[0][-leng_replace_head:].cpu().numpy().astype("float32") + score, ix = self.index.search(npy, k=8) + weight = np.square(1 / score) + weight /= weight.sum(axis=1, keepdims=True) + npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) + if config.is_half: + npy = npy.astype("float16") + feats[0][-leng_replace_head:] = ( + torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate + + (1 - self.index_rate) * feats[0][-leng_replace_head:] + ) + else: + print("index search FAIL or disabled") + except: + traceback.print_exc() + print("index search FAIL") + feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + t3 = ttime() + if self.if_f0 == 1: + pitch, pitchf = self.get_f0(indata, self.f0_up_key, self.n_cpu, f0method) + cache_pitch[:] = np.append(cache_pitch[pitch[:-1].shape[0]:], pitch[:-1]) + cache_pitchf[:] = np.append(cache_pitchf[pitchf[:-1].shape[0]:], pitchf[:-1]) + p_len = min(feats.shape[1], 13000, cache_pitch.shape[0]) + else: + cache_pitch, cache_pitchf = None, None + p_len = min(feats.shape[1], 13000) + t4 = ttime() + feats = feats[:, :p_len, :] + if self.if_f0 == 1: + cache_pitch = cache_pitch[:p_len] + cache_pitchf = cache_pitchf[:p_len] + cache_pitch = torch.LongTensor(cache_pitch).unsqueeze(0).to(self.device) + cache_pitchf = torch.FloatTensor(cache_pitchf).unsqueeze(0).to(self.device) + p_len = torch.LongTensor([p_len]).to(self.device) + ii = 0 # sid + sid = torch.LongTensor([ii]).to(self.device) + with torch.no_grad(): + if self.if_f0 == 1: + infered_audio = ( + self.net_g.infer(feats, p_len, cache_pitch, cache_pitchf, sid, rate2)[0][0, 0] + .data.cpu() + .float() + ) + else: + infered_audio = ( + self.net_g.infer(feats, p_len, sid, rate2)[0][0, 0].data.cpu().float() + ) + t5 = ttime() + print("time->fea-index-f0-model:", t2 - t1, t3 - t2, t4 - t3, t5 - t4) + return infered_audio