Fix return_complex warning on training (#1627)

* Fix return_complex warning on training

* remove unused prints
This commit is contained in:
Blaise 2023-12-22 02:35:51 +01:00 committed by GitHub
parent 0f8a5facd9
commit 78f03e7dc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -38,7 +38,6 @@ def spectral_de_normalize_torch(magnitudes):
mel_basis = {}
hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram.
@ -52,12 +51,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
Returns:
:: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
"""
# Validation
if torch.min(y) < -1.07:
logger.debug("min value is %s", str(torch.min(y)))
if torch.max(y) > 1.07:
logger.debug("max value is %s", str(torch.max(y)))
# Window - Cache if needed
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -66,7 +60,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
# Padding
y = torch.nn.functional.pad(
y.unsqueeze(1),
@ -74,7 +68,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
mode="reflect",
)
y = y.squeeze(1)
# Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
spec = torch.stft(
y,
@ -86,14 +80,13 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
return_complex=True,
)
# Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
# MelBasis - Cache if needed
global mel_basis