diff --git a/infer/lib/jit/__init__.py b/infer/lib/jit/__init__.py index d7f41dd..2f2e04a 100644 --- a/infer/lib/jit/__init__.py +++ b/infer/lib/jit/__init__.py @@ -1,163 +1,2 @@ -from io import BytesIO -import pickle -import time -import torch -from tqdm import tqdm -from collections import OrderedDict - - -def load_inputs(path, device, is_half=False): - parm = torch.load(path, map_location=torch.device("cpu")) - for key in parm.keys(): - parm[key] = parm[key].to(device) - if is_half and parm[key].dtype == torch.float32: - parm[key] = parm[key].half() - elif not is_half and parm[key].dtype == torch.float16: - parm[key] = parm[key].float() - return parm - - -def benchmark( - model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False -): - parm = load_inputs(inputs_path, device, is_half) - total_ts = 0.0 - bar = tqdm(range(epoch)) - for i in bar: - start_time = time.perf_counter() - o = model(**parm) - total_ts += time.perf_counter() - start_time - print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}") - - -def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False): - benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half) - - -def to_jit_model( - model_path, - model_type: str, - mode: str = "trace", - inputs_path: str = None, - device=torch.device("cpu"), - is_half=False, -): - model = None - if model_type.lower() == "synthesizer": - from .get_synthesizer import get_synthesizer - - model, _ = get_synthesizer(model_path, device) - model.forward = model.infer - elif model_type.lower() == "rmvpe": - from .get_rmvpe import get_rmvpe - - model = get_rmvpe(model_path, device) - elif model_type.lower() == "hubert": - from .get_hubert import get_hubert_model - - model = get_hubert_model(model_path, device) - model.forward = model.infer - else: - raise ValueError(f"No model type named {model_type}") - model = model.eval() - model = model.half() if is_half else model.float() - if mode == "trace": - assert not inputs_path - inputs = load_inputs(inputs_path, device, is_half) - model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) - elif mode == "script": - model_jit = torch.jit.script(model) - model_jit.to(device) - model_jit = model_jit.half() if is_half else model_jit.float() - # model = model.half() if is_half else model.float() - return (model, model_jit) - - -def export( - model: torch.nn.Module, - mode: str = "trace", - inputs: dict = None, - device=torch.device("cpu"), - is_half: bool = False, -) -> dict: - model = model.half() if is_half else model.float() - model.eval() - if mode == "trace": - assert inputs is not None - model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) - elif mode == "script": - model_jit = torch.jit.script(model) - model_jit.to(device) - model_jit = model_jit.half() if is_half else model_jit.float() - buffer = BytesIO() - # model_jit=model_jit.cpu() - torch.jit.save(model_jit, buffer) - del model_jit - cpt = OrderedDict() - cpt["model"] = buffer.getvalue() - cpt["is_half"] = is_half - return cpt - - -def load(path: str): - with open(path, "rb") as f: - return pickle.load(f) - - -def save(ckpt: dict, save_path: str): - with open(save_path, "wb") as f: - pickle.dump(ckpt, f) - - -def rmvpe_jit_export( - model_path: str, - mode: str = "script", - inputs_path: str = None, - save_path: str = None, - device=torch.device("cpu"), - is_half=False, -): - if not save_path: - save_path = model_path.rstrip(".pth") - save_path += ".half.jit" if is_half else ".jit" - if "cuda" in str(device) and ":" not in str(device): - device = torch.device("cuda:0") - from .get_rmvpe import get_rmvpe - - model = get_rmvpe(model_path, device) - inputs = None - if mode == "trace": - inputs = load_inputs(inputs_path, device, is_half) - ckpt = export(model, mode, inputs, device, is_half) - ckpt["device"] = str(device) - save(ckpt, save_path) - return ckpt - - -def synthesizer_jit_export( - model_path: str, - mode: str = "script", - inputs_path: str = None, - save_path: str = None, - device=torch.device("cpu"), - is_half=False, -): - if not save_path: - save_path = model_path.rstrip(".pth") - save_path += ".half.jit" if is_half else ".jit" - if "cuda" in str(device) and ":" not in str(device): - device = torch.device("cuda:0") - from .get_synthesizer import get_synthesizer - - model, cpt = get_synthesizer(model_path, device) - assert isinstance(cpt, dict) - model.forward = model.infer - inputs = None - if mode == "trace": - inputs = load_inputs(inputs_path, device, is_half) - ckpt = export(model, mode, inputs, device, is_half) - cpt.pop("weight") - cpt["model"] = ckpt["model"] - cpt["device"] = device - save(cpt, save_path) - return cpt +from .utils import load, rmvpe_jit_export, synthesizer_jit_export +from .synthesizer import get_synthesizer diff --git a/infer/lib/jit/get_hubert.py b/infer/lib/jit/hubert.py similarity index 100% rename from infer/lib/jit/get_hubert.py rename to infer/lib/jit/hubert.py diff --git a/infer/lib/jit/get_rmvpe.py b/infer/lib/jit/rmvpe.py similarity index 100% rename from infer/lib/jit/get_rmvpe.py rename to infer/lib/jit/rmvpe.py diff --git a/infer/lib/jit/get_synthesizer.py b/infer/lib/jit/synthesizer.py similarity index 100% rename from infer/lib/jit/get_synthesizer.py rename to infer/lib/jit/synthesizer.py diff --git a/infer/lib/jit/utils.py b/infer/lib/jit/utils.py new file mode 100644 index 0000000..ea73ba1 --- /dev/null +++ b/infer/lib/jit/utils.py @@ -0,0 +1,163 @@ +from io import BytesIO +import pickle +import time +import torch +from tqdm import tqdm +from collections import OrderedDict + + +def load_inputs(path, device, is_half=False): + parm = torch.load(path, map_location=torch.device("cpu")) + for key in parm.keys(): + parm[key] = parm[key].to(device) + if is_half and parm[key].dtype == torch.float32: + parm[key] = parm[key].half() + elif not is_half and parm[key].dtype == torch.float16: + parm[key] = parm[key].float() + return parm + + +def benchmark( + model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False +): + parm = load_inputs(inputs_path, device, is_half) + total_ts = 0.0 + bar = tqdm(range(epoch)) + for i in bar: + start_time = time.perf_counter() + o = model(**parm) + total_ts += time.perf_counter() - start_time + print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}") + + +def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False): + benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half) + + +def to_jit_model( + model_path, + model_type: str, + mode: str = "trace", + inputs_path: str = None, + device=torch.device("cpu"), + is_half=False, +): + model = None + if model_type.lower() == "synthesizer": + from .synthesizer import get_synthesizer + + model, _ = get_synthesizer(model_path, device) + model.forward = model.infer + elif model_type.lower() == "rmvpe": + from .rmvpe import get_rmvpe + + model = get_rmvpe(model_path, device) + elif model_type.lower() == "hubert": + from .hubert import get_hubert_model + + model = get_hubert_model(model_path, device) + model.forward = model.infer + else: + raise ValueError(f"No model type named {model_type}") + model = model.eval() + model = model.half() if is_half else model.float() + if mode == "trace": + assert not inputs_path + inputs = load_inputs(inputs_path, device, is_half) + model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) + elif mode == "script": + model_jit = torch.jit.script(model) + model_jit.to(device) + model_jit = model_jit.half() if is_half else model_jit.float() + # model = model.half() if is_half else model.float() + return (model, model_jit) + + +def export( + model: torch.nn.Module, + mode: str = "trace", + inputs: dict = None, + device=torch.device("cpu"), + is_half: bool = False, +) -> dict: + model = model.half() if is_half else model.float() + model.eval() + if mode == "trace": + assert inputs is not None + model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) + elif mode == "script": + model_jit = torch.jit.script(model) + model_jit.to(device) + model_jit = model_jit.half() if is_half else model_jit.float() + buffer = BytesIO() + # model_jit=model_jit.cpu() + torch.jit.save(model_jit, buffer) + del model_jit + cpt = OrderedDict() + cpt["model"] = buffer.getvalue() + cpt["is_half"] = is_half + return cpt + + +def load(path: str): + with open(path, "rb") as f: + return pickle.load(f) + + +def save(ckpt: dict, save_path: str): + with open(save_path, "wb") as f: + pickle.dump(ckpt, f) + + +def rmvpe_jit_export( + model_path: str, + mode: str = "script", + inputs_path: str = None, + save_path: str = None, + device=torch.device("cpu"), + is_half=False, +): + if not save_path: + save_path = model_path.rstrip(".pth") + save_path += ".half.jit" if is_half else ".jit" + if "cuda" in str(device) and ":" not in str(device): + device = torch.device("cuda:0") + from .rmvpe import get_rmvpe + + model = get_rmvpe(model_path, device) + inputs = None + if mode == "trace": + inputs = load_inputs(inputs_path, device, is_half) + ckpt = export(model, mode, inputs, device, is_half) + ckpt["device"] = str(device) + save(ckpt, save_path) + return ckpt + + +def synthesizer_jit_export( + model_path: str, + mode: str = "script", + inputs_path: str = None, + save_path: str = None, + device=torch.device("cpu"), + is_half=False, +): + if not save_path: + save_path = model_path.rstrip(".pth") + save_path += ".half.jit" if is_half else ".jit" + if "cuda" in str(device) and ":" not in str(device): + device = torch.device("cuda:0") + from .synthesizer import get_synthesizer + + model, cpt = get_synthesizer(model_path, device) + assert isinstance(cpt, dict) + model.forward = model.infer + inputs = None + if mode == "trace": + inputs = load_inputs(inputs_path, device, is_half) + ckpt = export(model, mode, inputs, device, is_half) + cpt.pop("weight") + cpt["model"] = ckpt["model"] + cpt["device"] = device + save(cpt, save_path) + return cpt diff --git a/infer/lib/rtrvc.py b/infer/lib/rtrvc.py index b74e938..49f6140 100644 --- a/infer/lib/rtrvc.py +++ b/infer/lib/rtrvc.py @@ -3,7 +3,6 @@ import os import sys import traceback from infer.lib import jit -from infer.lib.jit.get_synthesizer import get_synthesizer from time import time as ttime import fairseq import faiss @@ -114,7 +113,7 @@ class RVC: self.net_g: nn.Module = None def set_default_model(): - self.net_g, cpt = get_synthesizer(self.pth_path, self.device) + self.net_g, cpt = jit.get_synthesizer(self.pth_path, self.device) self.tgt_sr = cpt["config"][-1] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] self.if_f0 = cpt.get("f0", 1) diff --git a/infer/modules/vc/pipeline.py b/infer/modules/vc/pipeline.py index 2c93236..6b33022 100644 --- a/infer/modules/vc/pipeline.py +++ b/infer/modules/vc/pipeline.py @@ -6,7 +6,7 @@ import logging logger = logging.getLogger(__name__) from functools import lru_cache -from time import time as ttime +from time import time import faiss import librosa @@ -235,7 +235,7 @@ class Pipeline(object): "padding_mask": padding_mask, "output_layer": 9 if version == "v1" else 12, } - t0 = ttime() + t0 = time() with torch.no_grad(): logits = model.extract_features(**inputs) feats = model.final_proj(logits[0]) if version == "v1" else logits[0] @@ -270,7 +270,7 @@ class Pipeline(object): feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute( 0, 2, 1 ) - t1 = ttime() + t1 = time() p_len = audio0.shape[0] // self.window if feats.shape[1] < p_len: p_len = feats.shape[1] @@ -296,7 +296,7 @@ class Pipeline(object): torch.cuda.empty_cache() elif torch.backends.mps.is_available(): torch.mps.empty_cache() - t2 = ttime() + t2 = time() times[0] += t1 - t0 times[2] += t2 - t1 return audio1 @@ -356,7 +356,7 @@ class Pipeline(object): s = 0 audio_opt = [] t = None - t1 = ttime() + t1 = time() audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect") p_len = audio_pad.shape[0] // self.window inp_f0 = None @@ -387,7 +387,7 @@ class Pipeline(object): pitchf = pitchf.astype(np.float32) pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long() pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float() - t2 = ttime() + t2 = time() times[1] += t2 - t1 for t in opt_ts: t = t // self.window * self.window