mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-05-16 00:39:07 +08:00
Streaming noise reduction and other optimizations for real-time gui (#1188)
* loudness factor control and gpu-accelerated noise reduction * loudness factor control and gpu-accelerated noise reduction * loudness factor control and gpu-accelerated noise reduction * streaming noise reduction and other optimizations * streaming noise reduction and other optimizations
This commit is contained in:
parent
b09b6ad05c
commit
a669fee786
161
gui_v1.py
161
gui_v1.py
@ -5,7 +5,7 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
os.environ["OMP_NUM_THREADS"] = "2"
|
os.environ["OMP_NUM_THREADS"] = "4"
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
@ -481,49 +481,21 @@ if __name__ == "__main__":
|
|||||||
self.rvc if hasattr(self, "rvc") else None
|
self.rvc if hasattr(self, "rvc") else None
|
||||||
)
|
)
|
||||||
self.config.samplerate = self.rvc.tgt_sr
|
self.config.samplerate = self.rvc.tgt_sr
|
||||||
self.config.crossfade_time = min(
|
|
||||||
self.config.crossfade_time, self.config.block_time
|
|
||||||
)
|
|
||||||
self.zc = self.rvc.tgt_sr // 100
|
self.zc = self.rvc.tgt_sr // 100
|
||||||
self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc
|
self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc
|
||||||
self.block_frame_16k = 160 * self.block_frame // self.zc
|
self.block_frame_16k = 160 * self.block_frame // self.zc
|
||||||
self.crossfade_frame = int(
|
self.crossfade_frame = int(np.round(self.config.crossfade_time * self.config.samplerate / self.zc)) * self.zc
|
||||||
self.config.crossfade_time * self.config.samplerate
|
self.sola_search_frame = self.zc
|
||||||
)
|
self.extra_frame = int(np.round(self.config.extra_time * self.config.samplerate / self.zc)) * self.zc
|
||||||
self.sola_search_frame = int(0.01 * self.config.samplerate)
|
self.input_wav: torch.Tensor = torch.zeros(
|
||||||
self.extra_frame = int(self.config.extra_time * self.config.samplerate)
|
|
||||||
self.input_wav: np.ndarray = np.zeros(
|
|
||||||
int(
|
|
||||||
np.ceil(
|
|
||||||
(
|
|
||||||
self.extra_frame
|
self.extra_frame
|
||||||
+ self.crossfade_frame
|
+ self.crossfade_frame
|
||||||
+ self.sola_search_frame
|
+ self.sola_search_frame
|
||||||
+ self.block_frame
|
+ self.block_frame,
|
||||||
)
|
|
||||||
/ self.zc
|
|
||||||
)
|
|
||||||
* self.zc
|
|
||||||
),
|
|
||||||
dtype="float32",
|
|
||||||
)
|
|
||||||
self.input_wav_res: torch.Tensor= torch.zeros(160 * len(self.input_wav) // self.zc, device=device,dtype=torch.float32)
|
|
||||||
self.output_wav_cache: torch.Tensor = torch.zeros(
|
|
||||||
int(
|
|
||||||
np.ceil(
|
|
||||||
(
|
|
||||||
self.extra_frame
|
|
||||||
+ self.crossfade_frame
|
|
||||||
+ self.sola_search_frame
|
|
||||||
+ self.block_frame
|
|
||||||
)
|
|
||||||
/ self.zc
|
|
||||||
)
|
|
||||||
* self.zc
|
|
||||||
),
|
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
self.input_wav_res: torch.Tensor= torch.zeros(160 * self.input_wav.shape[0] // self.zc, device=device,dtype=torch.float32)
|
||||||
self.pitch: np.ndarray = np.zeros(
|
self.pitch: np.ndarray = np.zeros(
|
||||||
self.input_wav.shape[0] // self.zc,
|
self.input_wav.shape[0] // self.zc,
|
||||||
dtype="int32",
|
dtype="int32",
|
||||||
@ -532,12 +504,13 @@ if __name__ == "__main__":
|
|||||||
self.input_wav.shape[0] // self.zc,
|
self.input_wav.shape[0] // self.zc,
|
||||||
dtype="float64",
|
dtype="float64",
|
||||||
)
|
)
|
||||||
self.output_wav: torch.Tensor = torch.zeros(
|
|
||||||
self.block_frame, device=device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
self.sola_buffer: torch.Tensor = torch.zeros(
|
self.sola_buffer: torch.Tensor = torch.zeros(
|
||||||
self.crossfade_frame, device=device, dtype=torch.float32
|
self.crossfade_frame, device=device, dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
self.nr_buffer: torch.Tensor = self.sola_buffer.clone()
|
||||||
|
self.output_buffer: torch.Tensor = self.input_wav.clone()
|
||||||
|
self.res_buffer: torch.Tensor = torch.zeros(2 * self.zc, device=device,dtype=torch.float32)
|
||||||
|
self.valid_rate = 1 - (self.extra_frame - 1) / self.input_wav.shape[0]
|
||||||
self.fade_in_window: torch.Tensor = (
|
self.fade_in_window: torch.Tensor = (
|
||||||
torch.sin(
|
torch.sin(
|
||||||
0.5
|
0.5
|
||||||
@ -556,8 +529,7 @@ if __name__ == "__main__":
|
|||||||
self.resampler = tat.Resample(
|
self.resampler = tat.Resample(
|
||||||
orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32
|
orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32
|
||||||
).to(device)
|
).to(device)
|
||||||
self.input_tg = TorchGate(sr=16000, nonstationary=True, n_fft=640).to(device)
|
self.tg = TorchGate(sr=self.config.samplerate, n_fft=4*self.zc, prop_decrease=0.9).to(device)
|
||||||
self.output_tg = TorchGate(sr=self.config.samplerate, nonstationary=True, n_fft=4*self.zc).to(device)
|
|
||||||
thread_vc = threading.Thread(target=self.soundinput)
|
thread_vc = threading.Thread(target=self.soundinput)
|
||||||
thread_vc.start()
|
thread_vc.start()
|
||||||
|
|
||||||
@ -586,114 +558,91 @@ if __name__ == "__main__":
|
|||||||
"""
|
"""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
indata = librosa.to_mono(indata.T)
|
indata = librosa.to_mono(indata.T)
|
||||||
frame_length = 2048
|
|
||||||
hop_length = 1024
|
|
||||||
rms = librosa.feature.rms(
|
|
||||||
y=indata, frame_length=frame_length, hop_length=hop_length
|
|
||||||
)
|
|
||||||
if self.config.threhold > -60:
|
if self.config.threhold > -60:
|
||||||
|
rms = librosa.feature.rms(
|
||||||
|
y=indata, frame_length=4*self.zc, hop_length=self.zc
|
||||||
|
)
|
||||||
db_threhold = (
|
db_threhold = (
|
||||||
librosa.amplitude_to_db(rms, ref=1.0)[0] < self.config.threhold
|
librosa.amplitude_to_db(rms, ref=1.0)[0] < self.config.threhold
|
||||||
)
|
)
|
||||||
for i in range(db_threhold.shape[0]):
|
for i in range(db_threhold.shape[0]):
|
||||||
if db_threhold[i]:
|
if db_threhold[i]:
|
||||||
indata[i * hop_length : (i + 1) * hop_length] = 0
|
indata[i * self.zc : (i + 1) * self.zc] = 0
|
||||||
self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :]
|
self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :].clone()
|
||||||
self.input_wav[-self.block_frame: ] = indata
|
self.input_wav[-self.block_frame: ] = torch.from_numpy(indata).to(device)
|
||||||
# infer
|
|
||||||
inp = torch.from_numpy(self.input_wav[-self.block_frame-2*self.zc :]).to(device)
|
|
||||||
self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone()
|
self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone()
|
||||||
self.input_wav_res[-self.block_frame_16k-160 :] = self.resampler(inp)[160 :]
|
# input noise reduction and resampling
|
||||||
if self.config.I_noise_reduce:
|
if self.config.I_noise_reduce:
|
||||||
self.input_wav_res[-self.block_frame_16k-320 :] = self.input_tg(self.input_wav_res[None, -self.block_frame_16k-800 :])[0, 480 : ]
|
input_wav = self.input_wav[-self.crossfade_frame -self.block_frame-2*self.zc: ]
|
||||||
rate = (
|
input_wav = self.tg(input_wav.unsqueeze(0), self.input_wav.unsqueeze(0))[0, 2*self.zc:]
|
||||||
self.crossfade_frame + self.sola_search_frame + self.block_frame
|
input_wav[: self.crossfade_frame] *= self.fade_in_window
|
||||||
) / (
|
input_wav[: self.crossfade_frame] += self.nr_buffer * self.fade_out_window
|
||||||
self.extra_frame
|
self.nr_buffer[:] = input_wav[-self.crossfade_frame: ]
|
||||||
+ self.crossfade_frame
|
input_wav = torch.cat((self.res_buffer[:], input_wav[: self.block_frame]))
|
||||||
+ self.sola_search_frame
|
self.res_buffer[:] = input_wav[-2*self.zc: ]
|
||||||
+ self.block_frame
|
self.input_wav_res[-self.block_frame_16k-160: ] = self.resampler(input_wav)[160: ]
|
||||||
)
|
else:
|
||||||
|
self.input_wav_res[-self.block_frame_16k-160: ] = self.resampler(self.input_wav[-self.block_frame-2*self.zc: ])[160: ]
|
||||||
|
# infer
|
||||||
f0_extractor_frame = self.block_frame_16k + 800
|
f0_extractor_frame = self.block_frame_16k + 800
|
||||||
if self.config.f0method == 'rmvpe':
|
if self.config.f0method == 'rmvpe':
|
||||||
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1)
|
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1)
|
||||||
res2 = self.rvc.infer(
|
infer_wav = self.rvc.infer(
|
||||||
self.input_wav_res,
|
self.input_wav_res,
|
||||||
self.input_wav_res[-f0_extractor_frame :].cpu().numpy(),
|
self.input_wav_res[-f0_extractor_frame :].cpu().numpy(),
|
||||||
self.block_frame_16k,
|
self.block_frame_16k,
|
||||||
rate,
|
self.valid_rate,
|
||||||
self.pitch,
|
self.pitch,
|
||||||
self.pitchf,
|
self.pitchf,
|
||||||
self.config.f0method,
|
self.config.f0method,
|
||||||
)
|
)
|
||||||
self.output_wav_cache[-res2.shape[0] :] = res2
|
infer_wav = infer_wav[
|
||||||
infer_wav = self.output_wav_cache[
|
|
||||||
-self.crossfade_frame - self.sola_search_frame - self.block_frame :
|
-self.crossfade_frame - self.sola_search_frame - self.block_frame :
|
||||||
]
|
]
|
||||||
|
# output noise reduction
|
||||||
if self.config.O_noise_reduce:
|
if self.config.O_noise_reduce:
|
||||||
infer_wav = self.output_tg(infer_wav.unsqueeze(0)).squeeze(0)
|
self.output_buffer[: -self.block_frame] = self.output_buffer[self.block_frame :].clone()
|
||||||
|
self.output_buffer[-self.block_frame: ] = infer_wav[-self.block_frame:]
|
||||||
|
infer_wav = self.tg(infer_wav.unsqueeze(0), self.output_buffer.unsqueeze(0)).squeeze(0)
|
||||||
|
# volume envelop mixing
|
||||||
if self.config.rms_mix_rate < 1:
|
if self.config.rms_mix_rate < 1:
|
||||||
rms1 = librosa.feature.rms(
|
rms1 = librosa.feature.rms(
|
||||||
y=self.input_wav[-self.crossfade_frame - self.sola_search_frame - self.block_frame :],
|
y=self.input_wav_res[-160*infer_wav.shape[0]//self.zc :].cpu().numpy(),
|
||||||
frame_length=frame_length,
|
frame_length=640,
|
||||||
hop_length=hop_length
|
hop_length=160,
|
||||||
)
|
)
|
||||||
rms1 = torch.from_numpy(rms1).to(device)
|
rms1 = torch.from_numpy(rms1).to(device)
|
||||||
rms1 = F.interpolate(
|
rms1 = F.interpolate(
|
||||||
rms1.unsqueeze(0), size=infer_wav.shape[0], mode="linear"
|
rms1.unsqueeze(0), size=infer_wav.shape[0] + 1, mode="linear",align_corners=True,
|
||||||
).squeeze()
|
)[0,0,:-1]
|
||||||
rms2 = librosa.feature.rms(
|
rms2 = librosa.feature.rms(
|
||||||
y=infer_wav[:].cpu().numpy(), frame_length=frame_length, hop_length=hop_length
|
y=infer_wav[:].cpu().numpy(), frame_length=4*self.zc, hop_length=self.zc
|
||||||
)
|
)
|
||||||
rms2 = torch.from_numpy(rms2).to(device)
|
rms2 = torch.from_numpy(rms2).to(device)
|
||||||
rms2 = F.interpolate(
|
rms2 = F.interpolate(
|
||||||
rms2.unsqueeze(0), size=infer_wav.shape[0], mode="linear"
|
rms2.unsqueeze(0), size=infer_wav.shape[0] + 1, mode="linear",align_corners=True,
|
||||||
).squeeze()
|
)[0,0,:-1]
|
||||||
rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3)
|
rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3)
|
||||||
infer_wav *= torch.pow(rms1 / rms2, torch.tensor(1 - self.config.rms_mix_rate))
|
infer_wav *= torch.pow(rms1 / rms2, torch.tensor(1 - self.config.rms_mix_rate))
|
||||||
# SOLA algorithm from https://github.com/yxlllc/DDSP-SVC
|
# SOLA algorithm from https://github.com/yxlllc/DDSP-SVC
|
||||||
cor_nom = F.conv1d(
|
conv_input = infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame]
|
||||||
infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame],
|
cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
|
||||||
self.sola_buffer[None, None, :],
|
|
||||||
)
|
|
||||||
cor_den = torch.sqrt(
|
cor_den = torch.sqrt(
|
||||||
F.conv1d(
|
F.conv1d(conv_input ** 2, torch.ones(1, 1, self.crossfade_frame, device=device)) + 1e-8)
|
||||||
infer_wav[
|
|
||||||
None, None, : self.crossfade_frame + self.sola_search_frame
|
|
||||||
]
|
|
||||||
** 2,
|
|
||||||
torch.ones(1, 1, self.crossfade_frame, device=device),
|
|
||||||
)
|
|
||||||
+ 1e-8
|
|
||||||
)
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
_, sola_offset = torch.max(cor_nom[0, 0] / cor_den[0, 0])
|
_, sola_offset = torch.max(cor_nom[0, 0] / cor_den[0, 0])
|
||||||
sola_offset = sola_offset.item()
|
sola_offset = sola_offset.item()
|
||||||
else:
|
else:
|
||||||
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
||||||
logger.debug("sola_offset = %d", int(sola_offset))
|
logger.debug("sola_offset = %d", int(sola_offset))
|
||||||
self.output_wav[:] = infer_wav[sola_offset : sola_offset + self.block_frame]
|
infer_wav = infer_wav[sola_offset: sola_offset + self.block_frame + self.crossfade_frame]
|
||||||
self.output_wav[: self.crossfade_frame] *= self.fade_in_window
|
infer_wav[: self.crossfade_frame] *= self.fade_in_window
|
||||||
self.output_wav[: self.crossfade_frame] += self.sola_buffer[:]
|
infer_wav[: self.crossfade_frame] += self.sola_buffer *self.fade_out_window
|
||||||
# crossfade
|
self.sola_buffer[:] = infer_wav[-self.crossfade_frame:]
|
||||||
if sola_offset < self.sola_search_frame:
|
|
||||||
self.sola_buffer[:] = (
|
|
||||||
infer_wav[
|
|
||||||
-self.sola_search_frame
|
|
||||||
- self.crossfade_frame
|
|
||||||
+ sola_offset : -self.sola_search_frame
|
|
||||||
+ sola_offset
|
|
||||||
]
|
|
||||||
* self.fade_out_window
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.sola_buffer[:] = (
|
|
||||||
infer_wav[-self.crossfade_frame :] * self.fade_out_window
|
|
||||||
)
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
outdata[:] = self.output_wav[:].cpu().numpy()[:, np.newaxis]
|
outdata[:] = infer_wav[:-self.crossfade_frame].cpu().numpy()[:, np.newaxis]
|
||||||
else:
|
else:
|
||||||
outdata[:] = self.output_wav[:].repeat(2, 1).t().cpu().numpy()
|
outdata[:] = infer_wav[:-self.crossfade_frame].repeat(2, 1).t().cpu().numpy()
|
||||||
total_time = time.perf_counter() - start_time
|
total_time = time.perf_counter() - start_time
|
||||||
self.window["infer_time"].update(int(total_time * 1000))
|
self.window["infer_time"].update(int(total_time * 1000))
|
||||||
logger.info("Infer time: %.2f", total_time)
|
logger.info("Infer time: %.2f", total_time)
|
||||||
|
@ -91,7 +91,7 @@ class RVC:
|
|||||||
suffix="",
|
suffix="",
|
||||||
)
|
)
|
||||||
hubert_model = models[0]
|
hubert_model = models[0]
|
||||||
hubert_model = hubert_model.to(config.device)
|
hubert_model = hubert_model.to(device)
|
||||||
if config.is_half:
|
if config.is_half:
|
||||||
hubert_model = hubert_model.half()
|
hubert_model = hubert_model.half()
|
||||||
else:
|
else:
|
||||||
@ -309,6 +309,7 @@ class RVC:
|
|||||||
feats = (
|
feats = (
|
||||||
self.model.final_proj(logits[0]) if self.version == "v1" else logits[0]
|
self.model.final_proj(logits[0]) if self.version == "v1" else logits[0]
|
||||||
)
|
)
|
||||||
|
feats = F.pad(feats, (0, 0, 1, 0))
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
try:
|
try:
|
||||||
if hasattr(self, "index") and self.index_rate != 0:
|
if hasattr(self, "index") and self.index_rate != 0:
|
||||||
@ -360,13 +361,13 @@ class RVC:
|
|||||||
self.net_g.infer(
|
self.net_g.infer(
|
||||||
feats, p_len, cache_pitch, cache_pitchf, sid, rate
|
feats, p_len, cache_pitch, cache_pitchf, sid, rate
|
||||||
)[0][0, 0]
|
)[0][0, 0]
|
||||||
.data.cpu()
|
.data
|
||||||
.float()
|
.float()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
infered_audio = (
|
infered_audio = (
|
||||||
self.net_g.infer(feats, p_len, sid, rate)[0][0, 0]
|
self.net_g.infer(feats, p_len, sid, rate)[0][0, 0]
|
||||||
.data.cpu()
|
.data
|
||||||
.float()
|
.float()
|
||||||
)
|
)
|
||||||
t5 = ttime()
|
t5 = ttime()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user