diff --git a/infer/lib/rmvpe.py b/infer/lib/rmvpe.py index 25dcb8c..e5fa613 100644 --- a/infer/lib/rmvpe.py +++ b/infer/lib/rmvpe.py @@ -1,14 +1,23 @@ -import torch, numpy as np,pdb +import torch, numpy as np, pdb import torch.nn as nn import torch.nn.functional as F -import torch,pdb +import torch, pdb import numpy as np import torch.nn.functional as F from scipy.signal import get_window -from librosa.util import pad_center, tiny,normalize +from librosa.util import pad_center, tiny, normalize + + ###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py -def window_sumsquare(window, n_frames, hop_length=200, win_length=800, - n_fft=800, dtype=np.float32, norm=None): +def window_sumsquare( + window, + n_frames, + hop_length=200, + win_length=800, + n_fft=800, + dtype=np.float32, + norm=None, +): """ # from librosa 0.6 Compute the sum-square envelope of a window function at a given hop length. @@ -41,18 +50,20 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800, # Compute the squared window at the desired length win_sq = get_window(window, win_length, fftbins=True) - win_sq = normalize(win_sq, norm=norm)**2 + win_sq = normalize(win_sq, norm=norm) ** 2 win_sq = pad_center(win_sq, n_fft) # Fill the envelope for i in range(n_frames): sample = i * hop_length - x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] return x + class STFT(torch.nn.Module): - def __init__(self, filter_length=1024, hop_length=512, win_length=None, - window='hann'): + def __init__( + self, filter_length=1024, hop_length=512, win_length=None, window="hann" + ): """ This module implements an STFT using 1D convolution and 1D transpose convolutions. This is a bit tricky so there are some cases that probably won't work as working @@ -79,12 +90,15 @@ class STFT(torch.nn.Module): fourier_basis = np.fft.fft(np.eye(self.filter_length)) cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),np.imag(fourier_basis[:cutoff, :])]) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) inverse_basis = torch.FloatTensor( - np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) - assert (filter_length >= self.win_length) + assert filter_length >= self.win_length # get window and zero center pad it to filter_length fft_window = get_window(window, self.win_length, fftbins=True) fft_window = pad_center(fft_window, size=filter_length) @@ -94,8 +108,8 @@ class STFT(torch.nn.Module): forward_basis *= fft_window inverse_basis *= fft_window - self.register_buffer('forward_basis', forward_basis.float()) - self.register_buffer('inverse_basis', inverse_basis.float()) + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) def transform(self, input_data): """Take input data (audio) to STFT domain. @@ -117,23 +131,25 @@ class STFT(torch.nn.Module): # similar to librosa, reflect-pad the input input_data = input_data.view(num_batches, 1, num_samples) # print(1234,input_data.shape) - input_data = F.pad(input_data.unsqueeze(1),(self.pad_amount, self.pad_amount, 0, 0,0,0),mode='reflect').squeeze(1) + input_data = F.pad( + input_data.unsqueeze(1), + (self.pad_amount, self.pad_amount, 0, 0, 0, 0), + mode="reflect", + ).squeeze(1) # print(2333,input_data.shape,self.forward_basis.shape,self.hop_length) # pdb.set_trace() forward_transform = F.conv1d( - input_data, - self.forward_basis, - stride=self.hop_length, - padding=0) + input_data, self.forward_basis, stride=self.hop_length, padding=0 + ) cutoff = int((self.filter_length / 2) + 1) real_part = forward_transform[:, :cutoff, :] imag_part = forward_transform[:, cutoff:, :] - magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + magnitude = torch.sqrt(real_part**2 + imag_part**2) # phase = torch.atan2(imag_part.data, real_part.data) - return magnitude#, phase + return magnitude # , phase def inverse(self, magnitude, phase): """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced @@ -150,30 +166,39 @@ class STFT(torch.nn.Module): shape (num_batch, num_samples) """ recombine_magnitude_phase = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1) + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) inverse_transform = F.conv_transpose1d( recombine_magnitude_phase, self.inverse_basis, stride=self.hop_length, - padding=0) + padding=0, + ) if self.window is not None: window_sum = window_sumsquare( - self.window, magnitude.size(-1), hop_length=self.hop_length, - win_length=self.win_length, n_fft=self.filter_length, - dtype=np.float32) + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) # remove modulation effects approx_nonzero_indices = torch.from_numpy( - np.where(window_sum > tiny(window_sum))[0]) + np.where(window_sum > tiny(window_sum))[0] + ) window_sum = torch.from_numpy(window_sum).to(inverse_transform.device) - inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] # scale by hop ratio inverse_transform *= float(self.filter_length) / self.hop_length - inverse_transform = inverse_transform[..., self.pad_amount:] - inverse_transform = inverse_transform[..., :self.num_samples] + inverse_transform = inverse_transform[..., self.pad_amount :] + inverse_transform = inverse_transform[..., : self.num_samples] inverse_transform = inverse_transform.squeeze(1) return inverse_transform @@ -191,7 +216,11 @@ class STFT(torch.nn.Module): self.magnitude, self.phase = self.transform(input_data) reconstruction = self.inverse(self.magnitude, self.phase) return reconstruction + + from time import time as ttime + + class BiGRU(nn.Module): def __init__(self, input_features, hidden_features, num_layers): super(BiGRU, self).__init__() @@ -509,14 +538,14 @@ class MelSpectrogram(torch.nn.Module): # print(1111111111) # print(222222222222222,audio.device,self.is_half) if hasattr(self, "stft") == False: - # print(n_fft_new,hop_length_new,win_length_new,audio.shape) - self.stft=STFT( + # print(n_fft_new,hop_length_new,win_length_new,audio.shape) + self.stft = STFT( filter_length=n_fft_new, hop_length=hop_length_new, win_length=win_length_new, - window='hann' + window="hann", ).to(audio.device) - magnitude = self.stft.transform(audio)#phase + magnitude = self.stft.transform(audio) # phase # if (audio.device.type == "privateuseone"): # magnitude=magnitude.to(audio.device) if keyshift != 0: @@ -544,10 +573,13 @@ class RMVPE: self.mel_extractor = MelSpectrogram( is_half, 128, 16000, 1024, 160, None, 30, 8000 ).to(device) - if ("privateuseone" in str(device)): + if "privateuseone" in str(device): import onnxruntime as ort - ort_session = ort.InferenceSession("rmvpe.onnx", providers=["DmlExecutionProvider"]) - self.model=ort_session + + ort_session = ort.InferenceSession( + "rmvpe.onnx", providers=["DmlExecutionProvider"] + ) + self.model = ort_session else: model = E2E(4, 1, (2, 2)) ckpt = torch.load(model_path, map_location="cpu") @@ -566,10 +598,13 @@ class RMVPE: mel = F.pad( mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect" ) - if("privateuseone" in str(self.device) ): + if "privateuseone" in str(self.device): onnx_input_name = self.model.get_inputs()[0].name onnx_outputs_names = self.model.get_outputs()[0].name - hidden = self.model.run([onnx_outputs_names], input_feed={onnx_input_name: mel.cpu().numpy()})[0] + hidden = self.model.run( + [onnx_outputs_names], + input_feed={onnx_input_name: mel.cpu().numpy()}, + )[0] else: hidden = self.model(mel) return hidden[:, :n_frames] @@ -583,25 +618,27 @@ class RMVPE: def infer_from_audio(self, audio, thred=0.03): # torch.cuda.synchronize() - t0=ttime() - mel = self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True) + t0 = ttime() + mel = self.mel_extractor( + torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True + ) # print(123123123,mel.device.type) # torch.cuda.synchronize() - t1=ttime() + t1 = ttime() hidden = self.mel2hidden(mel) # torch.cuda.synchronize() - t2=ttime() + t2 = ttime() # print(234234,hidden.device.type) - if("privateuseone" not in str(self.device)): + if "privateuseone" not in str(self.device): hidden = hidden.squeeze(0).cpu().numpy() else: - hidden=hidden[0] + hidden = hidden[0] if self.is_half == True: hidden = hidden.astype("float32") f0 = self.decode(hidden, thred=thred) # torch.cuda.synchronize() - t3=ttime() + t3 = ttime() # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0)) return f0 @@ -632,8 +669,9 @@ class RMVPE: return devided -if __name__ == '__main__': +if __name__ == "__main__": import soundfile as sf, librosa + audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav") if len(audio.shape) > 1: audio = librosa.to_mono(audio.transpose(1, 0)) @@ -642,13 +680,13 @@ if __name__ == '__main__': audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt" thred = 0.03 # 0.01 - device = 'cuda' if torch.cuda.is_available() else 'cpu' - rmvpe = RMVPE(model_path,is_half=False, device=device) - t0=ttime() + device = "cuda" if torch.cuda.is_available() else "cpu" + rmvpe = RMVPE(model_path, is_half=False, device=device) + t0 = ttime() f0 = rmvpe.infer_from_audio(audio, thred=thred) # f0 = rmvpe.infer_from_audio(audio, thred=thred) # f0 = rmvpe.infer_from_audio(audio, thred=thred) # f0 = rmvpe.infer_from_audio(audio, thred=thred) # f0 = rmvpe.infer_from_audio(audio, thred=thred) - t1=ttime() - print(f0.shape,t1-t0) + t1 = ttime() + print(f0.shape, t1 - t0) diff --git a/infer/modules/vc/utils.py b/infer/modules/vc/utils.py index bc98989..98497e2 100644 --- a/infer/modules/vc/utils.py +++ b/infer/modules/vc/utils.py @@ -31,4 +31,3 @@ def load_hubert(config): else: hubert_model = hubert_model.float() return hubert_model.eval() -