From 1457169e7a912ac3e00b105090764a574caf71f5 Mon Sep 17 00:00:00 2001 From: yxlllc <33565655+yxlllc@users.noreply.github.com> Date: Sun, 3 Sep 2023 13:57:31 +0800 Subject: [PATCH] Update real-time gui (#1174) * loudness factor control and gpu-accelerated noise reduction * loudness factor control and gpu-accelerated noise reduction * loudness factor control and gpu-accelerated noise reduction --- configs/config.json | 13 +- gui_v1.py | 132 ++++++++++-------- i18n/locale/en_US.json | 1 + i18n/locale/es_ES.json | 1 + i18n/locale/it_IT.json | 1 + i18n/locale/ja_JP.json | 1 + i18n/locale/ru_RU.json | 1 + i18n/locale/tr_TR.json | 1 + i18n/locale/zh_CN.json | 1 + i18n/locale/zh_HK.json | 1 + i18n/locale/zh_SG.json | 1 + i18n/locale/zh_TW.json | 1 + tools/torchgate/__init__.py | 12 ++ tools/torchgate/torchgate.py | 264 +++++++++++++++++++++++++++++++++++ tools/torchgate/utils.py | 66 +++++++++ 15 files changed, 434 insertions(+), 63 deletions(-) create mode 100644 tools/torchgate/__init__.py create mode 100644 tools/torchgate/torchgate.py create mode 100644 tools/torchgate/utils.py diff --git a/configs/config.json b/configs/config.json index d8a6936..8e9c176 100644 --- a/configs/config.json +++ b/configs/config.json @@ -4,11 +4,12 @@ "sg_input_device": "VoiceMeeter Output (VB-Audio Vo (MME)", "sg_output_device": "VoiceMeeter Aux Input (VB-Audio (MME)", "threhold": -45.0, - "pitch": 0.0, - "index_rate": 1.0, - "block_time": 0.09, - "crossfade_length": 0.15, - "extra_time": 5.0, - "n_cpu": 8.0, + "pitch": 12.0, + "index_rate": 0.0, + "rms_mix_rate": 0.0, + "block_time": 0.25, + "crossfade_length": 0.04, + "extra_time": 2.0, + "n_cpu": 6.0, "f0method": "rmvpe" } diff --git a/gui_v1.py b/gui_v1.py index e5d39ff..7e0ad28 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -51,7 +51,7 @@ if __name__ == "__main__": from queue import Empty import librosa - import noisereduce as nr + from tools.torchgate import TorchGate import numpy as np import PySimpleGUI as sg import sounddevice as sd @@ -80,15 +80,16 @@ if __name__ == "__main__": def __init__(self) -> None: self.pth_path: str = "" self.index_path: str = "" - self.pitch: int = 12 + self.pitch: int = 0 self.samplerate: int = 40000 self.block_time: float = 1.0 # s self.buffer_num: int = 1 - self.threhold: int = -30 - self.crossfade_time: float = 0.08 - self.extra_time: float = 0.04 + self.threhold: int = -60 + self.crossfade_time: float = 0.04 + self.extra_time: float = 2.0 self.I_noise_reduce = False self.O_noise_reduce = False + self.rms_mix_rate = 0.0 self.index_rate = 0.3 self.n_cpu = min(n_cpu, 6) self.f0method = "harvest" @@ -118,14 +119,19 @@ if __name__ == "__main__": "index_path": " ", "sg_input_device": input_devices[sd.default.device[0]], "sg_output_device": output_devices[sd.default.device[1]], - "threhold": "-45", + "threhold": "-60", "pitch": "0", "index_rate": "0", - "block_time": "1", + "rms_mix_rate": "0", + "block_time": "0.25", "crossfade_length": "0.04", - "extra_time": "1", + "extra_time": "2", "f0method": "rmvpe", } + data["pm"] = data["f0method"] == "pm" + data["harvest"] = data["f0method"] == "harvest" + data["crepe"] = data["f0method"] == "crepe" + data["rmvpe"] = data["f0method"] == "rmvpe" return data def launcher(self): @@ -198,7 +204,7 @@ if __name__ == "__main__": key="threhold", resolution=1, orientation="h", - default_value=data.get("threhold", ""), + default_value=data.get("threhold", "-60"), enable_events=True, ), ], @@ -209,7 +215,7 @@ if __name__ == "__main__": key="pitch", resolution=1, orientation="h", - default_value=data.get("pitch", ""), + default_value=data.get("pitch", "0"), enable_events=True, ), ], @@ -220,7 +226,18 @@ if __name__ == "__main__": key="index_rate", resolution=0.01, orientation="h", - default_value=data.get("index_rate", ""), + default_value=data.get("index_rate", "0"), + enable_events=True, + ), + ], + [ + sg.Text(i18n("响度因子")), + sg.Slider( + range=(0.0, 1.0), + key="rms_mix_rate", + resolution=0.01, + orientation="h", + default_value=data.get("rms_mix_rate", "0"), enable_events=True, ), ], @@ -267,7 +284,7 @@ if __name__ == "__main__": key="block_time", resolution=0.01, orientation="h", - default_value=data.get("block_time", ""), + default_value=data.get("block_time", "0.25"), enable_events=True, ), ], @@ -291,7 +308,7 @@ if __name__ == "__main__": key="crossfade_length", resolution=0.01, orientation="h", - default_value=data.get("crossfade_length", ""), + default_value=data.get("crossfade_length", "0.04"), enable_events=True, ), ], @@ -302,7 +319,7 @@ if __name__ == "__main__": key="extra_time", resolution=0.01, orientation="h", - default_value=data.get("extra_time", ""), + default_value=data.get("extra_time", "2.0"), enable_events=True, ), ], @@ -369,6 +386,7 @@ if __name__ == "__main__": "sg_output_device": values["sg_output_device"], "threhold": values["threhold"], "pitch": values["pitch"], + "rms_mix_rate": values["rms_mix_rate"], "index_rate": values["index_rate"], "block_time": values["block_time"], "crossfade_length": values["crossfade_length"], @@ -399,6 +417,8 @@ if __name__ == "__main__": self.config.index_rate = values["index_rate"] if hasattr(self, "rvc"): self.rvc.change_index_rate(values["index_rate"]) + elif event == "rms_mix_rate": + self.config.rms_mix_rate = values["rms_mix_rate"] elif event in ["pm", "harvest", "crepe", "rmvpe"]: self.config.f0method = event elif event == "I_noise_reduce": @@ -433,6 +453,7 @@ if __name__ == "__main__": self.config.extra_time = values["extra_time"] self.config.I_noise_reduce = values["I_noise_reduce"] self.config.O_noise_reduce = values["O_noise_reduce"] + self.config.rms_mix_rate = values["rms_mix_rate"] self.config.index_rate = values["index_rate"] self.config.n_cpu = values["n_cpu"] self.config.f0method = ["pm", "harvest", "crepe", "rmvpe"][ @@ -457,17 +478,14 @@ if __name__ == "__main__": inp_q, opt_q, device, - self.rvc if hasattr(self, "rvc") else None, + self.rvc if hasattr(self, "rvc") else None ) self.config.samplerate = self.rvc.tgt_sr self.config.crossfade_time = min( self.config.crossfade_time, self.config.block_time ) self.zc = self.rvc.tgt_sr // 100 - self.block_frame = ( - int(np.round(self.config.block_time * self.config.samplerate / self.zc)) - * self.zc - ) + self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc self.block_frame_16k = 160 * self.block_frame // self.zc self.crossfade_frame = int( self.config.crossfade_time * self.config.samplerate @@ -489,9 +507,7 @@ if __name__ == "__main__": ), dtype="float32", ) - self.input_wav_res: torch.Tensor = torch.zeros( - 160 * len(self.input_wav) // self.zc - ) + self.input_wav_res: torch.Tensor= torch.zeros(160 * len(self.input_wav) // self.zc, device=device,dtype=torch.float32) self.output_wav_cache: torch.Tensor = torch.zeros( int( np.ceil( @@ -540,6 +556,8 @@ if __name__ == "__main__": self.resampler = tat.Resample( orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32 ).to(device) + self.input_tg = TorchGate(sr=16000, nonstationary=True, n_fft=640).to(device) + self.output_tg = TorchGate(sr=self.config.samplerate, nonstationary=True, n_fft=4*self.zc).to(device) thread_vc = threading.Thread(target=self.soundinput) thread_vc.start() @@ -568,9 +586,6 @@ if __name__ == "__main__": """ start_time = time.perf_counter() indata = librosa.to_mono(indata.T) - if self.config.I_noise_reduce: - indata[:] = nr.reduce_noise(y=indata, sr=self.config.samplerate) - """noise gate""" frame_length = 2048 hop_length = 1024 rms = librosa.feature.rms( @@ -584,18 +599,13 @@ if __name__ == "__main__": if db_threhold[i]: indata[i * hop_length : (i + 1) * hop_length] = 0 self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :] - self.input_wav[-self.block_frame :] = indata - + self.input_wav[-self.block_frame: ] = indata # infer - inp = torch.from_numpy( - self.input_wav[-self.block_frame - 2 * self.zc :] - ).to(device) - self.input_wav_res[: -self.block_frame_16k] = self.input_wav_res[ - self.block_frame_16k : - ].clone() - self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler(inp)[ - 160: - ] + inp = torch.from_numpy(self.input_wav[-self.block_frame-2*self.zc :]).to(device) + self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone() + self.input_wav_res[-self.block_frame_16k-160 :] = self.resampler(inp)[160 :] + if self.config.I_noise_reduce: + self.input_wav_res[-self.block_frame_16k-320 :] = self.input_tg(self.input_wav_res[None, -self.block_frame_16k-800 :])[0, 480 : ] rate = ( self.crossfade_frame + self.sola_search_frame + self.block_frame ) / ( @@ -605,11 +615,11 @@ if __name__ == "__main__": + self.block_frame ) f0_extractor_frame = self.block_frame_16k + 800 - if self.config.f0method == "rmvpe": + if self.config.f0method == 'rmvpe': f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1) res2 = self.rvc.infer( self.input_wav_res, - self.input_wav_res[-f0_extractor_frame:].cpu().numpy(), + self.input_wav_res[-f0_extractor_frame :].cpu().numpy(), self.block_frame_16k, rate, self.pitch, @@ -620,6 +630,27 @@ if __name__ == "__main__": infer_wav = self.output_wav_cache[ -self.crossfade_frame - self.sola_search_frame - self.block_frame : ] + if self.config.O_noise_reduce: + infer_wav = self.output_tg(infer_wav.unsqueeze(0)).squeeze(0) + if self.config.rms_mix_rate < 1: + rms1 = librosa.feature.rms( + y=self.input_wav[-self.crossfade_frame - self.sola_search_frame - self.block_frame :], + frame_length=frame_length, + hop_length=hop_length + ) + rms1 = torch.from_numpy(rms1).to(device) + rms1 = F.interpolate( + rms1.unsqueeze(0), size=infer_wav.shape[0], mode="linear" + ).squeeze() + rms2 = librosa.feature.rms( + y=infer_wav[:].cpu().numpy(), frame_length=frame_length, hop_length=hop_length + ) + rms2 = torch.from_numpy(rms2).to(device) + rms2 = F.interpolate( + rms2.unsqueeze(0), size=infer_wav.shape[0], mode="linear" + ).squeeze() + rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3) + infer_wav *= torch.pow(rms1 / rms2, torch.tensor(1 - self.config.rms_mix_rate)) # SOLA algorithm from https://github.com/yxlllc/DDSP-SVC cor_nom = F.conv1d( infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame], @@ -659,25 +690,10 @@ if __name__ == "__main__": self.sola_buffer[:] = ( infer_wav[-self.crossfade_frame :] * self.fade_out_window ) - if self.config.O_noise_reduce: - if sys.platform == "darwin": - noise_reduced_signal = nr.reduce_noise( - y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate - ) - outdata[:] = noise_reduced_signal[:, np.newaxis] - else: - outdata[:] = np.tile( - nr.reduce_noise( - y=self.output_wav[:].cpu().numpy(), - sr=self.config.samplerate, - ), - (2, 1), - ).T + if sys.platform == "darwin": + outdata[:] = self.output_wav[:].cpu().numpy()[:, np.newaxis] else: - if sys.platform == "darwin": - outdata[:] = self.output_wav[:].cpu().numpy()[:, np.newaxis] - else: - outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy() + outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy() total_time = time.perf_counter() - start_time self.window["infer_time"].update(int(total_time * 1000)) logger.info("Infer time: %.2f", total_time) @@ -733,7 +749,9 @@ if __name__ == "__main__": sd.default.device[1] = output_device_indices[ output_devices.index(output_device) ] - logger.info("Input device: %s:%d", str(sd.default.device[0]), input_device) + logger.info( + "Input device: %s:%d", str(sd.default.device[0]), input_device + ) logger.info( "Output device: %s:%d", str(sd.default.device[1]), output_device ) diff --git a/i18n/locale/en_US.json b/i18n/locale/en_US.json index bd0f19c..c862a9d 100644 --- a/i18n/locale/en_US.json +++ b/i18n/locale/en_US.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "Resample the output audio in post-processing to the final sample rate. Set to 0 for no resampling:", "否": "No", "响应阈值": "Response threshold", + "响度因子": "loudness factor", "处理数据": "Process data", "导出Onnx模型": "Export Onnx Model", "导出文件格式": "Export file format", diff --git a/i18n/locale/es_ES.json b/i18n/locale/es_ES.json index 3fcdff4..6341c75 100644 --- a/i18n/locale/es_ES.json +++ b/i18n/locale/es_ES.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "Remuestreo posterior al proceso a la tasa de muestreo final, 0 significa no remuestrear", "否": "No", "响应阈值": "Umbral de respuesta", + "响度因子": "factor de sonoridad", "处理数据": "Procesar datos", "导出Onnx模型": "Exportar modelo Onnx", "导出文件格式": "Formato de archivo de exportación", diff --git a/i18n/locale/it_IT.json b/i18n/locale/it_IT.json index 56dcd33..a77e82a 100644 --- a/i18n/locale/it_IT.json +++ b/i18n/locale/it_IT.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "Ricampiona l'audio di output in post-elaborazione alla frequenza di campionamento finale. ", "否": "NO", "响应阈值": "Soglia di risposta", + "响度因子": "fattore di sonorità", "处理数据": "Processa dati", "导出Onnx模型": "Esporta modello Onnx", "导出文件格式": "Formato file di esportazione", diff --git a/i18n/locale/ja_JP.json b/i18n/locale/ja_JP.json index ac14826..d33472f 100644 --- a/i18n/locale/ja_JP.json +++ b/i18n/locale/ja_JP.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "最終的なサンプリングレートへのポストプロセッシングのリサンプリング リサンプリングしない場合は0", "否": "いいえ", "响应阈值": "反応閾値", + "响度因子": "ラウドネス係数", "处理数据": "データ処理", "导出Onnx模型": "Onnxに変換", "导出文件格式": "エクスポート形式", diff --git a/i18n/locale/ru_RU.json b/i18n/locale/ru_RU.json index ec7c949..46af1f0 100644 --- a/i18n/locale/ru_RU.json +++ b/i18n/locale/ru_RU.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "Изменить частоту дискретизации в выходном файле на финальную. Поставьте 0, чтобы ничего не изменялось:", "否": "Нет", "响应阈值": "Порог ответа", + "响度因子": "коэффициент громкости", "处理数据": "Обработать данные", "导出Onnx模型": "Экспортировать модель", "导出文件格式": "Формат выходных файлов", diff --git a/i18n/locale/tr_TR.json b/i18n/locale/tr_TR.json index 8ebc306..566e999 100644 --- a/i18n/locale/tr_TR.json +++ b/i18n/locale/tr_TR.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "Son işleme aşamasında çıktı sesini son örnekleme hızına yeniden örnekle. 0 değeri için yeniden örnekleme yapılmaz:", "否": "Hayır", "响应阈值": "Tepki eşiği", + "响度因子": "ses yüksekliği faktörü", "处理数据": "Verileri işle", "导出Onnx模型": "Onnx Modeli Dışa Aktar", "导出文件格式": "Dışa aktarma dosya formatı", diff --git a/i18n/locale/zh_CN.json b/i18n/locale/zh_CN.json index be0b318..a65cc47 100644 --- a/i18n/locale/zh_CN.json +++ b/i18n/locale/zh_CN.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "后处理重采样至最终采样率,0为不进行重采样", "否": "否", "响应阈值": "响应阈值", + "响度因子": "响度因子", "处理数据": "处理数据", "导出Onnx模型": "导出Onnx模型", "导出文件格式": "导出文件格式", diff --git a/i18n/locale/zh_HK.json b/i18n/locale/zh_HK.json index 667a1c3..47ed97c 100644 --- a/i18n/locale/zh_HK.json +++ b/i18n/locale/zh_HK.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "後處理重採樣至最終採樣率,0為不進行重採樣", "否": "否", "响应阈值": "響應閾值", + "响度因子": "響度因子", "处理数据": "處理資料", "导出Onnx模型": "导出Onnx模型", "导出文件格式": "導出檔格式", diff --git a/i18n/locale/zh_SG.json b/i18n/locale/zh_SG.json index 667a1c3..47ed97c 100644 --- a/i18n/locale/zh_SG.json +++ b/i18n/locale/zh_SG.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "後處理重採樣至最終採樣率,0為不進行重採樣", "否": "否", "响应阈值": "響應閾值", + "响度因子": "響度因子", "处理数据": "處理資料", "导出Onnx模型": "导出Onnx模型", "导出文件格式": "導出檔格式", diff --git a/i18n/locale/zh_TW.json b/i18n/locale/zh_TW.json index 667a1c3..47ed97c 100644 --- a/i18n/locale/zh_TW.json +++ b/i18n/locale/zh_TW.json @@ -43,6 +43,7 @@ "后处理重采样至最终采样率,0为不进行重采样": "後處理重採樣至最終採樣率,0為不進行重採樣", "否": "否", "响应阈值": "響應閾值", + "响度因子": "響度因子", "处理数据": "處理資料", "导出Onnx模型": "导出Onnx模型", "导出文件格式": "導出檔格式", diff --git a/tools/torchgate/__init__.py b/tools/torchgate/__init__.py new file mode 100644 index 0000000..b4a1267 --- /dev/null +++ b/tools/torchgate/__init__.py @@ -0,0 +1,12 @@ +""" +TorchGating is a PyTorch-based implementation of Spectral Gating +================================================ +Author: Asaf Zorea + +Contents +-------- +torchgate imports all the functions from PyTorch, and in addition provides: + TorchGating --- A PyTorch module that applies a spectral gate to an input signal + +""" +from .torchgate import TorchGate diff --git a/tools/torchgate/torchgate.py b/tools/torchgate/torchgate.py new file mode 100644 index 0000000..086f2ab --- /dev/null +++ b/tools/torchgate/torchgate.py @@ -0,0 +1,264 @@ +import torch +from torch.nn.functional import conv1d, conv2d +from typing import Union, Optional +from .utils import linspace, temperature_sigmoid, amp_to_db + + +class TorchGate(torch.nn.Module): + """ + A PyTorch module that applies a spectral gate to an input signal. + + Arguments: + sr {int} -- Sample rate of the input signal. + nonstationary {bool} -- Whether to use non-stationary or stationary masking (default: {False}). + n_std_thresh_stationary {float} -- Number of standard deviations above mean to threshold noise for + stationary masking (default: {1.5}). + n_thresh_nonstationary {float} -- Number of multiplies above smoothed magnitude spectrogram. for + non-stationary masking (default: {1.3}). + temp_coeff_nonstationary {float} -- Temperature coefficient for non-stationary masking (default: {0.1}). + n_movemean_nonstationary {int} -- Number of samples for moving average smoothing in non-stationary masking + (default: {20}). + prop_decrease {float} -- Proportion to decrease signal by where the mask is zero (default: {1.0}). + n_fft {int} -- Size of FFT for STFT (default: {1024}). + win_length {[int]} -- Window length for STFT. If None, defaults to `n_fft` (default: {None}). + hop_length {[int]} -- Hop length for STFT. If None, defaults to `win_length` // 4 (default: {None}). + freq_mask_smooth_hz {float} -- Frequency smoothing width for mask (in Hz). If None, no smoothing is applied + (default: {500}). + time_mask_smooth_ms {float} -- Time smoothing width for mask (in ms). If None, no smoothing is applied + (default: {50}). + """ + + @torch.no_grad() + def __init__( + self, + sr: int, + nonstationary: bool = False, + n_std_thresh_stationary: float = 1.5, + n_thresh_nonstationary: float = 1.3, + temp_coeff_nonstationary: float = 0.1, + n_movemean_nonstationary: int = 20, + prop_decrease: float = 1.0, + n_fft: int = 1024, + win_length: bool = None, + hop_length: int = None, + freq_mask_smooth_hz: float = 500, + time_mask_smooth_ms: float = 50, + ): + super().__init__() + + # General Params + self.sr = sr + self.nonstationary = nonstationary + assert 0.0 <= prop_decrease <= 1.0 + self.prop_decrease = prop_decrease + + # STFT Params + self.n_fft = n_fft + self.win_length = self.n_fft if win_length is None else win_length + self.hop_length = self.win_length // 4 if hop_length is None else hop_length + + # Stationary Params + self.n_std_thresh_stationary = n_std_thresh_stationary + + # Non-Stationary Params + self.temp_coeff_nonstationary = temp_coeff_nonstationary + self.n_movemean_nonstationary = n_movemean_nonstationary + self.n_thresh_nonstationary = n_thresh_nonstationary + + # Smooth Mask Params + self.freq_mask_smooth_hz = freq_mask_smooth_hz + self.time_mask_smooth_ms = time_mask_smooth_ms + self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter()) + + @torch.no_grad() + def _generate_mask_smoothing_filter(self) -> Union[torch.Tensor, None]: + """ + A PyTorch module that applies a spectral gate to an input signal using the STFT. + + Returns: + smoothing_filter (torch.Tensor): a 2D tensor representing the smoothing filter, + with shape (n_grad_freq, n_grad_time), where n_grad_freq is the number of frequency + bins to smooth and n_grad_time is the number of time frames to smooth. + If both self.freq_mask_smooth_hz and self.time_mask_smooth_ms are None, returns None. + """ + if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: + return None + + n_grad_freq = ( + 1 + if self.freq_mask_smooth_hz is None + else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2))) + ) + if n_grad_freq < 1: + raise ValueError( + f"freq_mask_smooth_hz needs to be at least {int((self.sr / (self._n_fft / 2)))} Hz" + ) + + n_grad_time = ( + 1 + if self.time_mask_smooth_ms is None + else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000)) + ) + if n_grad_time < 1: + raise ValueError( + f"time_mask_smooth_ms needs to be at least {int((self.hop_length / self.sr) * 1000)} ms" + ) + + if n_grad_time == 1 and n_grad_freq == 1: + return None + + v_f = torch.cat( + [ + linspace(0, 1, n_grad_freq + 1, endpoint=False), + linspace(1, 0, n_grad_freq + 2), + ] + )[1:-1] + v_t = torch.cat( + [ + linspace(0, 1, n_grad_time + 1, endpoint=False), + linspace(1, 0, n_grad_time + 2), + ] + )[1:-1] + smoothing_filter = torch.outer(v_f, v_t).unsqueeze(0).unsqueeze(0) + + return smoothing_filter / smoothing_filter.sum() + + @torch.no_grad() + def _stationary_mask( + self, X_db: torch.Tensor, xn: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Computes a stationary binary mask to filter out noise in a log-magnitude spectrogram. + + Arguments: + X_db (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the log-magnitude spectrogram. + xn (torch.Tensor): 1D tensor containing the audio signal corresponding to X_db. + + Returns: + sig_mask (torch.Tensor): Binary mask of the same shape as X_db, where values greater than the threshold + are set to 1, and the rest are set to 0. + """ + if xn is not None: + XN = torch.stft( + xn, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + return_complex=True, + pad_mode="constant", + center=True, + window=torch.hann_window(self.win_length).to(xn.device), + ) + + XN_db = amp_to_db(XN).to(dtype=X_db.dtype) + else: + XN_db = X_db + + # calculate mean and standard deviation along the frequency axis + std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1) + + # compute noise threshold + noise_thresh = mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary + + # create binary mask by thresholding the spectrogram + sig_mask = X_db > noise_thresh.unsqueeze(2) + return sig_mask + + @torch.no_grad() + def _nonstationary_mask(self, X_abs: torch.Tensor) -> torch.Tensor: + """ + Computes a non-stationary binary mask to filter out noise in a log-magnitude spectrogram. + + Arguments: + X_abs (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the magnitude spectrogram. + + Returns: + sig_mask (torch.Tensor): Binary mask of the same shape as X_abs, where values greater than the threshold + are set to 1, and the rest are set to 0. + """ + X_smoothed = ( + conv1d( + X_abs.reshape(-1, 1, X_abs.shape[-1]), + torch.ones( + self.n_movemean_nonstationary, + dtype=X_abs.dtype, + device=X_abs.device, + ).view(1, 1, -1), + padding="same", + ).view(X_abs.shape) + / self.n_movemean_nonstationary + ) + + # Compute slowness ratio and apply temperature sigmoid + slowness_ratio = (X_abs - X_smoothed) / (X_smoothed + 1e-6) + sig_mask = temperature_sigmoid( + slowness_ratio, self.n_thresh_nonstationary, self.temp_coeff_nonstationary + ) + + return sig_mask + + def forward( + self, x: torch.Tensor, xn: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Apply the proposed algorithm to the input signal. + + Arguments: + x (torch.Tensor): The input audio signal, with shape (batch_size, signal_length). + xn (Optional[torch.Tensor]): The noise signal used for stationary noise reduction. If `None`, the input + signal is used as the noise signal. Default: `None`. + + Returns: + torch.Tensor: The denoised audio signal, with the same shape as the input signal. + """ + assert x.ndim == 2 + if x.shape[-1] < self.win_length * 2: + raise Exception(f"x must be bigger than {self.win_length * 2}") + + assert xn is None or xn.ndim == 1 or xn.ndim == 2 + if xn is not None and xn.shape[-1] < self.win_length * 2: + raise Exception(f"xn must be bigger than {self.win_length * 2}") + + # Compute short-time Fourier transform (STFT) + X = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + return_complex=True, + pad_mode="constant", + center=True, + window=torch.hann_window(self.win_length).to(x.device), + ) + + # Compute signal mask based on stationary or nonstationary assumptions + if self.nonstationary: + sig_mask = self._nonstationary_mask(X.abs()) + else: + sig_mask = self._stationary_mask(amp_to_db(X), xn) + + # Propagate decrease in signal power + sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0 + + # Smooth signal mask with 2D convolution + if self.smoothing_filter is not None: + sig_mask = conv2d( + sig_mask.unsqueeze(1), + self.smoothing_filter.to(sig_mask.dtype), + padding="same", + ) + + # Apply signal mask to STFT magnitude and phase components + Y = X * sig_mask.squeeze(1) + + # Inverse STFT to obtain time-domain signal + y = torch.istft( + Y, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=True, + window=torch.hann_window(self.win_length).to(Y.device), + ) + + return y.to(dtype=x.dtype) diff --git a/tools/torchgate/utils.py b/tools/torchgate/utils.py new file mode 100644 index 0000000..dc97d45 --- /dev/null +++ b/tools/torchgate/utils.py @@ -0,0 +1,66 @@ +import torch +from torch.types import Number + + +@torch.no_grad() +def amp_to_db(x: torch.Tensor, eps=torch.finfo(torch.float64).eps, top_db=40) -> torch.Tensor: + """ + Convert the input tensor from amplitude to decibel scale. + + Arguments: + x {[torch.Tensor]} -- [Input tensor.] + + Keyword Arguments: + eps {[float]} -- [Small value to avoid numerical instability.] + (default: {torch.finfo(torch.float64).eps}) + top_db {[float]} -- [threshold the output at ``top_db`` below the peak] + ` (default: {40}) + + Returns: + [torch.Tensor] -- [Output tensor in decibel scale.] + """ + x_db = 20 * torch.log10(x.abs() + eps) + return torch.max(x_db, (x_db.max(-1).values - top_db).unsqueeze(-1)) + + +@torch.no_grad() +def temperature_sigmoid(x: torch.Tensor, x0: float, temp_coeff: float) -> torch.Tensor: + """ + Apply a sigmoid function with temperature scaling. + + Arguments: + x {[torch.Tensor]} -- [Input tensor.] + x0 {[float]} -- [Parameter that controls the threshold of the sigmoid.] + temp_coeff {[float]} -- [Parameter that controls the slope of the sigmoid.] + + Returns: + [torch.Tensor] -- [Output tensor after applying the sigmoid with temperature scaling.] + """ + return torch.sigmoid((x - x0) / temp_coeff) + + +@torch.no_grad() +def linspace(start: Number, stop: Number, num: int = 50, endpoint: bool = True, **kwargs) -> torch.Tensor: + """ + Generate a linearly spaced 1-D tensor. + + Arguments: + start {[Number]} -- [The starting value of the sequence.] + stop {[Number]} -- [The end value of the sequence, unless `endpoint` is set to False. + In that case, the sequence consists of all but the last of ``num + 1`` + evenly spaced samples, so that `stop` is excluded. Note that the step + size changes when `endpoint` is False.] + + Keyword Arguments: + num {[int]} -- [Number of samples to generate. Default is 50. Must be non-negative.] + endpoint {[bool]} -- [If True, `stop` is the last sample. Otherwise, it is not included. + Default is True.] + **kwargs -- [Additional arguments to be passed to the underlying PyTorch `linspace` function.] + + Returns: + [torch.Tensor] -- [1-D tensor of `num` equally spaced samples from `start` to `stop`.] + """ + if endpoint: + return torch.linspace(start, stop, num, **kwargs) + else: + return torch.linspace(start, stop, num + 1, **kwargs)[:-1]