diff --git a/gui.py b/gui.py index 058f974..e1c5952 100644 --- a/gui.py +++ b/gui.py @@ -3,7 +3,7 @@ import sounddevice as sd import noisereduce as nr import numpy as np from fairseq import checkpoint_utils -import librosa,torch,parselmouth,faiss,time,threading,math +import librosa,torch,parselmouth,faiss,time,threading import torch.nn.functional as F import torchaudio.transforms as tat @@ -15,7 +15,7 @@ i18n = I18nAuto() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class RVC: - def __init__(self,key,pth_path,index_path,npy_path) -> None: + def __init__(self,key,hubert_path,pth_path,index_path,npy_path,index_rate) -> None: ''' 初始化 ''' @@ -26,8 +26,10 @@ class RVC: self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700) self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700) self.index=faiss.read_index(index_path) + self.index_rate=index_rate + '''NOT YET USED''' self.big_npy=np.load(npy_path) - model_path = "TEMP\\hubert_base.pt" + model_path = hubert_path print("load model(s) from {}".format(model_path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [model_path], @@ -75,25 +77,11 @@ class RVC: return f0_coarse, f0bak - def infer(self,audio:np.ndarray,sampling_rate:int) -> np.ndarray: + def infer(self,feats:torch.Tensor) -> np.ndarray: ''' - 推理函数。 - :param audio: ndarray(n,2) - :sampling_rate: 采样率 + 推理函数 ''' - - # f0_up_key=12 - if len(audio.shape) > 1: - audio = librosa.to_mono(audio.transpose(1, 0)) - if sampling_rate != 16000: - audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) - #print('test:audio:'+str(audio.shape)) - '''padding''' - - - feats = torch.from_numpy(audio).float() - if feats.dim() == 2: # double channels - feats = feats.mean(-1) + audio=feats.clone().cpu().numpy() assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) @@ -108,17 +96,17 @@ class RVC: feats = self.model.final_proj(logits[0]) ####索引优化 - npy = feats[0].cpu().numpy().astype("float32") - D, I = self.index.search(npy, 1) - # feats = torch.from_numpy(big_npy[I.squeeze()].astype("float16")).unsqueeze(0).to(device) - index_rate=0.5 - feats = torch.from_numpy(npy).unsqueeze(0).to(device) * index_rate + (1 - index_rate) * feats - feats=feats.half() + if(isinstance(self.index,type(None))==False and isinstance(self.big_npy,type(None))==False and self.index_rate!=0): + npy = feats[0].cpu().numpy().astype("float32") + _, I = self.index.search(npy, 1) + npy=self.big_npy[I.squeeze()].astype("float16") + feats = torch.from_numpy(npy).unsqueeze(0).to(device)*self.index_rate + (1-self.index_rate)*feats feats=F.interpolate(feats.permute(0,2,1),scale_factor=2).permute(0,2,1) torch.cuda.synchronize() # p_len = min(feats.shape[1],10000,pitch.shape[0])#太大了爆显存 p_len = min(feats.shape[1],12000)# + print(feats.shape) pitch, pitchf = self.get_f0(audio, p_len,self.f0_up_key) p_len = min(feats.shape[1],12000,pitch.shape[0])#太大了爆显存 torch.cuda.synchronize() @@ -132,13 +120,14 @@ class RVC: ii=0#sid sid=torch.LongTensor([ii]).to(device) with torch.no_grad(): - audio = self.net_g.infer(feats, p_len,pitch,pitchf,sid)[0][0, 0].data.cpu().float().numpy()#nsf + infered_audio = self.net_g.infer(feats, p_len,pitch,pitchf,sid)[0][0, 0].data.cpu().float()#nsf torch.cuda.synchronize() - return audio + return infered_audio class Config: def __init__(self) -> None: + self.hubert_path:str='' self.pth_path:str='' self.index_path:str='' self.npy_path:str='' @@ -151,6 +140,7 @@ class Config: self.extra_time:float=0.04 self.I_noise_reduce=False self.O_noise_reduce=False + self.index_rate=0.3 class GUI: def __init__(self) -> None: @@ -180,8 +170,8 @@ class GUI: [ sg.Frame(layout=[ [sg.Text(i18n("响应阈值")),sg.Slider(range=(-60,0),key='threhold',resolution=1,orientation='h',default_value=-30)], - [sg.Text(i18n("音调设置")),sg.Slider(range=(-24,24),key='pitch',resolution=1,orientation='h',default_value=12)] - + [sg.Text(i18n("音调设置")),sg.Slider(range=(-24,24),key='pitch',resolution=1,orientation='h',default_value=12)], + [sg.Text(i18n('Index Rate')),sg.Slider(range=(0.0,1.0),key='index_rate',resolution=0.01,orientation='h',default_value=0.5)] ],title=i18n("常规设置")), sg.Frame(layout=[ [sg.Text(i18n("采样长度")),sg.Slider(range=(0.1,3.0),key='block_time',resolution=0.1,orientation='h',default_value=1.0)], @@ -204,9 +194,7 @@ class GUI: exit() if event == 'start_vc' and self.flag_vc==False: self.set_values(values) - print('pth_path:'+self.config.pth_path) - print('index_path:'+self.config.index_path) - print('npy_path:'+self.config.npy_path) + print(str(self.config.__dict__)) print('using_cuda:'+str(torch.cuda.is_available())) self.start_vc() if event=='stop_vc'and self.flag_vc==True: @@ -215,6 +203,7 @@ class GUI: def set_values(self,values): self.set_devices(values["sg_input_device"],values['sg_output_device']) + self.config.hubert_path=values['hubert_path'] self.config.pth_path=values['pth_path'] self.config.index_path=values['index_path'] self.config.npy_path=values['npy_path'] @@ -225,27 +214,25 @@ class GUI: self.config.extra_time=values['extra_time'] self.config.I_noise_reduce=values['I_noise_reduce'] self.config.O_noise_reduce=values['O_noise_reduce'] + self.config.index_rate=values['index_rate'] def start_vc(self): torch.cuda.empty_cache() self.flag_vc=True - self.RMS_threhold=math.e**(float(self.config.threhold)/10) self.block_frame=int(self.config.block_time*self.config.samplerate) self.crossfade_frame=int(self.config.crossfade_time*self.config.samplerate) self.sola_search_frame=int(0.012*self.config.samplerate) self.delay_frame=int(0.02*self.config.samplerate)#往前预留0.02s self.extra_frame=int(self.config.extra_time*self.config.samplerate)#往后预留0.04s self.rvc=None - self.rvc=RVC(self.config.pitch,self.config.pth_path,self.config.index_path,self.config.npy_path) + self.rvc=RVC(self.config.pitch,self.config.hubert_path,self.config.pth_path,self.config.index_path,self.config.npy_path,self.config.index_rate) self.input_wav:np.ndarray=np.zeros(self.extra_frame+self.crossfade_frame+self.sola_search_frame+self.block_frame,dtype='float32') self.output_wav:torch.Tensor=torch.zeros(self.block_frame,device=device,dtype=torch.float32) - #self.sola_buffer:np.ndarray=np.zeros(self.crossfade_frame,dtype='float32') self.sola_buffer:torch.Tensor=torch.zeros(self.crossfade_frame,device=device,dtype=torch.float32) - #self.fade_in_window:np.ndarray = np.linspace(0, 1, self.crossfade_frame) self.fade_in_window:torch.Tensor=torch.linspace(0.0,1.0,steps=self.crossfade_frame,device=device,dtype=torch.float32) self.fade_out_window:torch.Tensor = 1 - self.fade_in_window - self.resampler=tat.Resample(orig_freq=40000,new_freq=self.config.samplerate,dtype=torch.float32) - self.RMS=lambda y:torch.sqrt(torch.mean(torch.square(y))).item()#RMS calculator + self.resampler1=tat.Resample(orig_freq=self.config.samplerate,new_freq=16000,dtype=torch.float32) + self.resampler2=tat.Resample(orig_freq=40000,new_freq=self.config.samplerate,dtype=torch.float32) thread_vc=threading.Thread(target=self.soundinput) thread_vc.start() @@ -284,7 +271,7 @@ class GUI: #infer print('input_wav:'+str(self.input_wav.shape)) #print('infered_wav:'+str(infer_wav.shape)) - infer_wav:torch.Tensor=self.resampler(torch.from_numpy(self.rvc.infer(self.input_wav,self.config.samplerate)))[-self.crossfade_frame-self.sola_search_frame-self.block_frame:].to(device) + infer_wav:torch.Tensor=self.resampler2(self.rvc.infer(self.resampler1(torch.from_numpy(self.input_wav))))[-self.crossfade_frame-self.sola_search_frame-self.block_frame:].to(device) print('infer_wav:'+str(infer_wav.shape)) # SOLA algorithm from https://github.com/yxlllc/DDSP-SVC diff --git a/locale/zh_CN.json b/locale/zh_CN.json index 4b1b672..2cf6fd4 100644 --- a/locale/zh_CN.json +++ b/locale/zh_CN.json @@ -94,5 +94,6 @@ "性能设置": "性能设置", "开始音频转换": "开始音频转换", "停止音频转换": "停止音频转换", - "推理时间(ms):": "推理时间(ms):" + "Infer Time(ms):":"推理时间(ms):", + "Index Rate":"索引权重" } \ No newline at end of file