optimize the streaming f0 exatrators (#1168)

This commit is contained in:
yxlllc 2023-09-02 15:45:50 +08:00 committed by GitHub
parent ad85b02ed9
commit 0fc160c03e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 71 deletions

View File

@ -261,9 +261,9 @@ if __name__ == "__main__":
[ [
sg.Text(i18n("采样长度")), sg.Text(i18n("采样长度")),
sg.Slider( sg.Slider(
range=(0.09, 2.4), range=(0.05, 2.4),
key="block_time", key="block_time",
resolution=0.03, resolution=0.01,
orientation="h", orientation="h",
default_value=data.get("block_time", ""), default_value=data.get("block_time", ""),
enable_events=True, enable_events=True,
@ -455,18 +455,20 @@ if __name__ == "__main__":
inp_q, inp_q,
opt_q, opt_q,
device, device,
self.rvc if hasattr(self, "rvc") else None
) )
self.config.samplerate = self.rvc.tgt_sr self.config.samplerate = self.rvc.tgt_sr
self.config.crossfade_time = min( self.config.crossfade_time = min(
self.config.crossfade_time, self.config.block_time self.config.crossfade_time, self.config.block_time
) )
self.block_frame = int(self.config.block_time * self.config.samplerate) self.zc = self.rvc.tgt_sr // 100
self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc
self.block_frame_16k = 160 * self.block_frame // self.zc
self.crossfade_frame = int( self.crossfade_frame = int(
self.config.crossfade_time * self.config.samplerate self.config.crossfade_time * self.config.samplerate
) )
self.sola_search_frame = int(0.01 * self.config.samplerate) self.sola_search_frame = int(0.01 * self.config.samplerate)
self.extra_frame = int(self.config.extra_time * self.config.samplerate) self.extra_frame = int(self.config.extra_time * self.config.samplerate)
self.zc = self.rvc.tgt_sr // 100
self.input_wav: np.ndarray = np.zeros( self.input_wav: np.ndarray = np.zeros(
int( int(
np.ceil( np.ceil(
@ -482,6 +484,7 @@ if __name__ == "__main__":
), ),
dtype="float32", dtype="float32",
) )
self.input_wav_res: torch.Tensor= torch.zeros(160 * len(self.input_wav) // self.zc)
self.output_wav_cache: torch.Tensor = torch.zeros( self.output_wav_cache: torch.Tensor = torch.zeros(
int( int(
np.ceil( np.ceil(
@ -573,18 +576,14 @@ if __name__ == "__main__":
for i in range(db_threhold.shape[0]): for i in range(db_threhold.shape[0]):
if db_threhold[i]: if db_threhold[i]:
indata[i * hop_length : (i + 1) * hop_length] = 0 indata[i * hop_length : (i + 1) * hop_length] = 0
self.input_wav[:] = np.append(self.input_wav[self.block_frame :], indata) self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :]
self.input_wav[-self.block_frame: ] = indata
# infer # infer
inp = torch.from_numpy(self.input_wav).to(device) inp = torch.from_numpy(self.input_wav[-self.block_frame-2*self.zc :]).to(device)
res1 = self.resampler(inp) self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone()
###55% self.input_wav_res[-self.block_frame_16k-160 :] = self.resampler(inp)[160 :]
rate1 = self.block_frame / ( rate = (
self.extra_frame
+ self.crossfade_frame
+ self.sola_search_frame
+ self.block_frame
)
rate2 = (
self.crossfade_frame + self.sola_search_frame + self.block_frame self.crossfade_frame + self.sola_search_frame + self.block_frame
) / ( ) / (
self.extra_frame self.extra_frame
@ -592,11 +591,14 @@ if __name__ == "__main__":
+ self.sola_search_frame + self.sola_search_frame
+ self.block_frame + self.block_frame
) )
f0_extractor_frame = self.block_frame_16k + 800
if self.config.f0method == 'rmvpe':
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1)
res2 = self.rvc.infer( res2 = self.rvc.infer(
res1, self.input_wav_res,
res1[-self.block_frame :].cpu().numpy(), self.input_wav_res[-f0_extractor_frame :].cpu().numpy(),
rate1, self.block_frame_16k,
rate2, rate,
self.pitch, self.pitch,
self.pitchf, self.pitchf,
self.config.f0method, self.config.f0method,

View File

@ -601,7 +601,7 @@ class RMVPE:
with torch.no_grad(): with torch.no_grad():
n_frames = mel.shape[-1] n_frames = mel.shape[-1]
mel = F.pad( mel = F.pad(
mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect" mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="constant"
) )
if "privateuseone" in str(self.device): if "privateuseone" in str(self.device):
onnx_input_name = self.model.get_inputs()[0].name onnx_input_name = self.model.get_inputs()[0].name

View File

@ -2,7 +2,6 @@ import os
import sys import sys
import traceback import traceback
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from time import time as ttime from time import time as ttime
@ -48,7 +47,7 @@ if config.dml == True:
# config.is_half=False########强制cpu测试 # config.is_half=False########强制cpu测试
class RVC: class RVC:
def __init__( def __init__(
self, key, pth_path, index_path, index_rate, n_cpu, inp_q, opt_q, device self, key, pth_path, index_path, index_rate, n_cpu, inp_q, opt_q, device, last_rvc=None,
) -> None: ) -> None:
""" """
初始化 初始化
@ -72,48 +71,64 @@ class RVC:
self.index = faiss.read_index(index_path) self.index = faiss.read_index(index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal) self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
logger.info("Index search enabled") logger.info("Index search enabled")
self.pth_path = pth_path
self.index_path = index_path self.index_path = index_path
self.index_rate = index_rate self.index_rate = index_rate
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"], if last_rvc is None:
suffix="", models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
) ["assets/hubert/hubert_base.pt"],
hubert_model = models[0] suffix="",
hubert_model = hubert_model.to(config.device) )
if config.is_half: hubert_model = models[0]
hubert_model = hubert_model.half() hubert_model = hubert_model.to(config.device)
else: if config.is_half:
hubert_model = hubert_model.float() hubert_model = hubert_model.half()
hubert_model.eval()
self.model = hubert_model
cpt = torch.load(pth_path, map_location="cpu")
self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
if self.version == "v1":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs256NSFsid(
*cpt["config"], is_half=config.is_half
)
else: else:
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) hubert_model = hubert_model.float()
elif self.version == "v2": hubert_model.eval()
if self.if_f0 == 1: self.model = hubert_model
self.net_g = SynthesizerTrnMs768NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del self.net_g.enc_q
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
self.net_g.eval().to(device)
# print(2333333333,device,config.device,self.device)#net_g是devicehubert是config.device
if config.is_half:
self.net_g = self.net_g.half()
else: else:
self.net_g = self.net_g.float() self.model = last_rvc.model
self.is_half = config.is_half
if last_rvc is None or last_rvc.pth_path != self.pth_path:
cpt = torch.load(self.pth_path, map_location="cpu")
self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
if self.version == "v1":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs256NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif self.version == "v2":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs768NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del self.net_g.enc_q
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
self.net_g.eval().to(device)
# print(2333333333,device,config.device,self.device)#net_g是devicehubert是config.device
if config.is_half:
self.net_g = self.net_g.half()
else:
self.net_g = self.net_g.float()
self.is_half = config.is_half
else:
self.tgt_sr = last_rvc.tgt_sr
self.if_f0 = last_rvc.if_f0
self.version = last_rvc.version
self.net_g = last_rvc.net_g
self.is_half = last_rvc.is_half
if last_rvc is not None and hasattr(last_rvc, "model_rmvpe"):
self.model_rmvpe = last_rvc.model_rmvpe
except: except:
logger.warn(traceback.format_exc()) logger.warn(traceback.format_exc())
@ -149,7 +164,7 @@ class RVC:
if method == "rmvpe": if method == "rmvpe":
return self.get_f0_rmvpe(x, f0_up_key) return self.get_f0_rmvpe(x, f0_up_key)
if method == "pm": if method == "pm":
p_len = x.shape[0] // 160 p_len = x.shape[0] // 160 + 1
f0 = ( f0 = (
parselmouth.Sound(x, 16000) parselmouth.Sound(x, 16000)
.to_pitch_ac( .to_pitch_ac(
@ -181,9 +196,10 @@ class RVC:
f0 = signal.medfilt(f0, 3) f0 = signal.medfilt(f0, 3)
f0 *= pow(2, f0_up_key / 12) f0 *= pow(2, f0_up_key / 12)
return self.get_f0_post(f0) return self.get_f0_post(f0)
f0bak = np.zeros(x.shape[0] // 160, dtype=np.float64) f0bak = np.zeros(x.shape[0] // 160 + 1, dtype=np.float64)
length = len(x) length = len(x)
part_length = int(length / n_cpu / 160) * 160 part_length = 160 * ((length // 160 - 1) // n_cpu + 1)
n_cpu = (length // 160 - 1) // (part_length // 160) + 1
ts = ttime() ts = ttime()
res_f0 = mm.dict() res_f0 = mm.dict()
for idx in range(n_cpu): for idx in range(n_cpu):
@ -205,7 +221,7 @@ class RVC:
elif idx != n_cpu - 1: elif idx != n_cpu - 1:
f0 = f0[2:-3] f0 = f0[2:-3]
else: else:
f0 = f0[2:-1] f0 = f0[2:]
f0bak[ f0bak[
part_length * idx // 160 : part_length * idx // 160 + f0.shape[0] part_length * idx // 160 : part_length * idx // 160 + f0.shape[0]
] = f0 ] = f0
@ -259,8 +275,8 @@ class RVC:
self, self,
feats: torch.Tensor, feats: torch.Tensor,
indata: np.ndarray, indata: np.ndarray,
rate1, block_frame_16k,
rate2, rate,
cache_pitch, cache_pitch,
cache_pitchf, cache_pitchf,
f0method, f0method,
@ -286,7 +302,7 @@ class RVC:
t2 = ttime() t2 = ttime()
try: try:
if hasattr(self, "index") and self.index_rate != 0: if hasattr(self, "index") and self.index_rate != 0:
leng_replace_head = int(rate1 * feats[0].shape[0]) leng_replace_head = int(rate * feats[0].shape[0])
npy = feats[0][-leng_replace_head:].cpu().numpy().astype("float32") npy = feats[0][-leng_replace_head:].cpu().numpy().astype("float32")
score, ix = self.index.search(npy, k=8) score, ix = self.index.search(npy, k=8)
weight = np.square(1 / score) weight = np.square(1 / score)
@ -307,9 +323,11 @@ class RVC:
t3 = ttime() t3 = ttime()
if self.if_f0 == 1: if self.if_f0 == 1:
pitch, pitchf = self.get_f0(indata, self.f0_up_key, self.n_cpu, f0method) pitch, pitchf = self.get_f0(indata, self.f0_up_key, self.n_cpu, f0method)
cache_pitch[:] = np.append(cache_pitch[pitch[:-1].shape[0] :], pitch[:-1]) start_frame = block_frame_16k // 160
end_frame = len(cache_pitch) - (pitch.shape[0] - 4) + start_frame
cache_pitch[:] = np.append(cache_pitch[start_frame : end_frame], pitch[3:-1])
cache_pitchf[:] = np.append( cache_pitchf[:] = np.append(
cache_pitchf[pitchf[:-1].shape[0] :], pitchf[:-1] cache_pitchf[start_frame : end_frame], pitchf[3:-1]
) )
p_len = min(feats.shape[1], 13000, cache_pitch.shape[0]) p_len = min(feats.shape[1], 13000, cache_pitch.shape[0])
else: else:
@ -330,14 +348,14 @@ class RVC:
# print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2) # print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
infered_audio = ( infered_audio = (
self.net_g.infer( self.net_g.infer(
feats, p_len, cache_pitch, cache_pitchf, sid, rate2 feats, p_len, cache_pitch, cache_pitchf, sid, rate
)[0][0, 0] )[0][0, 0]
.data.cpu() .data.cpu()
.float() .float()
) )
else: else:
infered_audio = ( infered_audio = (
self.net_g.infer(feats, p_len, sid, rate2)[0][0, 0] self.net_g.infer(feats, p_len, sid, rate)[0][0, 0]
.data.cpu() .data.cpu()
.float() .float()
) )