mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-05-07 04:09:06 +08:00
optimize(jit): package hierarchy
This commit is contained in:
parent
ae04d4f9cc
commit
17e703a9ad
@ -1,163 +1,2 @@
|
|||||||
from io import BytesIO
|
from .utils import load, rmvpe_jit_export, synthesizer_jit_export
|
||||||
import pickle
|
from .synthesizer import get_synthesizer
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
def load_inputs(path, device, is_half=False):
|
|
||||||
parm = torch.load(path, map_location=torch.device("cpu"))
|
|
||||||
for key in parm.keys():
|
|
||||||
parm[key] = parm[key].to(device)
|
|
||||||
if is_half and parm[key].dtype == torch.float32:
|
|
||||||
parm[key] = parm[key].half()
|
|
||||||
elif not is_half and parm[key].dtype == torch.float16:
|
|
||||||
parm[key] = parm[key].float()
|
|
||||||
return parm
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark(
|
|
||||||
model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
|
|
||||||
):
|
|
||||||
parm = load_inputs(inputs_path, device, is_half)
|
|
||||||
total_ts = 0.0
|
|
||||||
bar = tqdm(range(epoch))
|
|
||||||
for i in bar:
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
o = model(**parm)
|
|
||||||
total_ts += time.perf_counter() - start_time
|
|
||||||
print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
|
|
||||||
|
|
||||||
|
|
||||||
def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
|
|
||||||
benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
|
|
||||||
|
|
||||||
|
|
||||||
def to_jit_model(
|
|
||||||
model_path,
|
|
||||||
model_type: str,
|
|
||||||
mode: str = "trace",
|
|
||||||
inputs_path: str = None,
|
|
||||||
device=torch.device("cpu"),
|
|
||||||
is_half=False,
|
|
||||||
):
|
|
||||||
model = None
|
|
||||||
if model_type.lower() == "synthesizer":
|
|
||||||
from .get_synthesizer import get_synthesizer
|
|
||||||
|
|
||||||
model, _ = get_synthesizer(model_path, device)
|
|
||||||
model.forward = model.infer
|
|
||||||
elif model_type.lower() == "rmvpe":
|
|
||||||
from .get_rmvpe import get_rmvpe
|
|
||||||
|
|
||||||
model = get_rmvpe(model_path, device)
|
|
||||||
elif model_type.lower() == "hubert":
|
|
||||||
from .get_hubert import get_hubert_model
|
|
||||||
|
|
||||||
model = get_hubert_model(model_path, device)
|
|
||||||
model.forward = model.infer
|
|
||||||
else:
|
|
||||||
raise ValueError(f"No model type named {model_type}")
|
|
||||||
model = model.eval()
|
|
||||||
model = model.half() if is_half else model.float()
|
|
||||||
if mode == "trace":
|
|
||||||
assert not inputs_path
|
|
||||||
inputs = load_inputs(inputs_path, device, is_half)
|
|
||||||
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
|
||||||
elif mode == "script":
|
|
||||||
model_jit = torch.jit.script(model)
|
|
||||||
model_jit.to(device)
|
|
||||||
model_jit = model_jit.half() if is_half else model_jit.float()
|
|
||||||
# model = model.half() if is_half else model.float()
|
|
||||||
return (model, model_jit)
|
|
||||||
|
|
||||||
|
|
||||||
def export(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
mode: str = "trace",
|
|
||||||
inputs: dict = None,
|
|
||||||
device=torch.device("cpu"),
|
|
||||||
is_half: bool = False,
|
|
||||||
) -> dict:
|
|
||||||
model = model.half() if is_half else model.float()
|
|
||||||
model.eval()
|
|
||||||
if mode == "trace":
|
|
||||||
assert inputs is not None
|
|
||||||
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
|
||||||
elif mode == "script":
|
|
||||||
model_jit = torch.jit.script(model)
|
|
||||||
model_jit.to(device)
|
|
||||||
model_jit = model_jit.half() if is_half else model_jit.float()
|
|
||||||
buffer = BytesIO()
|
|
||||||
# model_jit=model_jit.cpu()
|
|
||||||
torch.jit.save(model_jit, buffer)
|
|
||||||
del model_jit
|
|
||||||
cpt = OrderedDict()
|
|
||||||
cpt["model"] = buffer.getvalue()
|
|
||||||
cpt["is_half"] = is_half
|
|
||||||
return cpt
|
|
||||||
|
|
||||||
|
|
||||||
def load(path: str):
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def save(ckpt: dict, save_path: str):
|
|
||||||
with open(save_path, "wb") as f:
|
|
||||||
pickle.dump(ckpt, f)
|
|
||||||
|
|
||||||
|
|
||||||
def rmvpe_jit_export(
|
|
||||||
model_path: str,
|
|
||||||
mode: str = "script",
|
|
||||||
inputs_path: str = None,
|
|
||||||
save_path: str = None,
|
|
||||||
device=torch.device("cpu"),
|
|
||||||
is_half=False,
|
|
||||||
):
|
|
||||||
if not save_path:
|
|
||||||
save_path = model_path.rstrip(".pth")
|
|
||||||
save_path += ".half.jit" if is_half else ".jit"
|
|
||||||
if "cuda" in str(device) and ":" not in str(device):
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
from .get_rmvpe import get_rmvpe
|
|
||||||
|
|
||||||
model = get_rmvpe(model_path, device)
|
|
||||||
inputs = None
|
|
||||||
if mode == "trace":
|
|
||||||
inputs = load_inputs(inputs_path, device, is_half)
|
|
||||||
ckpt = export(model, mode, inputs, device, is_half)
|
|
||||||
ckpt["device"] = str(device)
|
|
||||||
save(ckpt, save_path)
|
|
||||||
return ckpt
|
|
||||||
|
|
||||||
|
|
||||||
def synthesizer_jit_export(
|
|
||||||
model_path: str,
|
|
||||||
mode: str = "script",
|
|
||||||
inputs_path: str = None,
|
|
||||||
save_path: str = None,
|
|
||||||
device=torch.device("cpu"),
|
|
||||||
is_half=False,
|
|
||||||
):
|
|
||||||
if not save_path:
|
|
||||||
save_path = model_path.rstrip(".pth")
|
|
||||||
save_path += ".half.jit" if is_half else ".jit"
|
|
||||||
if "cuda" in str(device) and ":" not in str(device):
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
from .get_synthesizer import get_synthesizer
|
|
||||||
|
|
||||||
model, cpt = get_synthesizer(model_path, device)
|
|
||||||
assert isinstance(cpt, dict)
|
|
||||||
model.forward = model.infer
|
|
||||||
inputs = None
|
|
||||||
if mode == "trace":
|
|
||||||
inputs = load_inputs(inputs_path, device, is_half)
|
|
||||||
ckpt = export(model, mode, inputs, device, is_half)
|
|
||||||
cpt.pop("weight")
|
|
||||||
cpt["model"] = ckpt["model"]
|
|
||||||
cpt["device"] = device
|
|
||||||
save(cpt, save_path)
|
|
||||||
return cpt
|
|
||||||
|
163
infer/lib/jit/utils.py
Normal file
163
infer/lib/jit/utils.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
def load_inputs(path, device, is_half=False):
|
||||||
|
parm = torch.load(path, map_location=torch.device("cpu"))
|
||||||
|
for key in parm.keys():
|
||||||
|
parm[key] = parm[key].to(device)
|
||||||
|
if is_half and parm[key].dtype == torch.float32:
|
||||||
|
parm[key] = parm[key].half()
|
||||||
|
elif not is_half and parm[key].dtype == torch.float16:
|
||||||
|
parm[key] = parm[key].float()
|
||||||
|
return parm
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
|
||||||
|
):
|
||||||
|
parm = load_inputs(inputs_path, device, is_half)
|
||||||
|
total_ts = 0.0
|
||||||
|
bar = tqdm(range(epoch))
|
||||||
|
for i in bar:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
o = model(**parm)
|
||||||
|
total_ts += time.perf_counter() - start_time
|
||||||
|
print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
|
||||||
|
|
||||||
|
|
||||||
|
def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
|
||||||
|
benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
|
||||||
|
|
||||||
|
|
||||||
|
def to_jit_model(
|
||||||
|
model_path,
|
||||||
|
model_type: str,
|
||||||
|
mode: str = "trace",
|
||||||
|
inputs_path: str = None,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
is_half=False,
|
||||||
|
):
|
||||||
|
model = None
|
||||||
|
if model_type.lower() == "synthesizer":
|
||||||
|
from .synthesizer import get_synthesizer
|
||||||
|
|
||||||
|
model, _ = get_synthesizer(model_path, device)
|
||||||
|
model.forward = model.infer
|
||||||
|
elif model_type.lower() == "rmvpe":
|
||||||
|
from .rmvpe import get_rmvpe
|
||||||
|
|
||||||
|
model = get_rmvpe(model_path, device)
|
||||||
|
elif model_type.lower() == "hubert":
|
||||||
|
from .hubert import get_hubert_model
|
||||||
|
|
||||||
|
model = get_hubert_model(model_path, device)
|
||||||
|
model.forward = model.infer
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No model type named {model_type}")
|
||||||
|
model = model.eval()
|
||||||
|
model = model.half() if is_half else model.float()
|
||||||
|
if mode == "trace":
|
||||||
|
assert not inputs_path
|
||||||
|
inputs = load_inputs(inputs_path, device, is_half)
|
||||||
|
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
||||||
|
elif mode == "script":
|
||||||
|
model_jit = torch.jit.script(model)
|
||||||
|
model_jit.to(device)
|
||||||
|
model_jit = model_jit.half() if is_half else model_jit.float()
|
||||||
|
# model = model.half() if is_half else model.float()
|
||||||
|
return (model, model_jit)
|
||||||
|
|
||||||
|
|
||||||
|
def export(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
mode: str = "trace",
|
||||||
|
inputs: dict = None,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
is_half: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
model = model.half() if is_half else model.float()
|
||||||
|
model.eval()
|
||||||
|
if mode == "trace":
|
||||||
|
assert inputs is not None
|
||||||
|
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
||||||
|
elif mode == "script":
|
||||||
|
model_jit = torch.jit.script(model)
|
||||||
|
model_jit.to(device)
|
||||||
|
model_jit = model_jit.half() if is_half else model_jit.float()
|
||||||
|
buffer = BytesIO()
|
||||||
|
# model_jit=model_jit.cpu()
|
||||||
|
torch.jit.save(model_jit, buffer)
|
||||||
|
del model_jit
|
||||||
|
cpt = OrderedDict()
|
||||||
|
cpt["model"] = buffer.getvalue()
|
||||||
|
cpt["is_half"] = is_half
|
||||||
|
return cpt
|
||||||
|
|
||||||
|
|
||||||
|
def load(path: str):
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def save(ckpt: dict, save_path: str):
|
||||||
|
with open(save_path, "wb") as f:
|
||||||
|
pickle.dump(ckpt, f)
|
||||||
|
|
||||||
|
|
||||||
|
def rmvpe_jit_export(
|
||||||
|
model_path: str,
|
||||||
|
mode: str = "script",
|
||||||
|
inputs_path: str = None,
|
||||||
|
save_path: str = None,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
is_half=False,
|
||||||
|
):
|
||||||
|
if not save_path:
|
||||||
|
save_path = model_path.rstrip(".pth")
|
||||||
|
save_path += ".half.jit" if is_half else ".jit"
|
||||||
|
if "cuda" in str(device) and ":" not in str(device):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
from .rmvpe import get_rmvpe
|
||||||
|
|
||||||
|
model = get_rmvpe(model_path, device)
|
||||||
|
inputs = None
|
||||||
|
if mode == "trace":
|
||||||
|
inputs = load_inputs(inputs_path, device, is_half)
|
||||||
|
ckpt = export(model, mode, inputs, device, is_half)
|
||||||
|
ckpt["device"] = str(device)
|
||||||
|
save(ckpt, save_path)
|
||||||
|
return ckpt
|
||||||
|
|
||||||
|
|
||||||
|
def synthesizer_jit_export(
|
||||||
|
model_path: str,
|
||||||
|
mode: str = "script",
|
||||||
|
inputs_path: str = None,
|
||||||
|
save_path: str = None,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
is_half=False,
|
||||||
|
):
|
||||||
|
if not save_path:
|
||||||
|
save_path = model_path.rstrip(".pth")
|
||||||
|
save_path += ".half.jit" if is_half else ".jit"
|
||||||
|
if "cuda" in str(device) and ":" not in str(device):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
from .synthesizer import get_synthesizer
|
||||||
|
|
||||||
|
model, cpt = get_synthesizer(model_path, device)
|
||||||
|
assert isinstance(cpt, dict)
|
||||||
|
model.forward = model.infer
|
||||||
|
inputs = None
|
||||||
|
if mode == "trace":
|
||||||
|
inputs = load_inputs(inputs_path, device, is_half)
|
||||||
|
ckpt = export(model, mode, inputs, device, is_half)
|
||||||
|
cpt.pop("weight")
|
||||||
|
cpt["model"] = ckpt["model"]
|
||||||
|
cpt["device"] = device
|
||||||
|
save(cpt, save_path)
|
||||||
|
return cpt
|
@ -3,7 +3,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from infer.lib import jit
|
from infer.lib import jit
|
||||||
from infer.lib.jit.get_synthesizer import get_synthesizer
|
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
import fairseq
|
import fairseq
|
||||||
import faiss
|
import faiss
|
||||||
@ -114,7 +113,7 @@ class RVC:
|
|||||||
self.net_g: nn.Module = None
|
self.net_g: nn.Module = None
|
||||||
|
|
||||||
def set_default_model():
|
def set_default_model():
|
||||||
self.net_g, cpt = get_synthesizer(self.pth_path, self.device)
|
self.net_g, cpt = jit.get_synthesizer(self.pth_path, self.device)
|
||||||
self.tgt_sr = cpt["config"][-1]
|
self.tgt_sr = cpt["config"][-1]
|
||||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||||
self.if_f0 = cpt.get("f0", 1)
|
self.if_f0 = cpt.get("f0", 1)
|
||||||
|
@ -6,7 +6,7 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from time import time as ttime
|
from time import time
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import librosa
|
import librosa
|
||||||
@ -235,7 +235,7 @@ class Pipeline(object):
|
|||||||
"padding_mask": padding_mask,
|
"padding_mask": padding_mask,
|
||||||
"output_layer": 9 if version == "v1" else 12,
|
"output_layer": 9 if version == "v1" else 12,
|
||||||
}
|
}
|
||||||
t0 = ttime()
|
t0 = time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model.extract_features(**inputs)
|
logits = model.extract_features(**inputs)
|
||||||
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
||||||
@ -270,7 +270,7 @@ class Pipeline(object):
|
|||||||
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
|
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
|
||||||
0, 2, 1
|
0, 2, 1
|
||||||
)
|
)
|
||||||
t1 = ttime()
|
t1 = time()
|
||||||
p_len = audio0.shape[0] // self.window
|
p_len = audio0.shape[0] // self.window
|
||||||
if feats.shape[1] < p_len:
|
if feats.shape[1] < p_len:
|
||||||
p_len = feats.shape[1]
|
p_len = feats.shape[1]
|
||||||
@ -296,7 +296,7 @@ class Pipeline(object):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
t2 = ttime()
|
t2 = time()
|
||||||
times[0] += t1 - t0
|
times[0] += t1 - t0
|
||||||
times[2] += t2 - t1
|
times[2] += t2 - t1
|
||||||
return audio1
|
return audio1
|
||||||
@ -356,7 +356,7 @@ class Pipeline(object):
|
|||||||
s = 0
|
s = 0
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
t = None
|
t = None
|
||||||
t1 = ttime()
|
t1 = time()
|
||||||
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
||||||
p_len = audio_pad.shape[0] // self.window
|
p_len = audio_pad.shape[0] // self.window
|
||||||
inp_f0 = None
|
inp_f0 = None
|
||||||
@ -387,7 +387,7 @@ class Pipeline(object):
|
|||||||
pitchf = pitchf.astype(np.float32)
|
pitchf = pitchf.astype(np.float32)
|
||||||
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
|
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
|
||||||
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
||||||
t2 = ttime()
|
t2 = time()
|
||||||
times[1] += t2 - t1
|
times[1] += t2 - t1
|
||||||
for t in opt_ts:
|
for t in opt_ts:
|
||||||
t = t // self.window * self.window
|
t = t // self.window * self.window
|
||||||
|
Loading…
x
Reference in New Issue
Block a user