Retrieval-based-Voice-Conve.../tools/torchgate/torchgate.py
github-actions[bot] e9dd11bddb
chore(sync): merge dev into main (#1379)
* Optimize latency (#1259)

* add attribute:   configs/config.py
	Optimize latency:   tools/rvc_for_realtime.py

* new file:   assets/Synthesizer_inputs.pth

* fix:   configs/config.py
	fix:   tools/rvc_for_realtime.py

* fix bug:   infer/lib/infer_pack/models.py

* new file:   assets/hubert_inputs.pth
	new file:   assets/rmvpe_inputs.pth
	modified:   configs/config.py
	new features:   infer/lib/rmvpe.py
	new features:   tools/jit_export/__init__.py
	new features:   tools/jit_export/get_hubert.py
	new features:   tools/jit_export/get_rmvpe.py
	new features:   tools/jit_export/get_synthesizer.py
	optimize:   tools/rvc_for_realtime.py

* optimize:   tools/jit_export/get_synthesizer.py
	fix bug:   tools/jit_export/__init__.py

* Fixed a bug caused by using half on the CPU:   infer/lib/rmvpe.py
	Fixed a bug caused by using half on the CPU:   tools/jit_export/__init__.py
	Fixed CIRCULAR IMPORT:   tools/jit_export/get_rmvpe.py
	Fixed CIRCULAR IMPORT:   tools/jit_export/get_synthesizer.py
	Fixed a bug caused by using half on the CPU:   tools/rvc_for_realtime.py

* Remove useless code:   infer/lib/rmvpe.py

* Delete gui_v1 copy.py

* Delete .vscode/launch.json

* Delete jit_export_test.py

* Delete tools/rvc_for_realtime copy.py

* Delete configs/config.json

* Delete .gitignore

* Fix exceptions caused by switching inference devices:   infer/lib/rmvpe.py
	Fix exceptions caused by switching inference devices:   tools/jit_export/__init__.py
	Fix exceptions caused by switching inference devices:   tools/rvc_for_realtime.py

* restore

* replace(you can undo this commit)

* remove debug_print

---------

Co-authored-by: Ftps <ftpsflandre@gmail.com>

* Fixed some bugs when exporting ONNX model (#1254)

* fix import (#1280)

* fix import

* lint

* 🎨 同步 locale (#1242)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Fix jit load and import issue (#1282)

* fix jit model loading :   infer/lib/rmvpe.py

* modified:   assets/hubert/.gitignore
	move file:    assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth
	modified:   assets/rmvpe/.gitignore
	move file:    assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth
	fix import:   gui_v1.py

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* Add input wav and delay time monitor for real-time gui (#1293)

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* 🎨 同步 locale (#1289)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: edit PR template

* add input wav and delay time monitor

---------

Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>

* Optimize latency using scripted jit (#1291)

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* 🎨 同步 locale (#1289)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: edit PR template

* Optimize-latency-using-scripted:   configs/config.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/attentions.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/commons.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/models.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/modules.py
	Optimize-latency-using-scripted:   infer/lib/jit/__init__.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_hubert.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_rmvpe.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_synthesizer.py
	Optimize-latency-using-scripted:   infer/lib/rmvpe.py
	Optimize-latency-using-scripted:   tools/rvc_for_realtime.py

* modified:   infer/lib/infer_pack/models.py

* fix some bug:   configs/config.py
	fix some bug:   infer/lib/infer_pack/models.py
	fix some bug:   infer/lib/rmvpe.py

* Fixed abnormal reference of logger in multiprocessing:   infer/modules/train/train.py

---------

Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Format code (#1298)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* 🎨 同步 locale (#1299)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: optimize actions

* feat(workflow): add sync dev

* feat: optimize actions

* feat: optimize actions

* feat: optimize actions

* feat: optimize actions

* feat: add jit options (#1303)

Delete useless code:   infer/lib/jit/get_synthesizer.py
	Optimized code:   tools/rvc_for_realtime.py

* Code refactor + re-design inference ui (#1304)

* Code refacor + re-design inference ui

* Fix tabname

* i18n jp

---------

Co-authored-by: Ftps <ftpsflandre@gmail.com>

* feat: optimize actions

* feat: optimize actions

* Update README & en_US locale file (#1309)

* critical: some bug fixes (#1322)

* JIT acceleration switch does not support hot update

* fix padding bug of rmvpe in torch-directml

* fix padding bug of rmvpe in torch-directml

* Fix STFT under torch_directml (#1330)

* chore(format): run black on dev (#1318)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* chore(i18n): sync locale on dev (#1317)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: allow for tta to be passed to uvr (#1361)

* chore(format): run black on dev (#1373)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Added script for automatically download all needed models at install (#1366)

* Delete modules.py

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* chore(i18n): sync locale on dev (#1377)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* chore(format): run black on dev (#1376)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Update IPEX library (#1362)

* Update IPEX library

* Update ipex index

* chore(format): run black on dev (#1378)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com>
Co-authored-by: Ftps <ftpsflandre@gmail.com>
Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com>
Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com>
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com>
Co-authored-by: Rice Cake <gak141808@gmail.com>
Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com>
Co-authored-by: Dmitry <nda2911@yandex.ru>
Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
2023-10-06 17:14:33 +08:00

281 lines
11 KiB
Python

import torch
from infer.lib.rmvpe import STFT
from torch.nn.functional import conv1d, conv2d
from typing import Union, Optional
from .utils import linspace, temperature_sigmoid, amp_to_db
class TorchGate(torch.nn.Module):
"""
A PyTorch module that applies a spectral gate to an input signal.
Arguments:
sr {int} -- Sample rate of the input signal.
nonstationary {bool} -- Whether to use non-stationary or stationary masking (default: {False}).
n_std_thresh_stationary {float} -- Number of standard deviations above mean to threshold noise for
stationary masking (default: {1.5}).
n_thresh_nonstationary {float} -- Number of multiplies above smoothed magnitude spectrogram. for
non-stationary masking (default: {1.3}).
temp_coeff_nonstationary {float} -- Temperature coefficient for non-stationary masking (default: {0.1}).
n_movemean_nonstationary {int} -- Number of samples for moving average smoothing in non-stationary masking
(default: {20}).
prop_decrease {float} -- Proportion to decrease signal by where the mask is zero (default: {1.0}).
n_fft {int} -- Size of FFT for STFT (default: {1024}).
win_length {[int]} -- Window length for STFT. If None, defaults to `n_fft` (default: {None}).
hop_length {[int]} -- Hop length for STFT. If None, defaults to `win_length` // 4 (default: {None}).
freq_mask_smooth_hz {float} -- Frequency smoothing width for mask (in Hz). If None, no smoothing is applied
(default: {500}).
time_mask_smooth_ms {float} -- Time smoothing width for mask (in ms). If None, no smoothing is applied
(default: {50}).
"""
@torch.no_grad()
def __init__(
self,
sr: int,
nonstationary: bool = False,
n_std_thresh_stationary: float = 1.5,
n_thresh_nonstationary: float = 1.3,
temp_coeff_nonstationary: float = 0.1,
n_movemean_nonstationary: int = 20,
prop_decrease: float = 1.0,
n_fft: int = 1024,
win_length: bool = None,
hop_length: int = None,
freq_mask_smooth_hz: float = 500,
time_mask_smooth_ms: float = 50,
):
super().__init__()
# General Params
self.sr = sr
self.nonstationary = nonstationary
assert 0.0 <= prop_decrease <= 1.0
self.prop_decrease = prop_decrease
# STFT Params
self.n_fft = n_fft
self.win_length = self.n_fft if win_length is None else win_length
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
# Stationary Params
self.n_std_thresh_stationary = n_std_thresh_stationary
# Non-Stationary Params
self.temp_coeff_nonstationary = temp_coeff_nonstationary
self.n_movemean_nonstationary = n_movemean_nonstationary
self.n_thresh_nonstationary = n_thresh_nonstationary
# Smooth Mask Params
self.freq_mask_smooth_hz = freq_mask_smooth_hz
self.time_mask_smooth_ms = time_mask_smooth_ms
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
@torch.no_grad()
def _generate_mask_smoothing_filter(self) -> Union[torch.Tensor, None]:
"""
A PyTorch module that applies a spectral gate to an input signal using the STFT.
Returns:
smoothing_filter (torch.Tensor): a 2D tensor representing the smoothing filter,
with shape (n_grad_freq, n_grad_time), where n_grad_freq is the number of frequency
bins to smooth and n_grad_time is the number of time frames to smooth.
If both self.freq_mask_smooth_hz and self.time_mask_smooth_ms are None, returns None.
"""
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None:
return None
n_grad_freq = (
1
if self.freq_mask_smooth_hz is None
else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2)))
)
if n_grad_freq < 1:
raise ValueError(
f"freq_mask_smooth_hz needs to be at least {int((self.sr / (self._n_fft / 2)))} Hz"
)
n_grad_time = (
1
if self.time_mask_smooth_ms is None
else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000))
)
if n_grad_time < 1:
raise ValueError(
f"time_mask_smooth_ms needs to be at least {int((self.hop_length / self.sr) * 1000)} ms"
)
if n_grad_time == 1 and n_grad_freq == 1:
return None
v_f = torch.cat(
[
linspace(0, 1, n_grad_freq + 1, endpoint=False),
linspace(1, 0, n_grad_freq + 2),
]
)[1:-1]
v_t = torch.cat(
[
linspace(0, 1, n_grad_time + 1, endpoint=False),
linspace(1, 0, n_grad_time + 2),
]
)[1:-1]
smoothing_filter = torch.outer(v_f, v_t).unsqueeze(0).unsqueeze(0)
return smoothing_filter / smoothing_filter.sum()
@torch.no_grad()
def _stationary_mask(
self, X_db: torch.Tensor, xn: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Computes a stationary binary mask to filter out noise in a log-magnitude spectrogram.
Arguments:
X_db (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the log-magnitude spectrogram.
xn (torch.Tensor): 1D tensor containing the audio signal corresponding to X_db.
Returns:
sig_mask (torch.Tensor): Binary mask of the same shape as X_db, where values greater than the threshold
are set to 1, and the rest are set to 0.
"""
if xn is not None:
if "privateuseone" in str(xn.device):
if not hasattr(self, "stft"):
self.stft = STFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window="hann",
).to(xn.device)
XN = self.stft.transform(xn)
else:
XN = torch.stft(
xn,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(xn.device),
)
XN_db = amp_to_db(XN).to(dtype=X_db.dtype)
else:
XN_db = X_db
# calculate mean and standard deviation along the frequency axis
std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
# compute noise threshold
noise_thresh = mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary
# create binary mask by thresholding the spectrogram
sig_mask = X_db > noise_thresh.unsqueeze(2)
return sig_mask
@torch.no_grad()
def _nonstationary_mask(self, X_abs: torch.Tensor) -> torch.Tensor:
"""
Computes a non-stationary binary mask to filter out noise in a log-magnitude spectrogram.
Arguments:
X_abs (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the magnitude spectrogram.
Returns:
sig_mask (torch.Tensor): Binary mask of the same shape as X_abs, where values greater than the threshold
are set to 1, and the rest are set to 0.
"""
X_smoothed = (
conv1d(
X_abs.reshape(-1, 1, X_abs.shape[-1]),
torch.ones(
self.n_movemean_nonstationary,
dtype=X_abs.dtype,
device=X_abs.device,
).view(1, 1, -1),
padding="same",
).view(X_abs.shape)
/ self.n_movemean_nonstationary
)
# Compute slowness ratio and apply temperature sigmoid
slowness_ratio = (X_abs - X_smoothed) / (X_smoothed + 1e-6)
sig_mask = temperature_sigmoid(
slowness_ratio, self.n_thresh_nonstationary, self.temp_coeff_nonstationary
)
return sig_mask
def forward(
self, x: torch.Tensor, xn: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Apply the proposed algorithm to the input signal.
Arguments:
x (torch.Tensor): The input audio signal, with shape (batch_size, signal_length).
xn (Optional[torch.Tensor]): The noise signal used for stationary noise reduction. If `None`, the input
signal is used as the noise signal. Default: `None`.
Returns:
torch.Tensor: The denoised audio signal, with the same shape as the input signal.
"""
# Compute short-time Fourier transform (STFT)
if "privateuseone" in str(x.device):
if not hasattr(self, "stft"):
self.stft = STFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window="hann",
).to(x.device)
X, phase = self.stft.transform(x, return_phase=True)
else:
X = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(x.device),
)
# Compute signal mask based on stationary or nonstationary assumptions
if self.nonstationary:
sig_mask = self._nonstationary_mask(X.abs())
else:
sig_mask = self._stationary_mask(amp_to_db(X), xn)
# Propagate decrease in signal power
sig_mask = self.prop_decrease * (sig_mask.float() - 1.0) + 1.0
# Smooth signal mask with 2D convolution
if self.smoothing_filter is not None:
sig_mask = conv2d(
sig_mask.unsqueeze(1),
self.smoothing_filter.to(sig_mask.dtype),
padding="same",
)
# Apply signal mask to STFT magnitude and phase components
Y = X * sig_mask.squeeze(1)
# Inverse STFT to obtain time-domain signal
if "privateuseone" in str(Y.device):
y = self.stft.inverse(Y, phase)
else:
y = torch.istft(
Y,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=torch.hann_window(self.win_length).to(Y.device),
)
return y.to(dtype=x.dtype)