replace lib

This commit is contained in:
Ftps 2023-08-19 20:00:56 +09:00
parent c25bb6c5d5
commit 6721b81dcf
8 changed files with 2381 additions and 0 deletions

21
infer/lib/audio.py Normal file
View File

@ -0,0 +1,21 @@
import ffmpeg
import numpy as np
def load_audio(file, sr):
try:
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
file = (
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
) # 防止小白拷路径头尾带了空格和"和回车
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
return np.frombuffer(out, np.float32).flatten()

654
infer/lib/rmvpe.py Normal file
View File

@ -0,0 +1,654 @@
import torch, numpy as np,pdb
import torch.nn as nn
import torch.nn.functional as F
import torch,pdb
import numpy as np
import torch.nn.functional as F
from scipy.signal import get_window
from librosa.util import pad_center, tiny,normalize
###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
n_fft=800, dtype=np.float32, norm=None):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = normalize(win_sq, norm=norm)**2
win_sq = pad_center(win_sq, n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x
class STFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=512, win_length=None,
window='hann'):
"""
This module implements an STFT using 1D convolution and 1D transpose convolutions.
This is a bit tricky so there are some cases that probably won't work as working
out the same sizes before and after in all overlap add setups is tough. Right now,
this code should work with hop lengths that are half the filter length (50% overlap
between frames).
Keyword Arguments:
filter_length {int} -- Length of filters used (default: {1024})
hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
equals the filter length). (default: {None})
window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
(default: {'hann'})
"""
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length if win_length else filter_length
self.window = window
self.forward_transform = None
self.pad_amount = int(self.filter_length / 2)
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
assert (filter_length >= self.win_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, self.win_length, fftbins=True)
fft_window = pad_center(fft_window, size=filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer('forward_basis', forward_basis.float())
self.register_buffer('inverse_basis', inverse_basis.float())
def transform(self, input_data):
"""Take input data (audio) to STFT domain.
Arguments:
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
Returns:
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
num_frequencies, num_frames)
phase {tensor} -- Phase of STFT with shape (num_batch,
num_frequencies, num_frames)
"""
num_batches = input_data.shape[0]
num_samples = input_data.shape[-1]
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
# print(1234,input_data.shape)
input_data = F.pad(input_data.unsqueeze(1),(self.pad_amount, self.pad_amount, 0, 0,0,0),mode='reflect').squeeze(1)
# print(2333,input_data.shape,self.forward_basis.shape,self.hop_length)
# pdb.set_trace()
forward_transform = F.conv1d(
input_data,
self.forward_basis,
stride=self.hop_length,
padding=0)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
# phase = torch.atan2(imag_part.data, real_part.data)
return magnitude#, phase
def inverse(self, magnitude, phase):
"""Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
by the ```transform``` function.
Arguments:
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
num_frequencies, num_frames)
phase {tensor} -- Phase of STFT with shape (num_batch,
num_frequencies, num_frames)
Returns:
inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
shape (num_batch, num_samples)
"""
recombine_magnitude_phase = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)
inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase,
self.inverse_basis,
stride=self.hop_length,
padding=0)
if self.window is not None:
window_sum = window_sumsquare(
self.window, magnitude.size(-1), hop_length=self.hop_length,
win_length=self.win_length, n_fft=self.filter_length,
dtype=np.float32)
# remove modulation effects
approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0])
window_sum = torch.from_numpy(window_sum).to(inverse_transform.device)
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[..., self.pad_amount:]
inverse_transform = inverse_transform[..., :self.num_samples]
inverse_transform = inverse_transform.squeeze(1)
return inverse_transform
def forward(self, input_data):
"""Take input data (audio) to STFT domain and then back to audio.
Arguments:
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
Returns:
reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
shape (num_batch, num_samples)
"""
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
from time import time as ttime
class BiGRU(nn.Module):
def __init__(self, input_features, hidden_features, num_layers):
super(BiGRU, self).__init__()
self.gru = nn.GRU(
input_features,
hidden_features,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)
def forward(self, x):
return self.gru(x)[0]
class ConvBlockRes(nn.Module):
def __init__(self, in_channels, out_channels, momentum=0.01):
super(ConvBlockRes, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False,
),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
self.is_shortcut = True
else:
self.is_shortcut = False
def forward(self, x):
if self.is_shortcut:
return self.conv(x) + self.shortcut(x)
else:
return self.conv(x) + x
class Encoder(nn.Module):
def __init__(
self,
in_channels,
in_size,
n_encoders,
kernel_size,
n_blocks,
out_channels=16,
momentum=0.01,
):
super(Encoder, self).__init__()
self.n_encoders = n_encoders
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
self.layers = nn.ModuleList()
self.latent_channels = []
for i in range(self.n_encoders):
self.layers.append(
ResEncoderBlock(
in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
)
)
self.latent_channels.append([out_channels, in_size])
in_channels = out_channels
out_channels *= 2
in_size //= 2
self.out_size = in_size
self.out_channel = out_channels
def forward(self, x):
concat_tensors = []
x = self.bn(x)
for i in range(self.n_encoders):
_, x = self.layers[i](x)
concat_tensors.append(_)
return x, concat_tensors
class ResEncoderBlock(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
):
super(ResEncoderBlock, self).__init__()
self.n_blocks = n_blocks
self.conv = nn.ModuleList()
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
for i in range(n_blocks - 1):
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
self.kernel_size = kernel_size
if self.kernel_size is not None:
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
def forward(self, x):
for i in range(self.n_blocks):
x = self.conv[i](x)
if self.kernel_size is not None:
return x, self.pool(x)
else:
return x
class Intermediate(nn.Module): #
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
super(Intermediate, self).__init__()
self.n_inters = n_inters
self.layers = nn.ModuleList()
self.layers.append(
ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
)
for i in range(self.n_inters - 1):
self.layers.append(
ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
)
def forward(self, x):
for i in range(self.n_inters):
x = self.layers[i](x)
return x
class ResDecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
super(ResDecoderBlock, self).__init__()
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
self.n_blocks = n_blocks
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=stride,
padding=(1, 1),
output_padding=out_padding,
bias=False,
),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
)
self.conv2 = nn.ModuleList()
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
for i in range(n_blocks - 1):
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
def forward(self, x, concat_tensor):
x = self.conv1(x)
x = torch.cat((x, concat_tensor), dim=1)
for i in range(self.n_blocks):
x = self.conv2[i](x)
return x
class Decoder(nn.Module):
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
super(Decoder, self).__init__()
self.layers = nn.ModuleList()
self.n_decoders = n_decoders
for i in range(self.n_decoders):
out_channels = in_channels // 2
self.layers.append(
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
)
in_channels = out_channels
def forward(self, x, concat_tensors):
for i in range(self.n_decoders):
x = self.layers[i](x, concat_tensors[-1 - i])
return x
class DeepUnet(nn.Module):
def __init__(
self,
kernel_size,
n_blocks,
en_de_layers=5,
inter_layers=4,
in_channels=1,
en_out_channels=16,
):
super(DeepUnet, self).__init__()
self.encoder = Encoder(
in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
)
self.intermediate = Intermediate(
self.encoder.out_channel // 2,
self.encoder.out_channel,
inter_layers,
n_blocks,
)
self.decoder = Decoder(
self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
)
def forward(self, x):
x, concat_tensors = self.encoder(x)
x = self.intermediate(x)
x = self.decoder(x, concat_tensors)
return x
class E2E(nn.Module):
def __init__(
self,
n_blocks,
n_gru,
kernel_size,
en_de_layers=5,
inter_layers=4,
in_channels=1,
en_out_channels=16,
):
super(E2E, self).__init__()
self.unet = DeepUnet(
kernel_size,
n_blocks,
en_de_layers,
inter_layers,
in_channels,
en_out_channels,
)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
BiGRU(3 * 128, 256, n_gru),
nn.Linear(512, 360),
nn.Dropout(0.25),
nn.Sigmoid(),
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
)
def forward(self, mel):
# print(mel.shape)
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
x = self.fc(x)
# print(x.shape)
return x
from librosa.filters import mel
class MelSpectrogram(torch.nn.Module):
def __init__(
self,
is_half,
n_mel_channels,
sampling_rate,
win_length,
hop_length,
n_fft=None,
mel_fmin=0,
mel_fmax=None,
clamp=1e-5,
):
super().__init__()
n_fft = win_length if n_fft is None else n_fft
self.hann_window = {}
mel_basis = mel(
sr=sampling_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax,
htk=True,
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.n_fft = win_length if n_fft is None else n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
self.clamp = clamp
self.is_half = is_half
def forward(self, audio, keyshift=0, speed=1, center=True):
factor = 2 ** (keyshift / 12)
n_fft_new = int(np.round(self.n_fft * factor))
win_length_new = int(np.round(self.win_length * factor))
hop_length_new = int(np.round(self.hop_length * speed))
keyshift_key = str(keyshift) + "_" + str(audio.device)
if keyshift_key not in self.hann_window:
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
# "cpu"if(audio.device.type=="privateuseone") else audio.device
audio.device
)
# fft = torch.stft(#doesn't support pytorch_dml
# # audio.cpu() if(audio.device.type=="privateuseone")else audio,
# audio,
# n_fft=n_fft_new,
# hop_length=hop_length_new,
# win_length=win_length_new,
# window=self.hann_window[keyshift_key],
# center=center,
# return_complex=True,
# )
# magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
# print(1111111111)
# print(222222222222222,audio.device,self.is_half)
if hasattr(self, "stft") == False:
# print(n_fft_new,hop_length_new,win_length_new,audio.shape)
self.stft=STFT(
filter_length=n_fft_new,
hop_length=hop_length_new,
win_length=win_length_new,
window='hann'
).to(audio.device)
magnitude = self.stft.transform(audio)#phase
# if (audio.device.type == "privateuseone"):
# magnitude=magnitude.to(audio.device)
if keyshift != 0:
size = self.n_fft // 2 + 1
resize = magnitude.size(1)
if resize < size:
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
mel_output = torch.matmul(self.mel_basis, magnitude)
if self.is_half == True:
mel_output = mel_output.half()
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
# print(log_mel_spec.device.type)
return log_mel_spec
class RMVPE:
def __init__(self, model_path, is_half, device=None):
self.resample_kernel = {}
self.resample_kernel = {}
self.is_half = is_half
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.mel_extractor = MelSpectrogram(
is_half, 128, 16000, 1024, 160, None, 30, 8000
).to(device)
if ("privateuseone" in str(device)):
import onnxruntime as ort
ort_session = ort.InferenceSession("rmvpe.onnx", providers=["DmlExecutionProvider"])
self.model=ort_session
else:
model = E2E(4, 1, (2, 2))
ckpt = torch.load(model_path, map_location="cpu")
model.load_state_dict(ckpt)
model.eval()
if is_half == True:
model = model.half()
self.model = model
self.model = self.model.to(device)
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
def mel2hidden(self, mel):
with torch.no_grad():
n_frames = mel.shape[-1]
mel = F.pad(
mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
)
if("privateuseone" in str(self.device) ):
onnx_input_name = self.model.get_inputs()[0].name
onnx_outputs_names = self.model.get_outputs()[0].name
hidden = self.model.run([onnx_outputs_names], input_feed={onnx_input_name: mel.cpu().numpy()})[0]
else:
hidden = self.model(mel)
return hidden[:, :n_frames]
def decode(self, hidden, thred=0.03):
cents_pred = self.to_local_average_cents(hidden, thred=thred)
f0 = 10 * (2 ** (cents_pred / 1200))
f0[f0 == 10] = 0
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
return f0
def infer_from_audio(self, audio, thred=0.03):
# torch.cuda.synchronize()
t0=ttime()
mel = self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True)
# print(123123123,mel.device.type)
# torch.cuda.synchronize()
t1=ttime()
hidden = self.mel2hidden(mel)
# torch.cuda.synchronize()
t2=ttime()
# print(234234,hidden.device.type)
if("privateuseone" not in str(self.device)):
hidden = hidden.squeeze(0).cpu().numpy()
else:
hidden=hidden[0]
if self.is_half == True:
hidden = hidden.astype("float32")
f0 = self.decode(hidden, thred=thred)
# torch.cuda.synchronize()
t3=ttime()
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
return f0
def to_local_average_cents(self, salience, thred=0.05):
# t0 = ttime()
center = np.argmax(salience, axis=1) # 帧长#index
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
# t1 = ttime()
center += 4
todo_salience = []
todo_cents_mapping = []
starts = center - 4
ends = center + 5
for idx in range(salience.shape[0]):
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
# t2 = ttime()
todo_salience = np.array(todo_salience) # 帧长9
todo_cents_mapping = np.array(todo_cents_mapping) # 帧长9
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
weight_sum = np.sum(todo_salience, 1) # 帧长
devided = product_sum / weight_sum # 帧长
# t3 = ttime()
maxx = np.max(salience, axis=1) # 帧长
devided[maxx <= thred] = 0
# t4 = ttime()
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
return devided
if __name__ == '__main__':
import soundfile as sf, librosa
audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav")
if len(audio.shape) > 1:
audio = librosa.to_mono(audio.transpose(1, 0))
audio_bak = audio.copy()
if sampling_rate != 16000:
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt"
thred = 0.03 # 0.01
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rmvpe = RMVPE(model_path,is_half=False, device=device)
t0=ttime()
f0 = rmvpe.infer_from_audio(audio, thred=thred)
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
t1=ttime()
print(f0.shape,t1-t0)

260
infer/lib/slicer2.py Normal file
View File

@ -0,0 +1,260 @@
import numpy as np
# This function is obtained from librosa.
def get_rms(
y,
frame_length=2048,
hop_length=512,
pad_mode="constant",
):
padding = (int(frame_length // 2), int(frame_length // 2))
y = np.pad(y, padding, mode=pad_mode)
axis = -1
# put our new within-frame axis at the end for now
out_strides = y.strides + tuple([y.strides[axis]])
# Reduce the shape on the framing axis
x_shape_trimmed = list(y.shape)
x_shape_trimmed[axis] -= frame_length - 1
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
if axis < 0:
target_axis = axis - 1
else:
target_axis = axis + 1
xw = np.moveaxis(xw, -1, target_axis)
# Downsample along the target axis
slices = [slice(None)] * xw.ndim
slices[axis] = slice(0, None, hop_length)
x = xw[tuple(slices)]
# Calculate power
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
return np.sqrt(power)
class Slicer:
def __init__(
self,
sr: int,
threshold: float = -40.0,
min_length: int = 5000,
min_interval: int = 300,
hop_size: int = 20,
max_sil_kept: int = 5000,
):
if not min_length >= min_interval >= hop_size:
raise ValueError(
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
)
if not max_sil_kept >= hop_size:
raise ValueError(
"The following condition must be satisfied: max_sil_kept >= hop_size"
)
min_interval = sr * min_interval / 1000
self.threshold = 10 ** (threshold / 20.0)
self.hop_size = round(sr * hop_size / 1000)
self.win_size = min(round(min_interval), 4 * self.hop_size)
self.min_length = round(sr * min_length / 1000 / self.hop_size)
self.min_interval = round(min_interval / self.hop_size)
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
def _apply_slice(self, waveform, begin, end):
if len(waveform.shape) > 1:
return waveform[
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
]
else:
return waveform[
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
]
# @timeit
def slice(self, waveform):
if len(waveform.shape) > 1:
samples = waveform.mean(axis=0)
else:
samples = waveform
if samples.shape[0] <= self.min_length:
return [waveform]
rms_list = get_rms(
y=samples, frame_length=self.win_size, hop_length=self.hop_size
).squeeze(0)
sil_tags = []
silence_start = None
clip_start = 0
for i, rms in enumerate(rms_list):
# Keep looping while frame is silent.
if rms < self.threshold:
# Record start of silent frames.
if silence_start is None:
silence_start = i
continue
# Keep looping while frame is not silent and silence start has not been recorded.
if silence_start is None:
continue
# Clear recorded silence start if interval is not enough or clip is too short
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
need_slice_middle = (
i - silence_start >= self.min_interval
and i - clip_start >= self.min_length
)
if not is_leading_silence and not need_slice_middle:
silence_start = None
continue
# Need slicing. Record the range of silent frames to be removed.
if i - silence_start <= self.max_sil_kept:
pos = rms_list[silence_start : i + 1].argmin() + silence_start
if silence_start == 0:
sil_tags.append((0, pos))
else:
sil_tags.append((pos, pos))
clip_start = pos
elif i - silence_start <= self.max_sil_kept * 2:
pos = rms_list[
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
].argmin()
pos += i - self.max_sil_kept
pos_l = (
rms_list[
silence_start : silence_start + self.max_sil_kept + 1
].argmin()
+ silence_start
)
pos_r = (
rms_list[i - self.max_sil_kept : i + 1].argmin()
+ i
- self.max_sil_kept
)
if silence_start == 0:
sil_tags.append((0, pos_r))
clip_start = pos_r
else:
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
clip_start = max(pos_r, pos)
else:
pos_l = (
rms_list[
silence_start : silence_start + self.max_sil_kept + 1
].argmin()
+ silence_start
)
pos_r = (
rms_list[i - self.max_sil_kept : i + 1].argmin()
+ i
- self.max_sil_kept
)
if silence_start == 0:
sil_tags.append((0, pos_r))
else:
sil_tags.append((pos_l, pos_r))
clip_start = pos_r
silence_start = None
# Deal with trailing silence.
total_frames = rms_list.shape[0]
if (
silence_start is not None
and total_frames - silence_start >= self.min_interval
):
silence_end = min(total_frames, silence_start + self.max_sil_kept)
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
sil_tags.append((pos, total_frames + 1))
# Apply and return slices.
if len(sil_tags) == 0:
return [waveform]
else:
chunks = []
if sil_tags[0][0] > 0:
chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
for i in range(len(sil_tags) - 1):
chunks.append(
self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0])
)
if sil_tags[-1][1] < total_frames:
chunks.append(
self._apply_slice(waveform, sil_tags[-1][1], total_frames)
)
return chunks
def main():
import os.path
from argparse import ArgumentParser
import librosa
import soundfile
parser = ArgumentParser()
parser.add_argument("audio", type=str, help="The audio to be sliced")
parser.add_argument(
"--out", type=str, help="Output directory of the sliced audio clips"
)
parser.add_argument(
"--db_thresh",
type=float,
required=False,
default=-40,
help="The dB threshold for silence detection",
)
parser.add_argument(
"--min_length",
type=int,
required=False,
default=5000,
help="The minimum milliseconds required for each sliced audio clip",
)
parser.add_argument(
"--min_interval",
type=int,
required=False,
default=300,
help="The minimum milliseconds for a silence part to be sliced",
)
parser.add_argument(
"--hop_size",
type=int,
required=False,
default=10,
help="Frame length in milliseconds",
)
parser.add_argument(
"--max_sil_kept",
type=int,
required=False,
default=500,
help="The maximum silence length kept around the sliced clip, presented in milliseconds",
)
args = parser.parse_args()
out = args.out
if out is None:
out = os.path.dirname(os.path.abspath(args.audio))
audio, sr = librosa.load(args.audio, sr=None, mono=False)
slicer = Slicer(
sr=sr,
threshold=args.db_thresh,
min_length=args.min_length,
min_interval=args.min_interval,
hop_size=args.hop_size,
max_sil_kept=args.max_sil_kept,
)
chunks = slicer.slice(audio)
if not os.path.exists(out):
os.makedirs(out)
for i, chunk in enumerate(chunks):
if len(chunk.shape) > 1:
chunk = chunk.T
soundfile.write(
os.path.join(
out,
f"%s_%d.wav"
% (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
),
chunk,
sr,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,512 @@
import os, traceback
import numpy as np
import torch
import torch.utils.data
from infer.lib.train.mel_processing import spectrogram_torch
from infer.lib.train.utils import load_wav_to_torch, load_filepaths_and_text
class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
"""
1) loads audio, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""
def __init__(self, audiopaths_and_text, hparams):
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
self.filter_length = hparams.filter_length
self.hop_length = hparams.hop_length
self.win_length = hparams.win_length
self.sampling_rate = hparams.sampling_rate
self.min_text_len = getattr(hparams, "min_text_len", 1)
self.max_text_len = getattr(hparams, "max_text_len", 5000)
self._filter()
def _filter(self):
"""
Filter text & store spec lengths
"""
# Store spectrogram lengths for Bucketing
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
# spec_length = wav_length // hop_length
audiopaths_and_text_new = []
lengths = []
for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
self.audiopaths_and_text = audiopaths_and_text_new
self.lengths = lengths
def get_sid(self, sid):
sid = torch.LongTensor([int(sid)])
return sid
def get_audio_text_pair(self, audiopath_and_text):
# separate filename and text
file = audiopath_and_text[0]
phone = audiopath_and_text[1]
pitch = audiopath_and_text[2]
pitchf = audiopath_and_text[3]
dv = audiopath_and_text[4]
phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
spec, wav = self.get_audio(file)
dv = self.get_sid(dv)
len_phone = phone.size()[0]
len_spec = spec.size()[-1]
# print(123,phone.shape,pitch.shape,spec.shape)
if len_phone != len_spec:
len_min = min(len_phone, len_spec)
# amor
len_wav = len_min * self.hop_length
spec = spec[:, :len_min]
wav = wav[:, :len_wav]
phone = phone[:len_min, :]
pitch = pitch[:len_min]
pitchf = pitchf[:len_min]
return (spec, wav, phone, pitch, pitchf, dv)
def get_labels(self, phone, pitch, pitchf):
phone = np.load(phone)
phone = np.repeat(phone, 2, axis=0)
pitch = np.load(pitch)
pitchf = np.load(pitchf)
n_num = min(phone.shape[0], 900) # DistributedBucketSampler
# print(234,phone.shape,pitch.shape)
phone = phone[:n_num, :]
pitch = pitch[:n_num]
pitchf = pitchf[:n_num]
phone = torch.FloatTensor(phone)
pitch = torch.LongTensor(pitch)
pitchf = torch.FloatTensor(pitchf)
return phone, pitch, pitchf
def get_audio(self, filename):
audio, sampling_rate = load_wav_to_torch(filename)
if sampling_rate != self.sampling_rate:
raise ValueError(
"{} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate
)
)
audio_norm = audio
# audio_norm = audio / self.max_wav_value
# audio_norm = audio / np.abs(audio).max()
audio_norm = audio_norm.unsqueeze(0)
spec_filename = filename.replace(".wav", ".spec.pt")
if os.path.exists(spec_filename):
try:
spec = torch.load(spec_filename)
except:
print(spec_filename, traceback.format_exc())
spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
else:
spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
return spec, audio_norm
def __getitem__(self, index):
return self.get_audio_text_pair(self.audiopaths_and_text[index])
def __len__(self):
return len(self.audiopaths_and_text)
class TextAudioCollateMultiNSFsid:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
def __call__(self, batch):
"""Collate's training batch from normalized text and aduio
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
)
max_spec_len = max([x[0].size(1) for x in batch])
max_wave_len = max([x[1].size(1) for x in batch])
spec_lengths = torch.LongTensor(len(batch))
wave_lengths = torch.LongTensor(len(batch))
spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
spec_padded.zero_()
wave_padded.zero_()
max_phone_len = max([x[2].size(0) for x in batch])
phone_lengths = torch.LongTensor(len(batch))
phone_padded = torch.FloatTensor(
len(batch), max_phone_len, batch[0][2].shape[1]
) # (spec, wav, phone, pitch)
pitch_padded = torch.LongTensor(len(batch), max_phone_len)
pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
phone_padded.zero_()
pitch_padded.zero_()
pitchf_padded.zero_()
# dv = torch.FloatTensor(len(batch), 256)#gin=256
sid = torch.LongTensor(len(batch))
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
spec = row[0]
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wave = row[1]
wave_padded[i, :, : wave.size(1)] = wave
wave_lengths[i] = wave.size(1)
phone = row[2]
phone_padded[i, : phone.size(0), :] = phone
phone_lengths[i] = phone.size(0)
pitch = row[3]
pitch_padded[i, : pitch.size(0)] = pitch
pitchf = row[4]
pitchf_padded[i, : pitchf.size(0)] = pitchf
# dv[i] = row[5]
sid[i] = row[5]
return (
phone_padded,
phone_lengths,
pitch_padded,
pitchf_padded,
spec_padded,
spec_lengths,
wave_padded,
wave_lengths,
# dv
sid,
)
class TextAudioLoader(torch.utils.data.Dataset):
"""
1) loads audio, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""
def __init__(self, audiopaths_and_text, hparams):
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
self.filter_length = hparams.filter_length
self.hop_length = hparams.hop_length
self.win_length = hparams.win_length
self.sampling_rate = hparams.sampling_rate
self.min_text_len = getattr(hparams, "min_text_len", 1)
self.max_text_len = getattr(hparams, "max_text_len", 5000)
self._filter()
def _filter(self):
"""
Filter text & store spec lengths
"""
# Store spectrogram lengths for Bucketing
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
# spec_length = wav_length // hop_length
audiopaths_and_text_new = []
lengths = []
for audiopath, text, dv in self.audiopaths_and_text:
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
audiopaths_and_text_new.append([audiopath, text, dv])
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
self.audiopaths_and_text = audiopaths_and_text_new
self.lengths = lengths
def get_sid(self, sid):
sid = torch.LongTensor([int(sid)])
return sid
def get_audio_text_pair(self, audiopath_and_text):
# separate filename and text
file = audiopath_and_text[0]
phone = audiopath_and_text[1]
dv = audiopath_and_text[2]
phone = self.get_labels(phone)
spec, wav = self.get_audio(file)
dv = self.get_sid(dv)
len_phone = phone.size()[0]
len_spec = spec.size()[-1]
if len_phone != len_spec:
len_min = min(len_phone, len_spec)
len_wav = len_min * self.hop_length
spec = spec[:, :len_min]
wav = wav[:, :len_wav]
phone = phone[:len_min, :]
return (spec, wav, phone, dv)
def get_labels(self, phone):
phone = np.load(phone)
phone = np.repeat(phone, 2, axis=0)
n_num = min(phone.shape[0], 900) # DistributedBucketSampler
phone = phone[:n_num, :]
phone = torch.FloatTensor(phone)
return phone
def get_audio(self, filename):
audio, sampling_rate = load_wav_to_torch(filename)
if sampling_rate != self.sampling_rate:
raise ValueError(
"{} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate
)
)
audio_norm = audio
# audio_norm = audio / self.max_wav_value
# audio_norm = audio / np.abs(audio).max()
audio_norm = audio_norm.unsqueeze(0)
spec_filename = filename.replace(".wav", ".spec.pt")
if os.path.exists(spec_filename):
try:
spec = torch.load(spec_filename)
except:
print(spec_filename, traceback.format_exc())
spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
else:
spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
return spec, audio_norm
def __getitem__(self, index):
return self.get_audio_text_pair(self.audiopaths_and_text[index])
def __len__(self):
return len(self.audiopaths_and_text)
class TextAudioCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
def __call__(self, batch):
"""Collate's training batch from normalized text and aduio
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
)
max_spec_len = max([x[0].size(1) for x in batch])
max_wave_len = max([x[1].size(1) for x in batch])
spec_lengths = torch.LongTensor(len(batch))
wave_lengths = torch.LongTensor(len(batch))
spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
spec_padded.zero_()
wave_padded.zero_()
max_phone_len = max([x[2].size(0) for x in batch])
phone_lengths = torch.LongTensor(len(batch))
phone_padded = torch.FloatTensor(
len(batch), max_phone_len, batch[0][2].shape[1]
)
phone_padded.zero_()
sid = torch.LongTensor(len(batch))
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
spec = row[0]
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wave = row[1]
wave_padded[i, :, : wave.size(1)] = wave
wave_lengths[i] = wave.size(1)
phone = row[2]
phone_padded[i, : phone.size(0), :] = phone
phone_lengths[i] = phone.size(0)
sid[i] = row[3]
return (
phone_padded,
phone_lengths,
spec_padded,
spec_lengths,
wave_padded,
wave_lengths,
sid,
)
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
"""
Maintain similar input lengths in a batch.
Length groups are specified by boundaries.
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
It removes samples which are not included in the boundaries.
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
"""
def __init__(
self,
dataset,
batch_size,
boundaries,
num_replicas=None,
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths
self.batch_size = batch_size
self.boundaries = boundaries
self.buckets, self.num_samples_per_bucket = self._create_buckets()
self.total_size = sum(self.num_samples_per_bucket)
self.num_samples = self.total_size // self.num_replicas
def _create_buckets(self):
buckets = [[] for _ in range(len(self.boundaries) - 1)]
for i in range(len(self.lengths)):
length = self.lengths[i]
idx_bucket = self._bisect(length)
if idx_bucket != -1:
buckets[idx_bucket].append(i)
for i in range(len(buckets) - 1, -1, -1): #
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i + 1)
num_samples_per_bucket = []
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
if self.shuffle:
for bucket in self.buckets:
indices.append(torch.randperm(len(bucket), generator=g).tolist())
else:
for bucket in self.buckets:
indices.append(list(range(len(bucket))))
batches = []
for i in range(len(self.buckets)):
bucket = self.buckets[i]
len_bucket = len(bucket)
ids_bucket = indices[i]
num_samples_bucket = self.num_samples_per_bucket[i]
# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket
ids_bucket = (
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
# subsample
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
# batching
for j in range(len(ids_bucket) // self.batch_size):
batch = [
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
]
batches.append(batch)
if self.shuffle:
batch_ids = torch.randperm(len(batches), generator=g).tolist()
batches = [batches[i] for i in batch_ids]
self.batches = batches
assert len(self.batches) * self.batch_size == self.num_samples
return iter(self.batches)
def _bisect(self, x, lo=0, hi=None):
if hi is None:
hi = len(self.boundaries) - 1
if hi > lo:
mid = (hi + lo) // 2
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
return mid
elif x <= self.boundaries[mid]:
return self._bisect(x, lo, mid)
else:
return self._bisect(x, mid + 1, hi)
else:
return -1
def __len__(self):
return self.num_samples // self.batch_size

58
infer/lib/train/losses.py Normal file
View File

@ -0,0 +1,58 @@
import torch
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l

View File

@ -0,0 +1,130 @@
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
def spectral_de_normalize_torch(magnitudes):
return dynamic_range_decompression_torch(magnitudes)
# Reusable banks
mel_basis = {}
hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram.
Args:
y :: (B, T) - Audio waveforms
n_fft
sampling_rate
hop_size
win_size
center
Returns:
:: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
"""
# Validation
if torch.min(y) < -1.07:
print("min value is ", torch.min(y))
if torch.max(y) > 1.07:
print("max value is ", torch.max(y))
# Window - Cache if needed
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
# Padding
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
# Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
# Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
# MelBasis - Cache if needed
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
# Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame)
melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
melspec = spectral_normalize_torch(melspec)
return melspec
def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
"""Convert waveform into Mel-frequency Log-amplitude spectrogram.
Args:
y :: (B, T) - Waveforms
Returns:
melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
"""
# Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
# Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
return melspec

View File

@ -0,0 +1,259 @@
import torch, traceback, os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from collections import OrderedDict
from i18n.i18n import I18nAuto
i18n = I18nAuto()
def savee(ckpt, sr, if_f0, name, epoch, version, hps, i18n):
try:
opt = OrderedDict()
opt["weight"] = {}
for key in ckpt.keys():
if "enc_q" in key:
continue
opt["weight"][key] = ckpt[key].half()
opt["config"] = [
hps.data.filter_length // 2 + 1,
32,
hps.model.inter_channels,
hps.model.hidden_channels,
hps.model.filter_channels,
hps.model.n_heads,
hps.model.n_layers,
hps.model.kernel_size,
hps.model.p_dropout,
hps.model.resblock,
hps.model.resblock_kernel_sizes,
hps.model.resblock_dilation_sizes,
hps.model.upsample_rates,
hps.model.upsample_initial_channel,
hps.model.upsample_kernel_sizes,
hps.model.spk_embed_dim,
hps.model.gin_channels,
hps.data.sampling_rate,
]
opt["info"] = "%sepoch" % epoch
opt["sr"] = sr
opt["f0"] = if_f0
opt["version"] = version
torch.save(opt, "weights/%s.pth" % name)
return "Success."
except:
return traceback.format_exc()
def show_info(path):
try:
a = torch.load(path, map_location="cpu")
return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
a.get("info", "None"),
a.get("sr", "None"),
a.get("f0", "None"),
a.get("version", "None"),
)
except:
return traceback.format_exc()
def extract_small_model(path, name, sr, if_f0, info, version):
try:
ckpt = torch.load(path, map_location="cpu")
if "model" in ckpt:
ckpt = ckpt["model"]
opt = OrderedDict()
opt["weight"] = {}
for key in ckpt.keys():
if "enc_q" in key:
continue
opt["weight"][key] = ckpt[key].half()
if sr == "40k":
opt["config"] = [
1025,
32,
192,
192,
768,
2,
6,
3,
0,
"1",
[3, 7, 11],
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
[10, 10, 2, 2],
512,
[16, 16, 4, 4],
109,
256,
40000,
]
elif sr == "48k":
if version == "v1":
opt["config"] = [
1025,
32,
192,
192,
768,
2,
6,
3,
0,
"1",
[3, 7, 11],
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
[10, 6, 2, 2, 2],
512,
[16, 16, 4, 4, 4],
109,
256,
48000,
]
else:
opt["config"] = [
1025,
32,
192,
192,
768,
2,
6,
3,
0,
"1",
[3, 7, 11],
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
[12, 10, 2, 2],
512,
[24, 20, 4, 4],
109,
256,
48000,
]
elif sr == "32k":
if version == "v1":
opt["config"] = [
513,
32,
192,
192,
768,
2,
6,
3,
0,
"1",
[3, 7, 11],
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
[10, 4, 2, 2, 2],
512,
[16, 16, 4, 4, 4],
109,
256,
32000,
]
else:
opt["config"] = [
513,
32,
192,
192,
768,
2,
6,
3,
0,
"1",
[3, 7, 11],
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
[10, 8, 2, 2],
512,
[20, 16, 4, 4],
109,
256,
32000,
]
if info == "":
info = "Extracted model."
opt["info"] = info
opt["version"] = version
opt["sr"] = sr
opt["f0"] = int(if_f0)
torch.save(opt, "weights/%s.pth" % name)
return "Success."
except:
return traceback.format_exc()
def change_info(path, info, name):
try:
ckpt = torch.load(path, map_location="cpu")
ckpt["info"] = info
if name == "":
name = os.path.basename(path)
torch.save(ckpt, "weights/%s" % name)
return "Success."
except:
return traceback.format_exc()
def merge(path1, path2, alpha1, sr, f0, info, name, version):
try:
def extract(ckpt):
a = ckpt["model"]
opt = OrderedDict()
opt["weight"] = {}
for key in a.keys():
if "enc_q" in key:
continue
opt["weight"][key] = a[key]
return opt
ckpt1 = torch.load(path1, map_location="cpu")
ckpt2 = torch.load(path2, map_location="cpu")
cfg = ckpt1["config"]
if "model" in ckpt1:
ckpt1 = extract(ckpt1)
else:
ckpt1 = ckpt1["weight"]
if "model" in ckpt2:
ckpt2 = extract(ckpt2)
else:
ckpt2 = ckpt2["weight"]
if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
return "Fail to merge the models. The model architectures are not the same."
opt = OrderedDict()
opt["weight"] = {}
for key in ckpt1.keys():
# try:
if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
opt["weight"][key] = (
alpha1 * (ckpt1[key][:min_shape0].float())
+ (1 - alpha1) * (ckpt2[key][:min_shape0].float())
).half()
else:
opt["weight"][key] = (
alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
).half()
# except:
# pdb.set_trace()
opt["config"] = cfg
"""
if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
"""
opt["sr"] = sr
opt["f0"] = 1 if f0 == i18n("") else 0
opt["version"] = version
opt["info"] = info
torch.save(opt, "weights/%s.pth" % name)
return "Success."
except:
return traceback.format_exc()

487
infer/lib/train/utils.py Normal file
View File

@ -0,0 +1,487 @@
import os, traceback
import glob
import sys
import argparse
import logging
import json
import subprocess
import numpy as np
from scipy.io.wavfile import read
import torch
MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
##################
def go(model, bkey):
saved_state_dict = checkpoint_dict[bkey]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
"shape-%s-mismatch|need-%s|get-%s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
raise KeyError
except:
# logger.info(traceback.format_exc())
logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
return model
go(combd, "combd")
model = go(sbd, "sbd")
#############
logger.info("Loaded model weights")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None and load_opt == 1
): ###加载不了如果是空的的话重新初始化可能还会影响lr时间表的更新因此在train文件最外围catch
# try:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
# except:
# traceback.print_exc()
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
# def load_checkpoint(checkpoint_path, model, optimizer=None):
# assert os.path.isfile(checkpoint_path)
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
# iteration = checkpoint_dict['iteration']
# learning_rate = checkpoint_dict['learning_rate']
# if optimizer is not None:
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
# # print(1111)
# saved_state_dict = checkpoint_dict['model']
# # print(1111)
#
# if hasattr(model, 'module'):
# state_dict = model.module.state_dict()
# else:
# state_dict = model.state_dict()
# new_state_dict= {}
# for k, v in state_dict.items():
# try:
# new_state_dict[k] = saved_state_dict[k]
# except:
# logger.info("%s is not in the checkpoint" % k)
# new_state_dict[k] = v
# if hasattr(model, 'module'):
# model.module.load_state_dict(new_state_dict)
# else:
# model.load_state_dict(new_state_dict)
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
# checkpoint_path, iteration))
# return model, optimizer, learning_rate, iteration
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
"shape-%s-mismatch|need-%s|get-%s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
raise KeyError
except:
# logger.info(traceback.format_exc())
logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
logger.info("Loaded model weights")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None and load_opt == 1
): ###加载不了如果是空的的话重新初始化可能还会影响lr时间表的更新因此在train文件最外围catch
# try:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
# except:
# traceback.print_exc()
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(combd, "module"):
state_dict_combd = combd.module.state_dict()
else:
state_dict_combd = combd.state_dict()
if hasattr(sbd, "module"):
state_dict_sbd = sbd.module.state_dict()
else:
state_dict_sbd = sbd.state_dict()
torch.save(
{
"combd": state_dict_combd,
"sbd": state_dict_sbd,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
print(x)
return x
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True):
"""
todo:
结尾七人组
保存频率总epoch done
bs done
pretrainGpretrainD done
卡号os.en["CUDA_VISIBLE_DEVICES"] done
if_latest done
模型if_f0 done
采样率自动选择config done
是否缓存数据集进GPU:if_cache_data_in_gpu done
-m:
自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done
-c不要了
"""
parser = argparse.ArgumentParser()
# parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration')
parser.add_argument(
"-se",
"--save_every_epoch",
type=int,
required=True,
help="checkpoint save frequency (epoch)",
)
parser.add_argument(
"-te", "--total_epoch", type=int, required=True, help="total_epoch"
)
parser.add_argument(
"-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
)
parser.add_argument(
"-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
)
parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
parser.add_argument(
"-bs", "--batch_size", type=int, required=True, help="batch size"
)
parser.add_argument(
"-e", "--experiment_dir", type=str, required=True, help="experiment dir"
) # -m
parser.add_argument(
"-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
)
parser.add_argument(
"-sw",
"--save_every_weights",
type=str,
default="0",
help="save the extracted model in weights directory when saving checkpoints",
)
parser.add_argument(
"-v", "--version", type=str, required=True, help="model version"
)
parser.add_argument(
"-f0",
"--if_f0",
type=int,
required=True,
help="use f0 as one of the inputs of the model, 1 or 0",
)
parser.add_argument(
"-l",
"--if_latest",
type=int,
required=True,
help="if only save the latest G/D pth file, 1 or 0",
)
parser.add_argument(
"-c",
"--if_cache_data_in_gpu",
type=int,
required=True,
help="if caching the dataset in GPU memory, 1 or 0",
)
args = parser.parse_args()
name = args.experiment_dir
experiment_dir = os.path.join("./logs", args.experiment_dir)
if not os.path.exists(experiment_dir):
os.makedirs(experiment_dir)
if args.version == "v1" or args.sample_rate == "40k":
config_path = "configs/%s.json" % args.sample_rate
else:
config_path = "configs/%s_v2.json" % args.sample_rate
config_save_path = os.path.join(experiment_dir, "config.json")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = hparams.experiment_dir = experiment_dir
hparams.save_every_epoch = args.save_every_epoch
hparams.name = name
hparams.total_epoch = args.total_epoch
hparams.pretrainG = args.pretrainG
hparams.pretrainD = args.pretrainD
hparams.version = args.version
hparams.gpus = args.gpus
hparams.train.batch_size = args.batch_size
hparams.sample_rate = args.sample_rate
hparams.if_f0 = args.if_f0
hparams.if_latest = args.if_latest
hparams.save_every_weights = args.save_every_weights
hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
hparams.data.training_files = "%s/filelist.txt" % experiment_dir
return hparams
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()