mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2024-12-29 19:15:04 +08:00
optimize real-time vc
This commit is contained in:
parent
d62e80fb83
commit
d7fb651f7c
20
gui_v1.py
20
gui_v1.py
@ -681,14 +681,6 @@ if __name__ == "__main__":
|
|||||||
device=self.config.device,
|
device=self.config.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.pitch: np.ndarray = np.zeros(
|
|
||||||
self.input_wav.shape[0] // self.zc,
|
|
||||||
dtype="int32",
|
|
||||||
)
|
|
||||||
self.pitchf: np.ndarray = np.zeros(
|
|
||||||
self.input_wav.shape[0] // self.zc,
|
|
||||||
dtype="float64",
|
|
||||||
)
|
|
||||||
self.sola_buffer: torch.Tensor = torch.zeros(
|
self.sola_buffer: torch.Tensor = torch.zeros(
|
||||||
self.sola_buffer_frame, device=self.config.device, dtype=torch.float32
|
self.sola_buffer_frame, device=self.config.device, dtype=torch.float32
|
||||||
)
|
)
|
||||||
@ -698,6 +690,7 @@ if __name__ == "__main__":
|
|||||||
2 * self.zc, device=self.config.device, dtype=torch.float32
|
2 * self.zc, device=self.config.device, dtype=torch.float32
|
||||||
)
|
)
|
||||||
self.skip_head = self.extra_frame // self.zc
|
self.skip_head = self.extra_frame // self.zc
|
||||||
|
self.return_length = (self.block_frame + self.sola_buffer_frame + self.sola_search_frame) // self.zc
|
||||||
self.fade_in_window: torch.Tensor = (
|
self.fade_in_window: torch.Tensor = (
|
||||||
torch.sin(
|
torch.sin(
|
||||||
0.5
|
0.5
|
||||||
@ -808,8 +801,7 @@ if __name__ == "__main__":
|
|||||||
self.input_wav_res,
|
self.input_wav_res,
|
||||||
self.block_frame_16k,
|
self.block_frame_16k,
|
||||||
self.skip_head,
|
self.skip_head,
|
||||||
self.pitch,
|
self.return_length,
|
||||||
self.pitchf,
|
|
||||||
self.gui_config.f0method,
|
self.gui_config.f0method,
|
||||||
)
|
)
|
||||||
if self.resampler2 is not None:
|
if self.resampler2 is not None:
|
||||||
@ -879,9 +871,7 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
||||||
printt("sola_offset = %d", int(sola_offset))
|
printt("sola_offset = %d", int(sola_offset))
|
||||||
infer_wav = infer_wav[
|
infer_wav = infer_wav[sola_offset :]
|
||||||
sola_offset : sola_offset + self.block_frame + self.crossfade_frame
|
|
||||||
]
|
|
||||||
if "privateuseone" in str(self.config.device) or not self.gui_config.use_pv:
|
if "privateuseone" in str(self.config.device) or not self.gui_config.use_pv:
|
||||||
infer_wav[: self.sola_buffer_frame] *= self.fade_in_window
|
infer_wav[: self.sola_buffer_frame] *= self.fade_in_window
|
||||||
infer_wav[: self.sola_buffer_frame] += self.sola_buffer * self.fade_out_window
|
infer_wav[: self.sola_buffer_frame] += self.sola_buffer * self.fade_out_window
|
||||||
@ -894,11 +884,11 @@ if __name__ == "__main__":
|
|||||||
self.sola_buffer[:] = infer_wav[self.block_frame : self.block_frame + self.sola_buffer_frame]
|
self.sola_buffer[:] = infer_wav[self.block_frame : self.block_frame + self.sola_buffer_frame]
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
outdata[:] = (
|
outdata[:] = (
|
||||||
infer_wav[: -self.crossfade_frame].cpu().numpy()[:, np.newaxis]
|
infer_wav[: self.block_frame].cpu().numpy()[:, np.newaxis]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outdata[:] = (
|
outdata[:] = (
|
||||||
infer_wav[: -self.crossfade_frame].repeat(2, 1).t().cpu().numpy()
|
infer_wav[: self.block_frame].repeat(2, 1).t().cpu().numpy()
|
||||||
)
|
)
|
||||||
total_time = time.perf_counter() - start_time
|
total_time = time.perf_counter() - start_time
|
||||||
self.window["infer_time"].update(int(total_time * 1000))
|
self.window["infer_time"].update(int(total_time * 1000))
|
||||||
|
@ -785,16 +785,19 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
|||||||
nsff0: torch.Tensor,
|
nsff0: torch.Tensor,
|
||||||
sid: torch.Tensor,
|
sid: torch.Tensor,
|
||||||
skip_head: Optional[torch.Tensor] = None,
|
skip_head: Optional[torch.Tensor] = None,
|
||||||
|
return_length: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
if skip_head is not None:
|
if skip_head is not None and return_length is not None:
|
||||||
assert isinstance(skip_head, torch.Tensor)
|
assert isinstance(skip_head, torch.Tensor)
|
||||||
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
z_p = z_p[:, :, head:]
|
length = int(return_length.item())
|
||||||
x_mask = x_mask[:, :, head:]
|
z_p = z_p[:, :, head: head + length]
|
||||||
nsff0 = nsff0[:, head:]
|
x_mask = x_mask[:, :, head: head + length]
|
||||||
|
nsff0 = nsff0[:, head: head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, nsff0, g=g)
|
o = self.dec(z * x_mask, nsff0, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
@ -944,16 +947,19 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
|||||||
nsff0: torch.Tensor,
|
nsff0: torch.Tensor,
|
||||||
sid: torch.Tensor,
|
sid: torch.Tensor,
|
||||||
skip_head: Optional[torch.Tensor] = None,
|
skip_head: Optional[torch.Tensor] = None,
|
||||||
|
return_length: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
if skip_head is not None:
|
if skip_head is not None and return_length is not None:
|
||||||
assert isinstance(skip_head, torch.Tensor)
|
assert isinstance(skip_head, torch.Tensor)
|
||||||
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
z_p = z_p[:, :, head:]
|
length = int(return_length.item())
|
||||||
x_mask = x_mask[:, :, head:]
|
z_p = z_p[:, :, head: head + length]
|
||||||
nsff0 = nsff0[:, head:]
|
x_mask = x_mask[:, :, head: head + length]
|
||||||
|
nsff0 = nsff0[:, head: head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, nsff0, g=g)
|
o = self.dec(z * x_mask, nsff0, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
@ -1092,15 +1098,18 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
|||||||
phone_lengths: torch.Tensor,
|
phone_lengths: torch.Tensor,
|
||||||
sid: torch.Tensor,
|
sid: torch.Tensor,
|
||||||
skip_head: Optional[torch.Tensor] = None,
|
skip_head: Optional[torch.Tensor] = None,
|
||||||
|
return_length: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
if skip_head is not None:
|
if skip_head is not None and return_length is not None:
|
||||||
assert isinstance(skip_head, torch.Tensor)
|
assert isinstance(skip_head, torch.Tensor)
|
||||||
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
z_p = z_p[:, :, head:]
|
length = int(return_length.item())
|
||||||
x_mask = x_mask[:, :, head:]
|
z_p = z_p[:, :, head: head + length]
|
||||||
|
x_mask = x_mask[:, :, head: head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, g=g)
|
o = self.dec(z * x_mask, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
@ -1239,15 +1248,18 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
|||||||
phone_lengths: torch.Tensor,
|
phone_lengths: torch.Tensor,
|
||||||
sid: torch.Tensor,
|
sid: torch.Tensor,
|
||||||
skip_head: Optional[torch.Tensor] = None,
|
skip_head: Optional[torch.Tensor] = None,
|
||||||
|
return_length: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
if skip_head is not None:
|
if skip_head is not None and return_length is not None:
|
||||||
assert isinstance(skip_head, torch.Tensor)
|
assert isinstance(skip_head, torch.Tensor)
|
||||||
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
z_p = z_p[:, :, head:]
|
length = int(return_length.item())
|
||||||
x_mask = x_mask[:, :, head:]
|
z_p = z_p[:, :, head: head + length]
|
||||||
|
x_mask = x_mask[:, :, head: head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, g=g)
|
o = self.dec(z * x_mask, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
|
@ -90,7 +90,9 @@ class RVC:
|
|||||||
self.pth_path: str = pth_path
|
self.pth_path: str = pth_path
|
||||||
self.index_path = index_path
|
self.index_path = index_path
|
||||||
self.index_rate = index_rate
|
self.index_rate = index_rate
|
||||||
|
self.cache_pitch: np.ndarray = np.zeros(1024, dtype="int32")
|
||||||
|
self.cache_pitchf = np.zeros(1024, dtype="float32")
|
||||||
|
|
||||||
if last_rvc is None:
|
if last_rvc is None:
|
||||||
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||||
["assets/hubert/hubert_base.pt"],
|
["assets/hubert/hubert_base.pt"],
|
||||||
@ -329,8 +331,9 @@ class RVC:
|
|||||||
sr=16000,
|
sr=16000,
|
||||||
decoder_mode='local_argmax',
|
decoder_mode='local_argmax',
|
||||||
threshold=0.006,
|
threshold=0.006,
|
||||||
).squeeze().cpu().numpy()
|
)
|
||||||
f0 *= pow(2, f0_up_key / 12)
|
f0 *= pow(2, f0_up_key / 12)
|
||||||
|
f0 = f0.squeeze().cpu().numpy()
|
||||||
return self.get_f0_post(f0)
|
return self.get_f0_post(f0)
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
@ -338,8 +341,7 @@ class RVC:
|
|||||||
input_wav: torch.Tensor,
|
input_wav: torch.Tensor,
|
||||||
block_frame_16k,
|
block_frame_16k,
|
||||||
skip_head,
|
skip_head,
|
||||||
cache_pitch,
|
return_length,
|
||||||
cache_pitchf,
|
|
||||||
f0method,
|
f0method,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
@ -362,24 +364,22 @@ class RVC:
|
|||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
try:
|
try:
|
||||||
if hasattr(self, "index") and self.index_rate != 0:
|
if hasattr(self, "index") and self.index_rate != 0:
|
||||||
leng_replace_head = int(rate * feats[0].shape[0])
|
npy = feats[0][skip_head // 2:].cpu().numpy().astype("float32")
|
||||||
npy = feats[0][-leng_replace_head:].cpu().numpy().astype("float32")
|
|
||||||
score, ix = self.index.search(npy, k=8)
|
score, ix = self.index.search(npy, k=8)
|
||||||
weight = np.square(1 / score)
|
weight = np.square(1 / score)
|
||||||
weight /= weight.sum(axis=1, keepdims=True)
|
weight /= weight.sum(axis=1, keepdims=True)
|
||||||
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
||||||
if self.config.is_half:
|
if self.config.is_half:
|
||||||
npy = npy.astype("float16")
|
npy = npy.astype("float16")
|
||||||
feats[0][-leng_replace_head:] = (
|
feats[0][skip_head // 2:] = (
|
||||||
torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate
|
torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate
|
||||||
+ (1 - self.index_rate) * feats[0][-leng_replace_head:]
|
+ (1 - self.index_rate) * feats[0][skip_head // 2:]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
printt("Index search FAILED or disabled")
|
printt("Index search FAILED or disabled")
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
printt("Index search FAILED")
|
printt("Index search FAILED")
|
||||||
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
f0_extractor_frame = block_frame_16k + 800
|
f0_extractor_frame = block_frame_16k + 800
|
||||||
@ -387,40 +387,39 @@ class RVC:
|
|||||||
f0_extractor_frame = (
|
f0_extractor_frame = (
|
||||||
5120 * ((f0_extractor_frame - 1) // 5120 + 1) - 160
|
5120 * ((f0_extractor_frame - 1) // 5120 + 1) - 160
|
||||||
)
|
)
|
||||||
input_wav = input_wav[-f0_extractor_frame:]
|
pitch, pitchf = self.get_f0(input_wav[-f0_extractor_frame: ], self.f0_up_key, self.n_cpu, f0method)
|
||||||
pitch, pitchf = self.get_f0(input_wav, self.f0_up_key, self.n_cpu, f0method)
|
|
||||||
start_frame = block_frame_16k // 160
|
start_frame = block_frame_16k // 160
|
||||||
end_frame = len(cache_pitch) - (pitch.shape[0] - 4) + start_frame
|
end_frame = len(self.cache_pitch) - (pitch.shape[0] - 4) + start_frame
|
||||||
cache_pitch[:] = np.append(cache_pitch[start_frame:end_frame], pitch[3:-1])
|
self.cache_pitch[:] = np.append(self.cache_pitch[start_frame: end_frame], pitch[3:-1])
|
||||||
cache_pitchf[:] = np.append(
|
self.cache_pitchf[:] = np.append(
|
||||||
cache_pitchf[start_frame:end_frame], pitchf[3:-1]
|
self.cache_pitchf[start_frame: end_frame], pitchf[3:-1]
|
||||||
)
|
)
|
||||||
p_len = min(feats.shape[1], 13000, cache_pitch.shape[0])
|
|
||||||
else:
|
|
||||||
cache_pitch, cache_pitchf = None, None
|
|
||||||
p_len = min(feats.shape[1], 13000)
|
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
feats = feats[:, :p_len, :]
|
p_len = input_wav.shape[0] // 160
|
||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
cache_pitch = torch.LongTensor(cache_pitch[:p_len]).to(self.device).unsqueeze(0)
|
cache_pitch = torch.LongTensor(self.cache_pitch[-p_len: ]).to(self.device).unsqueeze(0)
|
||||||
cache_pitchf = torch.FloatTensor(cache_pitchf[:p_len]).to(self.device).unsqueeze(0)
|
cache_pitchf = torch.FloatTensor(self.cache_pitchf[-p_len: ]).to(self.device).unsqueeze(0)
|
||||||
|
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
||||||
|
feats = feats[:, :p_len, :]
|
||||||
p_len = torch.LongTensor([p_len]).to(self.device)
|
p_len = torch.LongTensor([p_len]).to(self.device)
|
||||||
sid = torch.LongTensor([0]).to(self.device)
|
sid = torch.LongTensor([0]).to(self.device)
|
||||||
skip_head = torch.LongTensor([skip_head])
|
skip_head = torch.LongTensor([skip_head])
|
||||||
|
return_length = torch.LongTensor([return_length])
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
infered_audio = self.net_g.infer(
|
infered_audio, _, _ = self.net_g.infer(
|
||||||
feats,
|
feats,
|
||||||
p_len,
|
p_len,
|
||||||
cache_pitch,
|
cache_pitch,
|
||||||
cache_pitchf,
|
cache_pitchf,
|
||||||
sid,
|
sid,
|
||||||
skip_head,
|
skip_head,
|
||||||
)[0][0, 0].data.float()
|
return_length,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
infered_audio = self.net_g.infer(
|
infered_audio, _, _ = self.net_g.infer(
|
||||||
feats, p_len, sid, skip_head
|
feats, p_len, sid, skip_head, return_length
|
||||||
)[0][0, 0].data.float()
|
)
|
||||||
t5 = ttime()
|
t5 = ttime()
|
||||||
printt(
|
printt(
|
||||||
"Spent time: fea = %.3fs, index = %.3fs, f0 = %.3fs, model = %.3fs",
|
"Spent time: fea = %.3fs, index = %.3fs, f0 = %.3fs, model = %.3fs",
|
||||||
@ -429,4 +428,4 @@ class RVC:
|
|||||||
t4 - t3,
|
t4 - t3,
|
||||||
t5 - t4,
|
t5 - t4,
|
||||||
)
|
)
|
||||||
return infered_audio
|
return infered_audio.squeeze().float()
|
||||||
|
Loading…
Reference in New Issue
Block a user