Update real-time gui (#1174)

* loudness factor control and gpu-accelerated noise reduction

* loudness factor control and gpu-accelerated noise reduction

* loudness factor control and gpu-accelerated noise reduction
This commit is contained in:
yxlllc 2023-09-03 13:57:31 +08:00 committed by GitHub
parent b5050fbf0d
commit 1457169e7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 434 additions and 63 deletions

View File

@ -4,11 +4,12 @@
"sg_input_device": "VoiceMeeter Output (VB-Audio Vo (MME)", "sg_input_device": "VoiceMeeter Output (VB-Audio Vo (MME)",
"sg_output_device": "VoiceMeeter Aux Input (VB-Audio (MME)", "sg_output_device": "VoiceMeeter Aux Input (VB-Audio (MME)",
"threhold": -45.0, "threhold": -45.0,
"pitch": 0.0, "pitch": 12.0,
"index_rate": 1.0, "index_rate": 0.0,
"block_time": 0.09, "rms_mix_rate": 0.0,
"crossfade_length": 0.15, "block_time": 0.25,
"extra_time": 5.0, "crossfade_length": 0.04,
"n_cpu": 8.0, "extra_time": 2.0,
"n_cpu": 6.0,
"f0method": "rmvpe" "f0method": "rmvpe"
} }

132
gui_v1.py
View File

@ -51,7 +51,7 @@ if __name__ == "__main__":
from queue import Empty from queue import Empty
import librosa import librosa
import noisereduce as nr from tools.torchgate import TorchGate
import numpy as np import numpy as np
import PySimpleGUI as sg import PySimpleGUI as sg
import sounddevice as sd import sounddevice as sd
@ -80,15 +80,16 @@ if __name__ == "__main__":
def __init__(self) -> None: def __init__(self) -> None:
self.pth_path: str = "" self.pth_path: str = ""
self.index_path: str = "" self.index_path: str = ""
self.pitch: int = 12 self.pitch: int = 0
self.samplerate: int = 40000 self.samplerate: int = 40000
self.block_time: float = 1.0 # s self.block_time: float = 1.0 # s
self.buffer_num: int = 1 self.buffer_num: int = 1
self.threhold: int = -30 self.threhold: int = -60
self.crossfade_time: float = 0.08 self.crossfade_time: float = 0.04
self.extra_time: float = 0.04 self.extra_time: float = 2.0
self.I_noise_reduce = False self.I_noise_reduce = False
self.O_noise_reduce = False self.O_noise_reduce = False
self.rms_mix_rate = 0.0
self.index_rate = 0.3 self.index_rate = 0.3
self.n_cpu = min(n_cpu, 6) self.n_cpu = min(n_cpu, 6)
self.f0method = "harvest" self.f0method = "harvest"
@ -118,14 +119,19 @@ if __name__ == "__main__":
"index_path": " ", "index_path": " ",
"sg_input_device": input_devices[sd.default.device[0]], "sg_input_device": input_devices[sd.default.device[0]],
"sg_output_device": output_devices[sd.default.device[1]], "sg_output_device": output_devices[sd.default.device[1]],
"threhold": "-45", "threhold": "-60",
"pitch": "0", "pitch": "0",
"index_rate": "0", "index_rate": "0",
"block_time": "1", "rms_mix_rate": "0",
"block_time": "0.25",
"crossfade_length": "0.04", "crossfade_length": "0.04",
"extra_time": "1", "extra_time": "2",
"f0method": "rmvpe", "f0method": "rmvpe",
} }
data["pm"] = data["f0method"] == "pm"
data["harvest"] = data["f0method"] == "harvest"
data["crepe"] = data["f0method"] == "crepe"
data["rmvpe"] = data["f0method"] == "rmvpe"
return data return data
def launcher(self): def launcher(self):
@ -198,7 +204,7 @@ if __name__ == "__main__":
key="threhold", key="threhold",
resolution=1, resolution=1,
orientation="h", orientation="h",
default_value=data.get("threhold", ""), default_value=data.get("threhold", "-60"),
enable_events=True, enable_events=True,
), ),
], ],
@ -209,7 +215,7 @@ if __name__ == "__main__":
key="pitch", key="pitch",
resolution=1, resolution=1,
orientation="h", orientation="h",
default_value=data.get("pitch", ""), default_value=data.get("pitch", "0"),
enable_events=True, enable_events=True,
), ),
], ],
@ -220,7 +226,18 @@ if __name__ == "__main__":
key="index_rate", key="index_rate",
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=data.get("index_rate", ""), default_value=data.get("index_rate", "0"),
enable_events=True,
),
],
[
sg.Text(i18n("响度因子")),
sg.Slider(
range=(0.0, 1.0),
key="rms_mix_rate",
resolution=0.01,
orientation="h",
default_value=data.get("rms_mix_rate", "0"),
enable_events=True, enable_events=True,
), ),
], ],
@ -267,7 +284,7 @@ if __name__ == "__main__":
key="block_time", key="block_time",
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=data.get("block_time", ""), default_value=data.get("block_time", "0.25"),
enable_events=True, enable_events=True,
), ),
], ],
@ -291,7 +308,7 @@ if __name__ == "__main__":
key="crossfade_length", key="crossfade_length",
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=data.get("crossfade_length", ""), default_value=data.get("crossfade_length", "0.04"),
enable_events=True, enable_events=True,
), ),
], ],
@ -302,7 +319,7 @@ if __name__ == "__main__":
key="extra_time", key="extra_time",
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=data.get("extra_time", ""), default_value=data.get("extra_time", "2.0"),
enable_events=True, enable_events=True,
), ),
], ],
@ -369,6 +386,7 @@ if __name__ == "__main__":
"sg_output_device": values["sg_output_device"], "sg_output_device": values["sg_output_device"],
"threhold": values["threhold"], "threhold": values["threhold"],
"pitch": values["pitch"], "pitch": values["pitch"],
"rms_mix_rate": values["rms_mix_rate"],
"index_rate": values["index_rate"], "index_rate": values["index_rate"],
"block_time": values["block_time"], "block_time": values["block_time"],
"crossfade_length": values["crossfade_length"], "crossfade_length": values["crossfade_length"],
@ -399,6 +417,8 @@ if __name__ == "__main__":
self.config.index_rate = values["index_rate"] self.config.index_rate = values["index_rate"]
if hasattr(self, "rvc"): if hasattr(self, "rvc"):
self.rvc.change_index_rate(values["index_rate"]) self.rvc.change_index_rate(values["index_rate"])
elif event == "rms_mix_rate":
self.config.rms_mix_rate = values["rms_mix_rate"]
elif event in ["pm", "harvest", "crepe", "rmvpe"]: elif event in ["pm", "harvest", "crepe", "rmvpe"]:
self.config.f0method = event self.config.f0method = event
elif event == "I_noise_reduce": elif event == "I_noise_reduce":
@ -433,6 +453,7 @@ if __name__ == "__main__":
self.config.extra_time = values["extra_time"] self.config.extra_time = values["extra_time"]
self.config.I_noise_reduce = values["I_noise_reduce"] self.config.I_noise_reduce = values["I_noise_reduce"]
self.config.O_noise_reduce = values["O_noise_reduce"] self.config.O_noise_reduce = values["O_noise_reduce"]
self.config.rms_mix_rate = values["rms_mix_rate"]
self.config.index_rate = values["index_rate"] self.config.index_rate = values["index_rate"]
self.config.n_cpu = values["n_cpu"] self.config.n_cpu = values["n_cpu"]
self.config.f0method = ["pm", "harvest", "crepe", "rmvpe"][ self.config.f0method = ["pm", "harvest", "crepe", "rmvpe"][
@ -457,17 +478,14 @@ if __name__ == "__main__":
inp_q, inp_q,
opt_q, opt_q,
device, device,
self.rvc if hasattr(self, "rvc") else None, self.rvc if hasattr(self, "rvc") else None
) )
self.config.samplerate = self.rvc.tgt_sr self.config.samplerate = self.rvc.tgt_sr
self.config.crossfade_time = min( self.config.crossfade_time = min(
self.config.crossfade_time, self.config.block_time self.config.crossfade_time, self.config.block_time
) )
self.zc = self.rvc.tgt_sr // 100 self.zc = self.rvc.tgt_sr // 100
self.block_frame = ( self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc
int(np.round(self.config.block_time * self.config.samplerate / self.zc))
* self.zc
)
self.block_frame_16k = 160 * self.block_frame // self.zc self.block_frame_16k = 160 * self.block_frame // self.zc
self.crossfade_frame = int( self.crossfade_frame = int(
self.config.crossfade_time * self.config.samplerate self.config.crossfade_time * self.config.samplerate
@ -489,9 +507,7 @@ if __name__ == "__main__":
), ),
dtype="float32", dtype="float32",
) )
self.input_wav_res: torch.Tensor = torch.zeros( self.input_wav_res: torch.Tensor= torch.zeros(160 * len(self.input_wav) // self.zc, device=device,dtype=torch.float32)
160 * len(self.input_wav) // self.zc
)
self.output_wav_cache: torch.Tensor = torch.zeros( self.output_wav_cache: torch.Tensor = torch.zeros(
int( int(
np.ceil( np.ceil(
@ -540,6 +556,8 @@ if __name__ == "__main__":
self.resampler = tat.Resample( self.resampler = tat.Resample(
orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32 orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32
).to(device) ).to(device)
self.input_tg = TorchGate(sr=16000, nonstationary=True, n_fft=640).to(device)
self.output_tg = TorchGate(sr=self.config.samplerate, nonstationary=True, n_fft=4*self.zc).to(device)
thread_vc = threading.Thread(target=self.soundinput) thread_vc = threading.Thread(target=self.soundinput)
thread_vc.start() thread_vc.start()
@ -568,9 +586,6 @@ if __name__ == "__main__":
""" """
start_time = time.perf_counter() start_time = time.perf_counter()
indata = librosa.to_mono(indata.T) indata = librosa.to_mono(indata.T)
if self.config.I_noise_reduce:
indata[:] = nr.reduce_noise(y=indata, sr=self.config.samplerate)
"""noise gate"""
frame_length = 2048 frame_length = 2048
hop_length = 1024 hop_length = 1024
rms = librosa.feature.rms( rms = librosa.feature.rms(
@ -584,18 +599,13 @@ if __name__ == "__main__":
if db_threhold[i]: if db_threhold[i]:
indata[i * hop_length : (i + 1) * hop_length] = 0 indata[i * hop_length : (i + 1) * hop_length] = 0
self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :] self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :]
self.input_wav[-self.block_frame :] = indata self.input_wav[-self.block_frame: ] = indata
# infer # infer
inp = torch.from_numpy( inp = torch.from_numpy(self.input_wav[-self.block_frame-2*self.zc :]).to(device)
self.input_wav[-self.block_frame - 2 * self.zc :] self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone()
).to(device) self.input_wav_res[-self.block_frame_16k-160 :] = self.resampler(inp)[160 :]
self.input_wav_res[: -self.block_frame_16k] = self.input_wav_res[ if self.config.I_noise_reduce:
self.block_frame_16k : self.input_wav_res[-self.block_frame_16k-320 :] = self.input_tg(self.input_wav_res[None, -self.block_frame_16k-800 :])[0, 480 : ]
].clone()
self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler(inp)[
160:
]
rate = ( rate = (
self.crossfade_frame + self.sola_search_frame + self.block_frame self.crossfade_frame + self.sola_search_frame + self.block_frame
) / ( ) / (
@ -605,11 +615,11 @@ if __name__ == "__main__":
+ self.block_frame + self.block_frame
) )
f0_extractor_frame = self.block_frame_16k + 800 f0_extractor_frame = self.block_frame_16k + 800
if self.config.f0method == "rmvpe": if self.config.f0method == 'rmvpe':
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1) f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1)
res2 = self.rvc.infer( res2 = self.rvc.infer(
self.input_wav_res, self.input_wav_res,
self.input_wav_res[-f0_extractor_frame:].cpu().numpy(), self.input_wav_res[-f0_extractor_frame :].cpu().numpy(),
self.block_frame_16k, self.block_frame_16k,
rate, rate,
self.pitch, self.pitch,
@ -620,6 +630,27 @@ if __name__ == "__main__":
infer_wav = self.output_wav_cache[ infer_wav = self.output_wav_cache[
-self.crossfade_frame - self.sola_search_frame - self.block_frame : -self.crossfade_frame - self.sola_search_frame - self.block_frame :
] ]
if self.config.O_noise_reduce:
infer_wav = self.output_tg(infer_wav.unsqueeze(0)).squeeze(0)
if self.config.rms_mix_rate < 1:
rms1 = librosa.feature.rms(
y=self.input_wav[-self.crossfade_frame - self.sola_search_frame - self.block_frame :],
frame_length=frame_length,
hop_length=hop_length
)
rms1 = torch.from_numpy(rms1).to(device)
rms1 = F.interpolate(
rms1.unsqueeze(0), size=infer_wav.shape[0], mode="linear"
).squeeze()
rms2 = librosa.feature.rms(
y=infer_wav[:].cpu().numpy(), frame_length=frame_length, hop_length=hop_length
)
rms2 = torch.from_numpy(rms2).to(device)
rms2 = F.interpolate(
rms2.unsqueeze(0), size=infer_wav.shape[0], mode="linear"
).squeeze()
rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3)
infer_wav *= torch.pow(rms1 / rms2, torch.tensor(1 - self.config.rms_mix_rate))
# SOLA algorithm from https://github.com/yxlllc/DDSP-SVC # SOLA algorithm from https://github.com/yxlllc/DDSP-SVC
cor_nom = F.conv1d( cor_nom = F.conv1d(
infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame], infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame],
@ -659,25 +690,10 @@ if __name__ == "__main__":
self.sola_buffer[:] = ( self.sola_buffer[:] = (
infer_wav[-self.crossfade_frame :] * self.fade_out_window infer_wav[-self.crossfade_frame :] * self.fade_out_window
) )
if self.config.O_noise_reduce: if sys.platform == "darwin":
if sys.platform == "darwin": outdata[:] = self.output_wav[:].cpu().numpy()[:, np.newaxis]
noise_reduced_signal = nr.reduce_noise(
y=self.output_wav[:].cpu().numpy(), sr=self.config.samplerate
)
outdata[:] = noise_reduced_signal[:, np.newaxis]
else:
outdata[:] = np.tile(
nr.reduce_noise(
y=self.output_wav[:].cpu().numpy(),
sr=self.config.samplerate,
),
(2, 1),
).T
else: else:
if sys.platform == "darwin": outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy()
outdata[:] = self.output_wav[:].cpu().numpy()[:, np.newaxis]
else:
outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy()
total_time = time.perf_counter() - start_time total_time = time.perf_counter() - start_time
self.window["infer_time"].update(int(total_time * 1000)) self.window["infer_time"].update(int(total_time * 1000))
logger.info("Infer time: %.2f", total_time) logger.info("Infer time: %.2f", total_time)
@ -733,7 +749,9 @@ if __name__ == "__main__":
sd.default.device[1] = output_device_indices[ sd.default.device[1] = output_device_indices[
output_devices.index(output_device) output_devices.index(output_device)
] ]
logger.info("Input device: %s:%d", str(sd.default.device[0]), input_device) logger.info(
"Input device: %s:%d", str(sd.default.device[0]), input_device
)
logger.info( logger.info(
"Output device: %s:%d", str(sd.default.device[1]), output_device "Output device: %s:%d", str(sd.default.device[1]), output_device
) )

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "Resample the output audio in post-processing to the final sample rate. Set to 0 for no resampling:", "后处理重采样至最终采样率0为不进行重采样": "Resample the output audio in post-processing to the final sample rate. Set to 0 for no resampling:",
"否": "No", "否": "No",
"响应阈值": "Response threshold", "响应阈值": "Response threshold",
"响度因子": "loudness factor",
"处理数据": "Process data", "处理数据": "Process data",
"导出Onnx模型": "Export Onnx Model", "导出Onnx模型": "Export Onnx Model",
"导出文件格式": "Export file format", "导出文件格式": "Export file format",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "Remuestreo posterior al proceso a la tasa de muestreo final, 0 significa no remuestrear", "后处理重采样至最终采样率0为不进行重采样": "Remuestreo posterior al proceso a la tasa de muestreo final, 0 significa no remuestrear",
"否": "No", "否": "No",
"响应阈值": "Umbral de respuesta", "响应阈值": "Umbral de respuesta",
"响度因子": "factor de sonoridad",
"处理数据": "Procesar datos", "处理数据": "Procesar datos",
"导出Onnx模型": "Exportar modelo Onnx", "导出Onnx模型": "Exportar modelo Onnx",
"导出文件格式": "Formato de archivo de exportación", "导出文件格式": "Formato de archivo de exportación",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "Ricampiona l'audio di output in post-elaborazione alla frequenza di campionamento finale. ", "后处理重采样至最终采样率0为不进行重采样": "Ricampiona l'audio di output in post-elaborazione alla frequenza di campionamento finale. ",
"否": "NO", "否": "NO",
"响应阈值": "Soglia di risposta", "响应阈值": "Soglia di risposta",
"响度因子": "fattore di sonorità",
"处理数据": "Processa dati", "处理数据": "Processa dati",
"导出Onnx模型": "Esporta modello Onnx", "导出Onnx模型": "Esporta modello Onnx",
"导出文件格式": "Formato file di esportazione", "导出文件格式": "Formato file di esportazione",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "最終的なサンプリングレートへのポストプロセッシングのリサンプリング リサンプリングしない場合は0", "后处理重采样至最终采样率0为不进行重采样": "最終的なサンプリングレートへのポストプロセッシングのリサンプリング リサンプリングしない場合は0",
"否": "いいえ", "否": "いいえ",
"响应阈值": "反応閾値", "响应阈值": "反応閾値",
"响度因子": "ラウドネス係数",
"处理数据": "データ処理", "处理数据": "データ処理",
"导出Onnx模型": "Onnxに変換", "导出Onnx模型": "Onnxに変換",
"导出文件格式": "エクスポート形式", "导出文件格式": "エクスポート形式",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "Изменить частоту дискретизации в выходном файле на финальную. Поставьте 0, чтобы ничего не изменялось:", "后处理重采样至最终采样率0为不进行重采样": "Изменить частоту дискретизации в выходном файле на финальную. Поставьте 0, чтобы ничего не изменялось:",
"否": "Нет", "否": "Нет",
"响应阈值": "Порог ответа", "响应阈值": "Порог ответа",
"响度因子": "коэффициент громкости",
"处理数据": "Обработать данные", "处理数据": "Обработать данные",
"导出Onnx模型": "Экспортировать модель", "导出Onnx模型": "Экспортировать модель",
"导出文件格式": "Формат выходных файлов", "导出文件格式": "Формат выходных файлов",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "Son işleme aşamasında çıktı sesini son örnekleme hızına yeniden örnekle. 0 değeri için yeniden örnekleme yapılmaz:", "后处理重采样至最终采样率0为不进行重采样": "Son işleme aşamasında çıktı sesini son örnekleme hızına yeniden örnekle. 0 değeri için yeniden örnekleme yapılmaz:",
"否": "Hayır", "否": "Hayır",
"响应阈值": "Tepki eşiği", "响应阈值": "Tepki eşiği",
"响度因子": "ses yüksekliği faktörü",
"处理数据": "Verileri işle", "处理数据": "Verileri işle",
"导出Onnx模型": "Onnx Modeli Dışa Aktar", "导出Onnx模型": "Onnx Modeli Dışa Aktar",
"导出文件格式": "Dışa aktarma dosya formatı", "导出文件格式": "Dışa aktarma dosya formatı",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "后处理重采样至最终采样率0为不进行重采样", "后处理重采样至最终采样率0为不进行重采样": "后处理重采样至最终采样率0为不进行重采样",
"否": "否", "否": "否",
"响应阈值": "响应阈值", "响应阈值": "响应阈值",
"响度因子": "响度因子",
"处理数据": "处理数据", "处理数据": "处理数据",
"导出Onnx模型": "导出Onnx模型", "导出Onnx模型": "导出Onnx模型",
"导出文件格式": "导出文件格式", "导出文件格式": "导出文件格式",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "後處理重採樣至最終採樣率0為不進行重採樣", "后处理重采样至最终采样率0为不进行重采样": "後處理重採樣至最終採樣率0為不進行重採樣",
"否": "否", "否": "否",
"响应阈值": "響應閾值", "响应阈值": "響應閾值",
"响度因子": "響度因子",
"处理数据": "處理資料", "处理数据": "處理資料",
"导出Onnx模型": "导出Onnx模型", "导出Onnx模型": "导出Onnx模型",
"导出文件格式": "導出檔格式", "导出文件格式": "導出檔格式",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "後處理重採樣至最終採樣率0為不進行重採樣", "后处理重采样至最终采样率0为不进行重采样": "後處理重採樣至最終採樣率0為不進行重採樣",
"否": "否", "否": "否",
"响应阈值": "響應閾值", "响应阈值": "響應閾值",
"响度因子": "響度因子",
"处理数据": "處理資料", "处理数据": "處理資料",
"导出Onnx模型": "导出Onnx模型", "导出Onnx模型": "导出Onnx模型",
"导出文件格式": "導出檔格式", "导出文件格式": "導出檔格式",

View File

@ -43,6 +43,7 @@
"后处理重采样至最终采样率0为不进行重采样": "後處理重採樣至最終採樣率0為不進行重採樣", "后处理重采样至最终采样率0为不进行重采样": "後處理重採樣至最終採樣率0為不進行重採樣",
"否": "否", "否": "否",
"响应阈值": "響應閾值", "响应阈值": "響應閾值",
"响度因子": "響度因子",
"处理数据": "處理資料", "处理数据": "處理資料",
"导出Onnx模型": "导出Onnx模型", "导出Onnx模型": "导出Onnx模型",
"导出文件格式": "導出檔格式", "导出文件格式": "導出檔格式",

View File

@ -0,0 +1,12 @@
"""
TorchGating is a PyTorch-based implementation of Spectral Gating
================================================
Author: Asaf Zorea
Contents
--------
torchgate imports all the functions from PyTorch, and in addition provides:
TorchGating --- A PyTorch module that applies a spectral gate to an input signal
"""
from .torchgate import TorchGate

View File

@ -0,0 +1,264 @@
import torch
from torch.nn.functional import conv1d, conv2d
from typing import Union, Optional
from .utils import linspace, temperature_sigmoid, amp_to_db
class TorchGate(torch.nn.Module):
"""
A PyTorch module that applies a spectral gate to an input signal.
Arguments:
sr {int} -- Sample rate of the input signal.
nonstationary {bool} -- Whether to use non-stationary or stationary masking (default: {False}).
n_std_thresh_stationary {float} -- Number of standard deviations above mean to threshold noise for
stationary masking (default: {1.5}).
n_thresh_nonstationary {float} -- Number of multiplies above smoothed magnitude spectrogram. for
non-stationary masking (default: {1.3}).
temp_coeff_nonstationary {float} -- Temperature coefficient for non-stationary masking (default: {0.1}).
n_movemean_nonstationary {int} -- Number of samples for moving average smoothing in non-stationary masking
(default: {20}).
prop_decrease {float} -- Proportion to decrease signal by where the mask is zero (default: {1.0}).
n_fft {int} -- Size of FFT for STFT (default: {1024}).
win_length {[int]} -- Window length for STFT. If None, defaults to `n_fft` (default: {None}).
hop_length {[int]} -- Hop length for STFT. If None, defaults to `win_length` // 4 (default: {None}).
freq_mask_smooth_hz {float} -- Frequency smoothing width for mask (in Hz). If None, no smoothing is applied
(default: {500}).
time_mask_smooth_ms {float} -- Time smoothing width for mask (in ms). If None, no smoothing is applied
(default: {50}).
"""
@torch.no_grad()
def __init__(
self,
sr: int,
nonstationary: bool = False,
n_std_thresh_stationary: float = 1.5,
n_thresh_nonstationary: float = 1.3,
temp_coeff_nonstationary: float = 0.1,
n_movemean_nonstationary: int = 20,
prop_decrease: float = 1.0,
n_fft: int = 1024,
win_length: bool = None,
hop_length: int = None,
freq_mask_smooth_hz: float = 500,
time_mask_smooth_ms: float = 50,
):
super().__init__()
# General Params
self.sr = sr
self.nonstationary = nonstationary
assert 0.0 <= prop_decrease <= 1.0
self.prop_decrease = prop_decrease
# STFT Params
self.n_fft = n_fft
self.win_length = self.n_fft if win_length is None else win_length
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
# Stationary Params
self.n_std_thresh_stationary = n_std_thresh_stationary
# Non-Stationary Params
self.temp_coeff_nonstationary = temp_coeff_nonstationary
self.n_movemean_nonstationary = n_movemean_nonstationary
self.n_thresh_nonstationary = n_thresh_nonstationary
# Smooth Mask Params
self.freq_mask_smooth_hz = freq_mask_smooth_hz
self.time_mask_smooth_ms = time_mask_smooth_ms
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
@torch.no_grad()
def _generate_mask_smoothing_filter(self) -> Union[torch.Tensor, None]:
"""
A PyTorch module that applies a spectral gate to an input signal using the STFT.
Returns:
smoothing_filter (torch.Tensor): a 2D tensor representing the smoothing filter,
with shape (n_grad_freq, n_grad_time), where n_grad_freq is the number of frequency
bins to smooth and n_grad_time is the number of time frames to smooth.
If both self.freq_mask_smooth_hz and self.time_mask_smooth_ms are None, returns None.
"""
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None:
return None
n_grad_freq = (
1
if self.freq_mask_smooth_hz is None
else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2)))
)
if n_grad_freq < 1:
raise ValueError(
f"freq_mask_smooth_hz needs to be at least {int((self.sr / (self._n_fft / 2)))} Hz"
)
n_grad_time = (
1
if self.time_mask_smooth_ms is None
else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000))
)
if n_grad_time < 1:
raise ValueError(
f"time_mask_smooth_ms needs to be at least {int((self.hop_length / self.sr) * 1000)} ms"
)
if n_grad_time == 1 and n_grad_freq == 1:
return None
v_f = torch.cat(
[
linspace(0, 1, n_grad_freq + 1, endpoint=False),
linspace(1, 0, n_grad_freq + 2),
]
)[1:-1]
v_t = torch.cat(
[
linspace(0, 1, n_grad_time + 1, endpoint=False),
linspace(1, 0, n_grad_time + 2),
]
)[1:-1]
smoothing_filter = torch.outer(v_f, v_t).unsqueeze(0).unsqueeze(0)
return smoothing_filter / smoothing_filter.sum()
@torch.no_grad()
def _stationary_mask(
self, X_db: torch.Tensor, xn: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Computes a stationary binary mask to filter out noise in a log-magnitude spectrogram.
Arguments:
X_db (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the log-magnitude spectrogram.
xn (torch.Tensor): 1D tensor containing the audio signal corresponding to X_db.
Returns:
sig_mask (torch.Tensor): Binary mask of the same shape as X_db, where values greater than the threshold
are set to 1, and the rest are set to 0.
"""
if xn is not None:
XN = torch.stft(
xn,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(xn.device),
)
XN_db = amp_to_db(XN).to(dtype=X_db.dtype)
else:
XN_db = X_db
# calculate mean and standard deviation along the frequency axis
std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
# compute noise threshold
noise_thresh = mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary
# create binary mask by thresholding the spectrogram
sig_mask = X_db > noise_thresh.unsqueeze(2)
return sig_mask
@torch.no_grad()
def _nonstationary_mask(self, X_abs: torch.Tensor) -> torch.Tensor:
"""
Computes a non-stationary binary mask to filter out noise in a log-magnitude spectrogram.
Arguments:
X_abs (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the magnitude spectrogram.
Returns:
sig_mask (torch.Tensor): Binary mask of the same shape as X_abs, where values greater than the threshold
are set to 1, and the rest are set to 0.
"""
X_smoothed = (
conv1d(
X_abs.reshape(-1, 1, X_abs.shape[-1]),
torch.ones(
self.n_movemean_nonstationary,
dtype=X_abs.dtype,
device=X_abs.device,
).view(1, 1, -1),
padding="same",
).view(X_abs.shape)
/ self.n_movemean_nonstationary
)
# Compute slowness ratio and apply temperature sigmoid
slowness_ratio = (X_abs - X_smoothed) / (X_smoothed + 1e-6)
sig_mask = temperature_sigmoid(
slowness_ratio, self.n_thresh_nonstationary, self.temp_coeff_nonstationary
)
return sig_mask
def forward(
self, x: torch.Tensor, xn: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Apply the proposed algorithm to the input signal.
Arguments:
x (torch.Tensor): The input audio signal, with shape (batch_size, signal_length).
xn (Optional[torch.Tensor]): The noise signal used for stationary noise reduction. If `None`, the input
signal is used as the noise signal. Default: `None`.
Returns:
torch.Tensor: The denoised audio signal, with the same shape as the input signal.
"""
assert x.ndim == 2
if x.shape[-1] < self.win_length * 2:
raise Exception(f"x must be bigger than {self.win_length * 2}")
assert xn is None or xn.ndim == 1 or xn.ndim == 2
if xn is not None and xn.shape[-1] < self.win_length * 2:
raise Exception(f"xn must be bigger than {self.win_length * 2}")
# Compute short-time Fourier transform (STFT)
X = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(x.device),
)
# Compute signal mask based on stationary or nonstationary assumptions
if self.nonstationary:
sig_mask = self._nonstationary_mask(X.abs())
else:
sig_mask = self._stationary_mask(amp_to_db(X), xn)
# Propagate decrease in signal power
sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
# Smooth signal mask with 2D convolution
if self.smoothing_filter is not None:
sig_mask = conv2d(
sig_mask.unsqueeze(1),
self.smoothing_filter.to(sig_mask.dtype),
padding="same",
)
# Apply signal mask to STFT magnitude and phase components
Y = X * sig_mask.squeeze(1)
# Inverse STFT to obtain time-domain signal
y = torch.istft(
Y,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=torch.hann_window(self.win_length).to(Y.device),
)
return y.to(dtype=x.dtype)

66
tools/torchgate/utils.py Normal file
View File

@ -0,0 +1,66 @@
import torch
from torch.types import Number
@torch.no_grad()
def amp_to_db(x: torch.Tensor, eps=torch.finfo(torch.float64).eps, top_db=40) -> torch.Tensor:
"""
Convert the input tensor from amplitude to decibel scale.
Arguments:
x {[torch.Tensor]} -- [Input tensor.]
Keyword Arguments:
eps {[float]} -- [Small value to avoid numerical instability.]
(default: {torch.finfo(torch.float64).eps})
top_db {[float]} -- [threshold the output at ``top_db`` below the peak]
` (default: {40})
Returns:
[torch.Tensor] -- [Output tensor in decibel scale.]
"""
x_db = 20 * torch.log10(x.abs() + eps)
return torch.max(x_db, (x_db.max(-1).values - top_db).unsqueeze(-1))
@torch.no_grad()
def temperature_sigmoid(x: torch.Tensor, x0: float, temp_coeff: float) -> torch.Tensor:
"""
Apply a sigmoid function with temperature scaling.
Arguments:
x {[torch.Tensor]} -- [Input tensor.]
x0 {[float]} -- [Parameter that controls the threshold of the sigmoid.]
temp_coeff {[float]} -- [Parameter that controls the slope of the sigmoid.]
Returns:
[torch.Tensor] -- [Output tensor after applying the sigmoid with temperature scaling.]
"""
return torch.sigmoid((x - x0) / temp_coeff)
@torch.no_grad()
def linspace(start: Number, stop: Number, num: int = 50, endpoint: bool = True, **kwargs) -> torch.Tensor:
"""
Generate a linearly spaced 1-D tensor.
Arguments:
start {[Number]} -- [The starting value of the sequence.]
stop {[Number]} -- [The end value of the sequence, unless `endpoint` is set to False.
In that case, the sequence consists of all but the last of ``num + 1``
evenly spaced samples, so that `stop` is excluded. Note that the step
size changes when `endpoint` is False.]
Keyword Arguments:
num {[int]} -- [Number of samples to generate. Default is 50. Must be non-negative.]
endpoint {[bool]} -- [If True, `stop` is the last sample. Otherwise, it is not included.
Default is True.]
**kwargs -- [Additional arguments to be passed to the underlying PyTorch `linspace` function.]
Returns:
[torch.Tensor] -- [1-D tensor of `num` equally spaced samples from `start` to `stop`.]
"""
if endpoint:
return torch.linspace(start, stop, num, **kwargs)
else:
return torch.linspace(start, stop, num + 1, **kwargs)[:-1]