From 78f03e7dc0563e438307ff62c76a062b46083ec4 Mon Sep 17 00:00:00 2001 From: Blaise <133521603+blaise-tk@users.noreply.github.com> Date: Fri, 22 Dec 2023 02:35:51 +0100 Subject: [PATCH] Fix return_complex warning on training (#1627) * Fix return_complex warning on training * remove unused prints --- infer/lib/train/mel_processing.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/infer/lib/train/mel_processing.py b/infer/lib/train/mel_processing.py index 04a11f1..14a960f 100644 --- a/infer/lib/train/mel_processing.py +++ b/infer/lib/train/mel_processing.py @@ -38,7 +38,6 @@ def spectral_de_normalize_torch(magnitudes): mel_basis = {} hann_window = {} - def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): """Convert waveform into Linear-frequency Linear-amplitude spectrogram. @@ -52,12 +51,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) Returns: :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram """ - # Validation - if torch.min(y) < -1.07: - logger.debug("min value is %s", str(torch.min(y))) - if torch.max(y) > 1.07: - logger.debug("max value is %s", str(torch.max(y))) - + # Window - Cache if needed global hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -66,7 +60,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( dtype=y.dtype, device=y.device ) - + # Padding y = torch.nn.functional.pad( y.unsqueeze(1), @@ -74,7 +68,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) mode="reflect", ) y = y.squeeze(1) - + # Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2) spec = torch.stft( y, @@ -86,14 +80,13 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) pad_mode="reflect", normalized=False, onesided=True, - return_complex=False, + return_complex=True, ) - + # Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) return spec - def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): # MelBasis - Cache if needed global mel_basis