diff --git a/infer/lib/audio.py b/infer/lib/audio.py new file mode 100644 index 0000000..776939d --- /dev/null +++ b/infer/lib/audio.py @@ -0,0 +1,21 @@ +import ffmpeg +import numpy as np + + +def load_audio(file, sr): + try: + # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + file = ( + file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") + ) # 防止小白拷路径头尾带了空格和"和回车 + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except Exception as e: + raise RuntimeError(f"Failed to load audio: {e}") + + return np.frombuffer(out, np.float32).flatten() diff --git a/infer/lib/rmvpe.py b/infer/lib/rmvpe.py new file mode 100644 index 0000000..25dcb8c --- /dev/null +++ b/infer/lib/rmvpe.py @@ -0,0 +1,654 @@ +import torch, numpy as np,pdb +import torch.nn as nn +import torch.nn.functional as F +import torch,pdb +import numpy as np +import torch.nn.functional as F +from scipy.signal import get_window +from librosa.util import pad_center, tiny,normalize +###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + n_frames : int > 0 + The number of analysis frames + hop_length : int > 0 + The number of samples to advance between frames + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + n_fft : int > 0 + The length of each analysis frame. + dtype : np.dtype + The data type of the output + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = normalize(win_sq, norm=norm)**2 + win_sq = pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + +class STFT(torch.nn.Module): + def __init__(self, filter_length=1024, hop_length=512, win_length=None, + window='hann'): + """ + This module implements an STFT using 1D convolution and 1D transpose convolutions. + This is a bit tricky so there are some cases that probably won't work as working + out the same sizes before and after in all overlap add setups is tough. Right now, + this code should work with hop lengths that are half the filter length (50% overlap + between frames). + + Keyword Arguments: + filter_length {int} -- Length of filters used (default: {1024}) + hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512}) + win_length {[type]} -- Length of the window function applied to each frame (if not specified, it + equals the filter length). (default: {None}) + window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris) + (default: {'hann'}) + """ + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length if win_length else filter_length + self.window = window + self.forward_transform = None + self.pad_amount = int(self.filter_length / 2) + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),np.imag(fourier_basis[:cutoff, :])]) + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + assert (filter_length >= self.win_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, self.win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + """Take input data (audio) to STFT domain. + + Arguments: + input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) + + Returns: + magnitude {tensor} -- Magnitude of STFT with shape (num_batch, + num_frequencies, num_frames) + phase {tensor} -- Phase of STFT with shape (num_batch, + num_frequencies, num_frames) + """ + num_batches = input_data.shape[0] + num_samples = input_data.shape[-1] + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + # print(1234,input_data.shape) + input_data = F.pad(input_data.unsqueeze(1),(self.pad_amount, self.pad_amount, 0, 0,0,0),mode='reflect').squeeze(1) + # print(2333,input_data.shape,self.forward_basis.shape,self.hop_length) + # pdb.set_trace() + forward_transform = F.conv1d( + input_data, + self.forward_basis, + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + # phase = torch.atan2(imag_part.data, real_part.data) + + return magnitude#, phase + + def inverse(self, magnitude, phase): + """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced + by the ```transform``` function. + + Arguments: + magnitude {tensor} -- Magnitude of STFT with shape (num_batch, + num_frequencies, num_frames) + phase {tensor} -- Phase of STFT with shape (num_batch, + num_frequencies, num_frames) + + Returns: + inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of + shape (num_batch, num_samples) + """ + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + self.inverse_basis, + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.from_numpy(window_sum).to(inverse_transform.device) + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[..., self.pad_amount:] + inverse_transform = inverse_transform[..., :self.num_samples] + inverse_transform = inverse_transform.squeeze(1) + + return inverse_transform + + def forward(self, input_data): + """Take input data (audio) to STFT domain and then back to audio. + + Arguments: + input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) + + Returns: + reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of + shape (num_batch, num_samples) + """ + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction +from time import time as ttime +class BiGRU(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiGRU, self).__init__() + self.gru = nn.GRU( + input_features, + hidden_features, + num_layers=num_layers, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.gru(x)[0] + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, momentum=0.01): + super(ConvBlockRes, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + if in_channels != out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) + self.is_shortcut = True + else: + self.is_shortcut = False + + def forward(self, x): + if self.is_shortcut: + return self.conv(x) + self.shortcut(x) + else: + return self.conv(x) + x + + +class Encoder(nn.Module): + def __init__( + self, + in_channels, + in_size, + n_encoders, + kernel_size, + n_blocks, + out_channels=16, + momentum=0.01, + ): + super(Encoder, self).__init__() + self.n_encoders = n_encoders + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + self.latent_channels = [] + for i in range(self.n_encoders): + self.layers.append( + ResEncoderBlock( + in_channels, out_channels, kernel_size, n_blocks, momentum=momentum + ) + ) + self.latent_channels.append([out_channels, in_size]) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x): + concat_tensors = [] + x = self.bn(x) + for i in range(self.n_encoders): + _, x = self.layers[i](x) + concat_tensors.append(_) + return x, concat_tensors + + +class ResEncoderBlock(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01 + ): + super(ResEncoderBlock, self).__init__() + self.n_blocks = n_blocks + self.conv = nn.ModuleList() + self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) + self.kernel_size = kernel_size + if self.kernel_size is not None: + self.pool = nn.AvgPool2d(kernel_size=kernel_size) + + def forward(self, x): + for i in range(self.n_blocks): + x = self.conv[i](x) + if self.kernel_size is not None: + return x, self.pool(x) + else: + return x + + +class Intermediate(nn.Module): # + def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): + super(Intermediate, self).__init__() + self.n_inters = n_inters + self.layers = nn.ModuleList() + self.layers.append( + ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum) + ) + for i in range(self.n_inters - 1): + self.layers.append( + ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum) + ) + + def forward(self, x): + for i in range(self.n_inters): + x = self.layers[i](x) + return x + + +class ResDecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): + super(ResDecoderBlock, self).__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.n_blocks = n_blocks + self.conv1 = nn.Sequential( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=stride, + padding=(1, 1), + output_padding=out_padding, + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.conv2 = nn.ModuleList() + self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) + + def forward(self, x, concat_tensor): + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for i in range(self.n_blocks): + x = self.conv2[i](x) + return x + + +class Decoder(nn.Module): + def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): + super(Decoder, self).__init__() + self.layers = nn.ModuleList() + self.n_decoders = n_decoders + for i in range(self.n_decoders): + out_channels = in_channels // 2 + self.layers.append( + ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum) + ) + in_channels = out_channels + + def forward(self, x, concat_tensors): + for i in range(self.n_decoders): + x = self.layers[i](x, concat_tensors[-1 - i]) + return x + + +class DeepUnet(nn.Module): + def __init__( + self, + kernel_size, + n_blocks, + en_de_layers=5, + inter_layers=4, + in_channels=1, + en_out_channels=16, + ): + super(DeepUnet, self).__init__() + self.encoder = Encoder( + in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels + ) + self.intermediate = Intermediate( + self.encoder.out_channel // 2, + self.encoder.out_channel, + inter_layers, + n_blocks, + ) + self.decoder = Decoder( + self.encoder.out_channel, en_de_layers, kernel_size, n_blocks + ) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + x = self.decoder(x, concat_tensors) + return x + + +class E2E(nn.Module): + def __init__( + self, + n_blocks, + n_gru, + kernel_size, + en_de_layers=5, + inter_layers=4, + in_channels=1, + en_out_channels=16, + ): + super(E2E, self).__init__() + self.unet = DeepUnet( + kernel_size, + n_blocks, + en_de_layers, + inter_layers, + in_channels, + en_out_channels, + ) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * 128, 256, n_gru), + nn.Linear(512, 360), + nn.Dropout(0.25), + nn.Sigmoid(), + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid() + ) + + def forward(self, mel): + # print(mel.shape) + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + x = self.fc(x) + # print(x.shape) + return x + + +from librosa.filters import mel + + +class MelSpectrogram(torch.nn.Module): + def __init__( + self, + is_half, + n_mel_channels, + sampling_rate, + win_length, + hop_length, + n_fft=None, + mel_fmin=0, + mel_fmax=None, + clamp=1e-5, + ): + super().__init__() + n_fft = win_length if n_fft is None else n_fft + self.hann_window = {} + mel_basis = mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True, + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.n_fft = win_length if n_fft is None else n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.clamp = clamp + self.is_half = is_half + + def forward(self, audio, keyshift=0, speed=1, center=True): + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.n_fft * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + keyshift_key = str(keyshift) + "_" + str(audio.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to( + # "cpu"if(audio.device.type=="privateuseone") else audio.device + audio.device + ) + # fft = torch.stft(#doesn't support pytorch_dml + # # audio.cpu() if(audio.device.type=="privateuseone")else audio, + # audio, + # n_fft=n_fft_new, + # hop_length=hop_length_new, + # win_length=win_length_new, + # window=self.hann_window[keyshift_key], + # center=center, + # return_complex=True, + # ) + # magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + # print(1111111111) + # print(222222222222222,audio.device,self.is_half) + if hasattr(self, "stft") == False: + # print(n_fft_new,hop_length_new,win_length_new,audio.shape) + self.stft=STFT( + filter_length=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window='hann' + ).to(audio.device) + magnitude = self.stft.transform(audio)#phase + # if (audio.device.type == "privateuseone"): + # magnitude=magnitude.to(audio.device) + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + mel_output = torch.matmul(self.mel_basis, magnitude) + if self.is_half == True: + mel_output = mel_output.half() + log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + # print(log_mel_spec.device.type) + return log_mel_spec + + +class RMVPE: + def __init__(self, model_path, is_half, device=None): + self.resample_kernel = {} + self.resample_kernel = {} + self.is_half = is_half + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + self.mel_extractor = MelSpectrogram( + is_half, 128, 16000, 1024, 160, None, 30, 8000 + ).to(device) + if ("privateuseone" in str(device)): + import onnxruntime as ort + ort_session = ort.InferenceSession("rmvpe.onnx", providers=["DmlExecutionProvider"]) + self.model=ort_session + else: + model = E2E(4, 1, (2, 2)) + ckpt = torch.load(model_path, map_location="cpu") + model.load_state_dict(ckpt) + model.eval() + if is_half == True: + model = model.half() + self.model = model + self.model = self.model.to(device) + cents_mapping = 20 * np.arange(360) + 1997.3794084376191 + self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368 + + def mel2hidden(self, mel): + with torch.no_grad(): + n_frames = mel.shape[-1] + mel = F.pad( + mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect" + ) + if("privateuseone" in str(self.device) ): + onnx_input_name = self.model.get_inputs()[0].name + onnx_outputs_names = self.model.get_outputs()[0].name + hidden = self.model.run([onnx_outputs_names], input_feed={onnx_input_name: mel.cpu().numpy()})[0] + else: + hidden = self.model(mel) + return hidden[:, :n_frames] + + def decode(self, hidden, thred=0.03): + cents_pred = self.to_local_average_cents(hidden, thred=thred) + f0 = 10 * (2 ** (cents_pred / 1200)) + f0[f0 == 10] = 0 + # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]) + return f0 + + def infer_from_audio(self, audio, thred=0.03): + # torch.cuda.synchronize() + t0=ttime() + mel = self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True) + # print(123123123,mel.device.type) + # torch.cuda.synchronize() + t1=ttime() + hidden = self.mel2hidden(mel) + # torch.cuda.synchronize() + t2=ttime() + # print(234234,hidden.device.type) + if("privateuseone" not in str(self.device)): + hidden = hidden.squeeze(0).cpu().numpy() + else: + hidden=hidden[0] + if self.is_half == True: + hidden = hidden.astype("float32") + + f0 = self.decode(hidden, thred=thred) + # torch.cuda.synchronize() + t3=ttime() + # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0)) + return f0 + + def to_local_average_cents(self, salience, thred=0.05): + # t0 = ttime() + center = np.argmax(salience, axis=1) # 帧长#index + salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368 + # t1 = ttime() + center += 4 + todo_salience = [] + todo_cents_mapping = [] + starts = center - 4 + ends = center + 5 + for idx in range(salience.shape[0]): + todo_salience.append(salience[:, starts[idx] : ends[idx]][idx]) + todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]]) + # t2 = ttime() + todo_salience = np.array(todo_salience) # 帧长,9 + todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9 + product_sum = np.sum(todo_salience * todo_cents_mapping, 1) + weight_sum = np.sum(todo_salience, 1) # 帧长 + devided = product_sum / weight_sum # 帧长 + # t3 = ttime() + maxx = np.max(salience, axis=1) # 帧长 + devided[maxx <= thred] = 0 + # t4 = ttime() + # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + return devided + + +if __name__ == '__main__': + import soundfile as sf, librosa + audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav") + if len(audio.shape) > 1: + audio = librosa.to_mono(audio.transpose(1, 0)) + audio_bak = audio.copy() + if sampling_rate != 16000: + audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) + model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt" + thred = 0.03 # 0.01 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + rmvpe = RMVPE(model_path,is_half=False, device=device) + t0=ttime() + f0 = rmvpe.infer_from_audio(audio, thred=thred) + # f0 = rmvpe.infer_from_audio(audio, thred=thred) + # f0 = rmvpe.infer_from_audio(audio, thred=thred) + # f0 = rmvpe.infer_from_audio(audio, thred=thred) + # f0 = rmvpe.infer_from_audio(audio, thred=thred) + t1=ttime() + print(f0.shape,t1-t0) diff --git a/infer/lib/slicer2.py b/infer/lib/slicer2.py new file mode 100644 index 0000000..7d9d16d --- /dev/null +++ b/infer/lib/slicer2.py @@ -0,0 +1,260 @@ +import numpy as np + + +# This function is obtained from librosa. +def get_rms( + y, + frame_length=2048, + hop_length=512, + pad_mode="constant", +): + padding = (int(frame_length // 2), int(frame_length // 2)) + y = np.pad(y, padding, mode=pad_mode) + + axis = -1 + # put our new within-frame axis at the end for now + out_strides = y.strides + tuple([y.strides[axis]]) + # Reduce the shape on the framing axis + x_shape_trimmed = list(y.shape) + x_shape_trimmed[axis] -= frame_length - 1 + out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) + xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides) + if axis < 0: + target_axis = axis - 1 + else: + target_axis = axis + 1 + xw = np.moveaxis(xw, -1, target_axis) + # Downsample along the target axis + slices = [slice(None)] * xw.ndim + slices[axis] = slice(0, None, hop_length) + x = xw[tuple(slices)] + + # Calculate power + power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) + + return np.sqrt(power) + + +class Slicer: + def __init__( + self, + sr: int, + threshold: float = -40.0, + min_length: int = 5000, + min_interval: int = 300, + hop_size: int = 20, + max_sil_kept: int = 5000, + ): + if not min_length >= min_interval >= hop_size: + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) + if not max_sil_kept >= hop_size: + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.0) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + if len(waveform.shape) > 1: + return waveform[ + :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size) + ] + else: + return waveform[ + begin * self.hop_size : min(waveform.shape[0], end * self.hop_size) + ] + + # @timeit + def slice(self, waveform): + if len(waveform.shape) > 1: + samples = waveform.mean(axis=0) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return [waveform] + rms_list = get_rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + # Keep looping while frame is silent. + if rms < self.threshold: + # Record start of silent frames. + if silence_start is None: + silence_start = i + continue + # Keep looping while frame is not silent and silence start has not been recorded. + if silence_start is None: + continue + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + # Need slicing. Record the range of silent frames to be removed. + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start : i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() + pos += i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + # Deal with trailing silence. + total_frames = rms_list.shape[0] + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + # Apply and return slices. + if len(sil_tags) == 0: + return [waveform] + else: + chunks = [] + if sil_tags[0][0] > 0: + chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0])) + for i in range(len(sil_tags) - 1): + chunks.append( + self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]) + ) + if sil_tags[-1][1] < total_frames: + chunks.append( + self._apply_slice(waveform, sil_tags[-1][1], total_frames) + ) + return chunks + + +def main(): + import os.path + from argparse import ArgumentParser + + import librosa + import soundfile + + parser = ArgumentParser() + parser.add_argument("audio", type=str, help="The audio to be sliced") + parser.add_argument( + "--out", type=str, help="Output directory of the sliced audio clips" + ) + parser.add_argument( + "--db_thresh", + type=float, + required=False, + default=-40, + help="The dB threshold for silence detection", + ) + parser.add_argument( + "--min_length", + type=int, + required=False, + default=5000, + help="The minimum milliseconds required for each sliced audio clip", + ) + parser.add_argument( + "--min_interval", + type=int, + required=False, + default=300, + help="The minimum milliseconds for a silence part to be sliced", + ) + parser.add_argument( + "--hop_size", + type=int, + required=False, + default=10, + help="Frame length in milliseconds", + ) + parser.add_argument( + "--max_sil_kept", + type=int, + required=False, + default=500, + help="The maximum silence length kept around the sliced clip, presented in milliseconds", + ) + args = parser.parse_args() + out = args.out + if out is None: + out = os.path.dirname(os.path.abspath(args.audio)) + audio, sr = librosa.load(args.audio, sr=None, mono=False) + slicer = Slicer( + sr=sr, + threshold=args.db_thresh, + min_length=args.min_length, + min_interval=args.min_interval, + hop_size=args.hop_size, + max_sil_kept=args.max_sil_kept, + ) + chunks = slicer.slice(audio) + if not os.path.exists(out): + os.makedirs(out) + for i, chunk in enumerate(chunks): + if len(chunk.shape) > 1: + chunk = chunk.T + soundfile.write( + os.path.join( + out, + f"%s_%d.wav" + % (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i), + ), + chunk, + sr, + ) + + +if __name__ == "__main__": + main() diff --git a/infer/lib/train/data_utils.py b/infer/lib/train/data_utils.py new file mode 100644 index 0000000..7793f15 --- /dev/null +++ b/infer/lib/train/data_utils.py @@ -0,0 +1,512 @@ +import os, traceback +import numpy as np +import torch +import torch.utils.data + +from infer.lib.train.mel_processing import spectrogram_torch +from infer.lib.train.utils import load_wav_to_torch, load_filepaths_and_text + + +class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset): + """ + 1) loads audio, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, audiopaths_and_text, hparams): + self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.min_text_len = getattr(hparams, "min_text_len", 1) + self.max_text_len = getattr(hparams, "max_text_len", 5000) + self._filter() + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + audiopaths_and_text_new = [] + lengths = [] + for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text: + if self.min_text_len <= len(text) and len(text) <= self.max_text_len: + audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv]) + lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length)) + self.audiopaths_and_text = audiopaths_and_text_new + self.lengths = lengths + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def get_audio_text_pair(self, audiopath_and_text): + # separate filename and text + file = audiopath_and_text[0] + phone = audiopath_and_text[1] + pitch = audiopath_and_text[2] + pitchf = audiopath_and_text[3] + dv = audiopath_and_text[4] + + phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf) + spec, wav = self.get_audio(file) + dv = self.get_sid(dv) + + len_phone = phone.size()[0] + len_spec = spec.size()[-1] + # print(123,phone.shape,pitch.shape,spec.shape) + if len_phone != len_spec: + len_min = min(len_phone, len_spec) + # amor + len_wav = len_min * self.hop_length + + spec = spec[:, :len_min] + wav = wav[:, :len_wav] + + phone = phone[:len_min, :] + pitch = pitch[:len_min] + pitchf = pitchf[:len_min] + + return (spec, wav, phone, pitch, pitchf, dv) + + def get_labels(self, phone, pitch, pitchf): + phone = np.load(phone) + phone = np.repeat(phone, 2, axis=0) + pitch = np.load(pitch) + pitchf = np.load(pitchf) + n_num = min(phone.shape[0], 900) # DistributedBucketSampler + # print(234,phone.shape,pitch.shape) + phone = phone[:n_num, :] + pitch = pitch[:n_num] + pitchf = pitchf[:n_num] + phone = torch.FloatTensor(phone) + pitch = torch.LongTensor(pitch) + pitchf = torch.FloatTensor(pitchf) + return phone, pitch, pitchf + + def get_audio(self, filename): + audio, sampling_rate = load_wav_to_torch(filename) + if sampling_rate != self.sampling_rate: + raise ValueError( + "{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate + ) + ) + audio_norm = audio + # audio_norm = audio / self.max_wav_value + # audio_norm = audio / np.abs(audio).max() + + audio_norm = audio_norm.unsqueeze(0) + spec_filename = filename.replace(".wav", ".spec.pt") + if os.path.exists(spec_filename): + try: + spec = torch.load(spec_filename) + except: + print(spec_filename, traceback.format_exc()) + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) + else: + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) + return spec, audio_norm + + def __getitem__(self, index): + return self.get_audio_text_pair(self.audiopaths_and_text[index]) + + def __len__(self): + return len(self.audiopaths_and_text) + + +class TextAudioCollateMultiNSFsid: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text and aduio + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True + ) + + max_spec_len = max([x[0].size(1) for x in batch]) + max_wave_len = max([x[1].size(1) for x in batch]) + spec_lengths = torch.LongTensor(len(batch)) + wave_lengths = torch.LongTensor(len(batch)) + spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) + wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len) + spec_padded.zero_() + wave_padded.zero_() + + max_phone_len = max([x[2].size(0) for x in batch]) + phone_lengths = torch.LongTensor(len(batch)) + phone_padded = torch.FloatTensor( + len(batch), max_phone_len, batch[0][2].shape[1] + ) # (spec, wav, phone, pitch) + pitch_padded = torch.LongTensor(len(batch), max_phone_len) + pitchf_padded = torch.FloatTensor(len(batch), max_phone_len) + phone_padded.zero_() + pitch_padded.zero_() + pitchf_padded.zero_() + # dv = torch.FloatTensor(len(batch), 256)#gin=256 + sid = torch.LongTensor(len(batch)) + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + spec = row[0] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wave = row[1] + wave_padded[i, :, : wave.size(1)] = wave + wave_lengths[i] = wave.size(1) + + phone = row[2] + phone_padded[i, : phone.size(0), :] = phone + phone_lengths[i] = phone.size(0) + + pitch = row[3] + pitch_padded[i, : pitch.size(0)] = pitch + pitchf = row[4] + pitchf_padded[i, : pitchf.size(0)] = pitchf + + # dv[i] = row[5] + sid[i] = row[5] + + return ( + phone_padded, + phone_lengths, + pitch_padded, + pitchf_padded, + spec_padded, + spec_lengths, + wave_padded, + wave_lengths, + # dv + sid, + ) + + +class TextAudioLoader(torch.utils.data.Dataset): + """ + 1) loads audio, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, audiopaths_and_text, hparams): + self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.min_text_len = getattr(hparams, "min_text_len", 1) + self.max_text_len = getattr(hparams, "max_text_len", 5000) + self._filter() + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + audiopaths_and_text_new = [] + lengths = [] + for audiopath, text, dv in self.audiopaths_and_text: + if self.min_text_len <= len(text) and len(text) <= self.max_text_len: + audiopaths_and_text_new.append([audiopath, text, dv]) + lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length)) + self.audiopaths_and_text = audiopaths_and_text_new + self.lengths = lengths + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def get_audio_text_pair(self, audiopath_and_text): + # separate filename and text + file = audiopath_and_text[0] + phone = audiopath_and_text[1] + dv = audiopath_and_text[2] + + phone = self.get_labels(phone) + spec, wav = self.get_audio(file) + dv = self.get_sid(dv) + + len_phone = phone.size()[0] + len_spec = spec.size()[-1] + if len_phone != len_spec: + len_min = min(len_phone, len_spec) + len_wav = len_min * self.hop_length + spec = spec[:, :len_min] + wav = wav[:, :len_wav] + phone = phone[:len_min, :] + return (spec, wav, phone, dv) + + def get_labels(self, phone): + phone = np.load(phone) + phone = np.repeat(phone, 2, axis=0) + n_num = min(phone.shape[0], 900) # DistributedBucketSampler + phone = phone[:n_num, :] + phone = torch.FloatTensor(phone) + return phone + + def get_audio(self, filename): + audio, sampling_rate = load_wav_to_torch(filename) + if sampling_rate != self.sampling_rate: + raise ValueError( + "{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate + ) + ) + audio_norm = audio + # audio_norm = audio / self.max_wav_value + # audio_norm = audio / np.abs(audio).max() + + audio_norm = audio_norm.unsqueeze(0) + spec_filename = filename.replace(".wav", ".spec.pt") + if os.path.exists(spec_filename): + try: + spec = torch.load(spec_filename) + except: + print(spec_filename, traceback.format_exc()) + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) + else: + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) + return spec, audio_norm + + def __getitem__(self, index): + return self.get_audio_text_pair(self.audiopaths_and_text[index]) + + def __len__(self): + return len(self.audiopaths_and_text) + + +class TextAudioCollate: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text and aduio + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True + ) + + max_spec_len = max([x[0].size(1) for x in batch]) + max_wave_len = max([x[1].size(1) for x in batch]) + spec_lengths = torch.LongTensor(len(batch)) + wave_lengths = torch.LongTensor(len(batch)) + spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) + wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len) + spec_padded.zero_() + wave_padded.zero_() + + max_phone_len = max([x[2].size(0) for x in batch]) + phone_lengths = torch.LongTensor(len(batch)) + phone_padded = torch.FloatTensor( + len(batch), max_phone_len, batch[0][2].shape[1] + ) + phone_padded.zero_() + sid = torch.LongTensor(len(batch)) + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + spec = row[0] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wave = row[1] + wave_padded[i, :, : wave.size(1)] = wave + wave_lengths[i] = wave.size(1) + + phone = row[2] + phone_padded[i, : phone.size(0), :] = phone + phone_lengths[i] = phone.size(0) + + sid[i] = row[3] + + return ( + phone_padded, + phone_lengths, + spec_padded, + spec_lengths, + wave_padded, + wave_lengths, + sid, + ) + + +class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): + """ + Maintain similar input lengths in a batch. + Length groups are specified by boundaries. + Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. + + It removes samples which are not included in the boundaries. + Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. + """ + + def __init__( + self, + dataset, + batch_size, + boundaries, + num_replicas=None, + rank=None, + shuffle=True, + ): + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + self.lengths = dataset.lengths + self.batch_size = batch_size + self.boundaries = boundaries + + self.buckets, self.num_samples_per_bucket = self._create_buckets() + self.total_size = sum(self.num_samples_per_bucket) + self.num_samples = self.total_size // self.num_replicas + + def _create_buckets(self): + buckets = [[] for _ in range(len(self.boundaries) - 1)] + for i in range(len(self.lengths)): + length = self.lengths[i] + idx_bucket = self._bisect(length) + if idx_bucket != -1: + buckets[idx_bucket].append(i) + + for i in range(len(buckets) - 1, -1, -1): # + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + + num_samples_per_bucket = [] + for i in range(len(buckets)): + len_bucket = len(buckets[i]) + total_batch_size = self.num_replicas * self.batch_size + rem = ( + total_batch_size - (len_bucket % total_batch_size) + ) % total_batch_size + num_samples_per_bucket.append(len_bucket + rem) + return buckets, num_samples_per_bucket + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ( + ids_bucket + + ids_bucket * (rem // len_bucket) + + ids_bucket[: (rem % len_bucket)] + ) + + # subsample + ids_bucket = ids_bucket[self.rank :: self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [ + bucket[idx] + for idx in ids_bucket[ + j * self.batch_size : (j + 1) * self.batch_size + ] + ] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + + def _bisect(self, x, lo=0, hi=None): + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 + + def __len__(self): + return self.num_samples // self.batch_size diff --git a/infer/lib/train/losses.py b/infer/lib/train/losses.py new file mode 100644 index 0000000..aa7bd81 --- /dev/null +++ b/infer/lib/train/losses.py @@ -0,0 +1,58 @@ +import torch + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/infer/lib/train/mel_processing.py b/infer/lib/train/mel_processing.py new file mode 100644 index 0000000..3cc3687 --- /dev/null +++ b/infer/lib/train/mel_processing.py @@ -0,0 +1,130 @@ +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +# Reusable banks +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + """Convert waveform into Linear-frequency Linear-amplitude spectrogram. + + Args: + y :: (B, T) - Audio waveforms + n_fft + sampling_rate + hop_size + win_size + center + Returns: + :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram + """ + # Validation + if torch.min(y) < -1.07: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.07: + print("max value is ", torch.max(y)) + + # Window - Cache if needed + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + # Padding + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + # Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + # MelBasis - Cache if needed + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=spec.dtype, device=spec.device + ) + + # Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame) + melspec = torch.matmul(mel_basis[fmax_dtype_device], spec) + melspec = spectral_normalize_torch(melspec) + return melspec + + +def mel_spectrogram_torch( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + """Convert waveform into Mel-frequency Log-amplitude spectrogram. + + Args: + y :: (B, T) - Waveforms + Returns: + melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram + """ + # Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame) + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) + + # Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame) + melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax) + + return melspec diff --git a/infer/lib/train/process_ckpt.py b/infer/lib/train/process_ckpt.py new file mode 100644 index 0000000..a48ca61 --- /dev/null +++ b/infer/lib/train/process_ckpt.py @@ -0,0 +1,259 @@ +import torch, traceback, os, sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +from collections import OrderedDict +from i18n.i18n import I18nAuto + +i18n = I18nAuto() + + +def savee(ckpt, sr, if_f0, name, epoch, version, hps, i18n): + try: + opt = OrderedDict() + opt["weight"] = {} + for key in ckpt.keys(): + if "enc_q" in key: + continue + opt["weight"][key] = ckpt[key].half() + opt["config"] = [ + hps.data.filter_length // 2 + 1, + 32, + hps.model.inter_channels, + hps.model.hidden_channels, + hps.model.filter_channels, + hps.model.n_heads, + hps.model.n_layers, + hps.model.kernel_size, + hps.model.p_dropout, + hps.model.resblock, + hps.model.resblock_kernel_sizes, + hps.model.resblock_dilation_sizes, + hps.model.upsample_rates, + hps.model.upsample_initial_channel, + hps.model.upsample_kernel_sizes, + hps.model.spk_embed_dim, + hps.model.gin_channels, + hps.data.sampling_rate, + ] + opt["info"] = "%sepoch" % epoch + opt["sr"] = sr + opt["f0"] = if_f0 + opt["version"] = version + torch.save(opt, "weights/%s.pth" % name) + return "Success." + except: + return traceback.format_exc() + + +def show_info(path): + try: + a = torch.load(path, map_location="cpu") + return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % ( + a.get("info", "None"), + a.get("sr", "None"), + a.get("f0", "None"), + a.get("version", "None"), + ) + except: + return traceback.format_exc() + + +def extract_small_model(path, name, sr, if_f0, info, version): + try: + ckpt = torch.load(path, map_location="cpu") + if "model" in ckpt: + ckpt = ckpt["model"] + opt = OrderedDict() + opt["weight"] = {} + for key in ckpt.keys(): + if "enc_q" in key: + continue + opt["weight"][key] = ckpt[key].half() + if sr == "40k": + opt["config"] = [ + 1025, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [10, 10, 2, 2], + 512, + [16, 16, 4, 4], + 109, + 256, + 40000, + ] + elif sr == "48k": + if version == "v1": + opt["config"] = [ + 1025, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [10, 6, 2, 2, 2], + 512, + [16, 16, 4, 4, 4], + 109, + 256, + 48000, + ] + else: + opt["config"] = [ + 1025, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [12, 10, 2, 2], + 512, + [24, 20, 4, 4], + 109, + 256, + 48000, + ] + elif sr == "32k": + if version == "v1": + opt["config"] = [ + 513, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [10, 4, 2, 2, 2], + 512, + [16, 16, 4, 4, 4], + 109, + 256, + 32000, + ] + else: + opt["config"] = [ + 513, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [10, 8, 2, 2], + 512, + [20, 16, 4, 4], + 109, + 256, + 32000, + ] + if info == "": + info = "Extracted model." + opt["info"] = info + opt["version"] = version + opt["sr"] = sr + opt["f0"] = int(if_f0) + torch.save(opt, "weights/%s.pth" % name) + return "Success." + except: + return traceback.format_exc() + + +def change_info(path, info, name): + try: + ckpt = torch.load(path, map_location="cpu") + ckpt["info"] = info + if name == "": + name = os.path.basename(path) + torch.save(ckpt, "weights/%s" % name) + return "Success." + except: + return traceback.format_exc() + + +def merge(path1, path2, alpha1, sr, f0, info, name, version): + try: + + def extract(ckpt): + a = ckpt["model"] + opt = OrderedDict() + opt["weight"] = {} + for key in a.keys(): + if "enc_q" in key: + continue + opt["weight"][key] = a[key] + return opt + + ckpt1 = torch.load(path1, map_location="cpu") + ckpt2 = torch.load(path2, map_location="cpu") + cfg = ckpt1["config"] + if "model" in ckpt1: + ckpt1 = extract(ckpt1) + else: + ckpt1 = ckpt1["weight"] + if "model" in ckpt2: + ckpt2 = extract(ckpt2) + else: + ckpt2 = ckpt2["weight"] + if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())): + return "Fail to merge the models. The model architectures are not the same." + opt = OrderedDict() + opt["weight"] = {} + for key in ckpt1.keys(): + # try: + if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape: + min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0]) + opt["weight"][key] = ( + alpha1 * (ckpt1[key][:min_shape0].float()) + + (1 - alpha1) * (ckpt2[key][:min_shape0].float()) + ).half() + else: + opt["weight"][key] = ( + alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float()) + ).half() + # except: + # pdb.set_trace() + opt["config"] = cfg + """ + if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000] + elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000] + elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] + """ + opt["sr"] = sr + opt["f0"] = 1 if f0 == i18n("是") else 0 + opt["version"] = version + opt["info"] = info + torch.save(opt, "weights/%s.pth" % name) + return "Success." + except: + return traceback.format_exc() diff --git a/infer/lib/train/utils.py b/infer/lib/train/utils.py new file mode 100644 index 0000000..9c0fb5c --- /dev/null +++ b/infer/lib/train/utils.py @@ -0,0 +1,487 @@ +import os, traceback +import glob +import sys +import argparse +import logging +import json +import subprocess +import numpy as np +from scipy.io.wavfile import read +import torch + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logger = logging + + +def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + + ################## + def go(model, bkey): + saved_state_dict = checkpoint_dict[bkey] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): # 模型需要的shape + try: + new_state_dict[k] = saved_state_dict[k] + if saved_state_dict[k].shape != state_dict[k].shape: + print( + "shape-%s-mismatch|need-%s|get-%s" + % (k, state_dict[k].shape, saved_state_dict[k].shape) + ) # + raise KeyError + except: + # logger.info(traceback.format_exc()) + logger.info("%s is not in the checkpoint" % k) # pretrain缺失的 + new_state_dict[k] = v # 模型自带的随机值 + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + return model + + go(combd, "combd") + model = go(sbd, "sbd") + ############# + logger.info("Loaded model weights") + + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if ( + optimizer is not None and load_opt == 1 + ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch + # try: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + # except: + # traceback.print_exc() + logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +# def load_checkpoint(checkpoint_path, model, optimizer=None): +# assert os.path.isfile(checkpoint_path) +# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') +# iteration = checkpoint_dict['iteration'] +# learning_rate = checkpoint_dict['learning_rate'] +# if optimizer is not None: +# optimizer.load_state_dict(checkpoint_dict['optimizer']) +# # print(1111) +# saved_state_dict = checkpoint_dict['model'] +# # print(1111) +# +# if hasattr(model, 'module'): +# state_dict = model.module.state_dict() +# else: +# state_dict = model.state_dict() +# new_state_dict= {} +# for k, v in state_dict.items(): +# try: +# new_state_dict[k] = saved_state_dict[k] +# except: +# logger.info("%s is not in the checkpoint" % k) +# new_state_dict[k] = v +# if hasattr(model, 'module'): +# model.module.load_state_dict(new_state_dict) +# else: +# model.load_state_dict(new_state_dict) +# logger.info("Loaded checkpoint '{}' (epoch {})" .format( +# checkpoint_path, iteration)) +# return model, optimizer, learning_rate, iteration +def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): # 模型需要的shape + try: + new_state_dict[k] = saved_state_dict[k] + if saved_state_dict[k].shape != state_dict[k].shape: + print( + "shape-%s-mismatch|need-%s|get-%s" + % (k, state_dict[k].shape, saved_state_dict[k].shape) + ) # + raise KeyError + except: + # logger.info(traceback.format_exc()) + logger.info("%s is not in the checkpoint" % k) # pretrain缺失的 + new_state_dict[k] = v # 模型自带的随机值 + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + logger.info("Loaded model weights") + + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if ( + optimizer is not None and load_opt == 1 + ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch + # try: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + # except: + # traceback.print_exc() + logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info( + "Saving model and optimizer state at epoch {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path): + logger.info( + "Saving model and optimizer state at epoch {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(combd, "module"): + state_dict_combd = combd.module.state_dict() + else: + state_dict_combd = combd.state_dict() + if hasattr(sbd, "module"): + state_dict_sbd = sbd.module.state_dict() + else: + state_dict_sbd = sbd.state_dict() + torch.save( + { + "combd": state_dict_combd, + "sbd": state_dict_sbd, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), aspect="auto", origin="lower", interpolation="none" + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_hparams(init=True): + """ + todo: + 结尾七人组: + 保存频率、总epoch done + bs done + pretrainG、pretrainD done + 卡号:os.en["CUDA_VISIBLE_DEVICES"] done + if_latest done + 模型:if_f0 done + 采样率:自动选择config done + 是否缓存数据集进GPU:if_cache_data_in_gpu done + + -m: + 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done + -c不要了 + """ + parser = argparse.ArgumentParser() + # parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration') + parser.add_argument( + "-se", + "--save_every_epoch", + type=int, + required=True, + help="checkpoint save frequency (epoch)", + ) + parser.add_argument( + "-te", "--total_epoch", type=int, required=True, help="total_epoch" + ) + parser.add_argument( + "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path" + ) + parser.add_argument( + "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path" + ) + parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -") + parser.add_argument( + "-bs", "--batch_size", type=int, required=True, help="batch size" + ) + parser.add_argument( + "-e", "--experiment_dir", type=str, required=True, help="experiment dir" + ) # -m + parser.add_argument( + "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k" + ) + parser.add_argument( + "-sw", + "--save_every_weights", + type=str, + default="0", + help="save the extracted model in weights directory when saving checkpoints", + ) + parser.add_argument( + "-v", "--version", type=str, required=True, help="model version" + ) + parser.add_argument( + "-f0", + "--if_f0", + type=int, + required=True, + help="use f0 as one of the inputs of the model, 1 or 0", + ) + parser.add_argument( + "-l", + "--if_latest", + type=int, + required=True, + help="if only save the latest G/D pth file, 1 or 0", + ) + parser.add_argument( + "-c", + "--if_cache_data_in_gpu", + type=int, + required=True, + help="if caching the dataset in GPU memory, 1 or 0", + ) + + args = parser.parse_args() + name = args.experiment_dir + experiment_dir = os.path.join("./logs", args.experiment_dir) + + if not os.path.exists(experiment_dir): + os.makedirs(experiment_dir) + + if args.version == "v1" or args.sample_rate == "40k": + config_path = "configs/%s.json" % args.sample_rate + else: + config_path = "configs/%s_v2.json" % args.sample_rate + config_save_path = os.path.join(experiment_dir, "config.json") + if init: + with open(config_path, "r") as f: + data = f.read() + with open(config_save_path, "w") as f: + f.write(data) + else: + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = hparams.experiment_dir = experiment_dir + hparams.save_every_epoch = args.save_every_epoch + hparams.name = name + hparams.total_epoch = args.total_epoch + hparams.pretrainG = args.pretrainG + hparams.pretrainD = args.pretrainD + hparams.version = args.version + hparams.gpus = args.gpus + hparams.train.batch_size = args.batch_size + hparams.sample_rate = args.sample_rate + hparams.if_f0 = args.if_f0 + hparams.if_latest = args.if_latest + hparams.save_every_weights = args.save_every_weights + hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu + hparams.data.training_files = "%s/filelist.txt" % experiment_dir + return hparams + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warn( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + ) + ) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warn( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8] + ) + ) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__()