From 20fb86acfce4f94f33bae121afefce28aa315565 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Sat, 12 Aug 2023 22:59:30 +0800 Subject: [PATCH] Add files via upload --- lib/rmvpe.py | 315 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 269 insertions(+), 46 deletions(-) diff --git a/lib/rmvpe.py b/lib/rmvpe.py index fc254cf..25dcb8c 100644 --- a/lib/rmvpe.py +++ b/lib/rmvpe.py @@ -1,8 +1,197 @@ -import torch, numpy as np +import torch, numpy as np,pdb import torch.nn as nn import torch.nn.functional as F +import torch,pdb +import numpy as np +import torch.nn.functional as F +from scipy.signal import get_window +from librosa.util import pad_center, tiny,normalize +###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + n_frames : int > 0 + The number of analysis frames + hop_length : int > 0 + The number of samples to advance between frames + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + n_fft : int > 0 + The length of each analysis frame. + dtype : np.dtype + The data type of the output + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = normalize(win_sq, norm=norm)**2 + win_sq = pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + +class STFT(torch.nn.Module): + def __init__(self, filter_length=1024, hop_length=512, win_length=None, + window='hann'): + """ + This module implements an STFT using 1D convolution and 1D transpose convolutions. + This is a bit tricky so there are some cases that probably won't work as working + out the same sizes before and after in all overlap add setups is tough. Right now, + this code should work with hop lengths that are half the filter length (50% overlap + between frames). + + Keyword Arguments: + filter_length {int} -- Length of filters used (default: {1024}) + hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512}) + win_length {[type]} -- Length of the window function applied to each frame (if not specified, it + equals the filter length). (default: {None}) + window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris) + (default: {'hann'}) + """ + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length if win_length else filter_length + self.window = window + self.forward_transform = None + self.pad_amount = int(self.filter_length / 2) + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),np.imag(fourier_basis[:cutoff, :])]) + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + assert (filter_length >= self.win_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, self.win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + """Take input data (audio) to STFT domain. + + Arguments: + input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) + + Returns: + magnitude {tensor} -- Magnitude of STFT with shape (num_batch, + num_frequencies, num_frames) + phase {tensor} -- Phase of STFT with shape (num_batch, + num_frequencies, num_frames) + """ + num_batches = input_data.shape[0] + num_samples = input_data.shape[-1] + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + # print(1234,input_data.shape) + input_data = F.pad(input_data.unsqueeze(1),(self.pad_amount, self.pad_amount, 0, 0,0,0),mode='reflect').squeeze(1) + # print(2333,input_data.shape,self.forward_basis.shape,self.hop_length) + # pdb.set_trace() + forward_transform = F.conv1d( + input_data, + self.forward_basis, + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + # phase = torch.atan2(imag_part.data, real_part.data) + + return magnitude#, phase + + def inverse(self, magnitude, phase): + """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced + by the ```transform``` function. + + Arguments: + magnitude {tensor} -- Magnitude of STFT with shape (num_batch, + num_frequencies, num_frames) + phase {tensor} -- Phase of STFT with shape (num_batch, + num_frequencies, num_frames) + + Returns: + inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of + shape (num_batch, num_samples) + """ + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + self.inverse_basis, + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.from_numpy(window_sum).to(inverse_transform.device) + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[..., self.pad_amount:] + inverse_transform = inverse_transform[..., :self.num_samples] + inverse_transform = inverse_transform.squeeze(1) + + return inverse_transform + + def forward(self, input_data): + """Take input data (audio) to STFT domain and then back to audio. + + Arguments: + input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) + + Returns: + reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of + shape (num_batch, num_samples) + """ + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction +from time import time as ttime class BiGRU(nn.Module): def __init__(self, input_features, hidden_features, num_layers): super(BiGRU, self).__init__() @@ -250,9 +439,11 @@ class E2E(nn.Module): ) def forward(self, mel): + # print(mel.shape) mel = mel.transpose(-1, -2).unsqueeze(1) x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) x = self.fc(x) + # print(x.shape) return x @@ -301,18 +492,33 @@ class MelSpectrogram(torch.nn.Module): keyshift_key = str(keyshift) + "_" + str(audio.device) if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to( + # "cpu"if(audio.device.type=="privateuseone") else audio.device audio.device ) - fft = torch.stft( - audio, - n_fft=n_fft_new, - hop_length=hop_length_new, - win_length=win_length_new, - window=self.hann_window[keyshift_key], - center=center, - return_complex=True, - ) - magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + # fft = torch.stft(#doesn't support pytorch_dml + # # audio.cpu() if(audio.device.type=="privateuseone")else audio, + # audio, + # n_fft=n_fft_new, + # hop_length=hop_length_new, + # win_length=win_length_new, + # window=self.hann_window[keyshift_key], + # center=center, + # return_complex=True, + # ) + # magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + # print(1111111111) + # print(222222222222222,audio.device,self.is_half) + if hasattr(self, "stft") == False: + # print(n_fft_new,hop_length_new,win_length_new,audio.shape) + self.stft=STFT( + filter_length=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window='hann' + ).to(audio.device) + magnitude = self.stft.transform(audio)#phase + # if (audio.device.type == "privateuseone"): + # magnitude=magnitude.to(audio.device) if keyshift != 0: size = self.n_fft // 2 + 1 resize = magnitude.size(1) @@ -323,19 +529,13 @@ class MelSpectrogram(torch.nn.Module): if self.is_half == True: mel_output = mel_output.half() log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + # print(log_mel_spec.device.type) return log_mel_spec class RMVPE: def __init__(self, model_path, is_half, device=None): self.resample_kernel = {} - model = E2E(4, 1, (2, 2)) - ckpt = torch.load(model_path, map_location="cpu") - model.load_state_dict(ckpt) - model.eval() - if is_half == True: - model = model.half() - self.model = model self.resample_kernel = {} self.is_half = is_half if device is None: @@ -344,7 +544,19 @@ class RMVPE: self.mel_extractor = MelSpectrogram( is_half, 128, 16000, 1024, 160, None, 30, 8000 ).to(device) - self.model = self.model.to(device) + if ("privateuseone" in str(device)): + import onnxruntime as ort + ort_session = ort.InferenceSession("rmvpe.onnx", providers=["DmlExecutionProvider"]) + self.model=ort_session + else: + model = E2E(4, 1, (2, 2)) + ckpt = torch.load(model_path, map_location="cpu") + model.load_state_dict(ckpt) + model.eval() + if is_half == True: + model = model.half() + self.model = model + self.model = self.model.to(device) cents_mapping = 20 * np.arange(360) + 1997.3794084376191 self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368 @@ -354,7 +566,12 @@ class RMVPE: mel = F.pad( mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect" ) - hidden = self.model(mel) + if("privateuseone" in str(self.device) ): + onnx_input_name = self.model.get_inputs()[0].name + onnx_outputs_names = self.model.get_outputs()[0].name + hidden = self.model.run([onnx_outputs_names], input_feed={onnx_input_name: mel.cpu().numpy()})[0] + else: + hidden = self.model(mel) return hidden[:, :n_frames] def decode(self, hidden, thred=0.03): @@ -365,21 +582,26 @@ class RMVPE: return f0 def infer_from_audio(self, audio, thred=0.03): - audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0) # torch.cuda.synchronize() - # t0=ttime() - mel = self.mel_extractor(audio, center=True) + t0=ttime() + mel = self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True) + # print(123123123,mel.device.type) # torch.cuda.synchronize() - # t1=ttime() + t1=ttime() hidden = self.mel2hidden(mel) # torch.cuda.synchronize() - # t2=ttime() - hidden = hidden.squeeze(0).cpu().numpy() + t2=ttime() + # print(234234,hidden.device.type) + if("privateuseone" not in str(self.device)): + hidden = hidden.squeeze(0).cpu().numpy() + else: + hidden=hidden[0] if self.is_half == True: hidden = hidden.astype("float32") + f0 = self.decode(hidden, thred=thred) # torch.cuda.synchronize() - # t3=ttime() + t3=ttime() # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0)) return f0 @@ -410,22 +632,23 @@ class RMVPE: return devided -# if __name__ == '__main__': -# audio, sampling_rate = sf.read("卢本伟语录~1.wav") -# if len(audio.shape) > 1: -# audio = librosa.to_mono(audio.transpose(1, 0)) -# audio_bak = audio.copy() -# if sampling_rate != 16000: -# audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) -# model_path = "/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/test-RMVPE/weights/rmvpe_llc_half.pt" -# thred = 0.03 # 0.01 -# device = 'cuda' if torch.cuda.is_available() else 'cpu' -# rmvpe = RMVPE(model_path,is_half=False, device=device) -# t0=ttime() -# f0 = rmvpe.infer_from_audio(audio, thred=thred) -# f0 = rmvpe.infer_from_audio(audio, thred=thred) -# f0 = rmvpe.infer_from_audio(audio, thred=thred) -# f0 = rmvpe.infer_from_audio(audio, thred=thred) -# f0 = rmvpe.infer_from_audio(audio, thred=thred) -# t1=ttime() -# print(f0.shape,t1-t0) +if __name__ == '__main__': + import soundfile as sf, librosa + audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav") + if len(audio.shape) > 1: + audio = librosa.to_mono(audio.transpose(1, 0)) + audio_bak = audio.copy() + if sampling_rate != 16000: + audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) + model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt" + thred = 0.03 # 0.01 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + rmvpe = RMVPE(model_path,is_half=False, device=device) + t0=ttime() + f0 = rmvpe.infer_from_audio(audio, thred=thred) + # f0 = rmvpe.infer_from_audio(audio, thred=thred) + # f0 = rmvpe.infer_from_audio(audio, thred=thred) + # f0 = rmvpe.infer_from_audio(audio, thred=thred) + # f0 = rmvpe.infer_from_audio(audio, thred=thred) + t1=ttime() + print(f0.shape,t1-t0)