optimize(jit): package hierarchy

This commit is contained in:
源文雨 2024-06-03 15:41:25 +09:00
parent ae04d4f9cc
commit 17e703a9ad
7 changed files with 172 additions and 171 deletions

View File

@ -1,163 +1,2 @@
from io import BytesIO from .utils import load, rmvpe_jit_export, synthesizer_jit_export
import pickle from .synthesizer import get_synthesizer
import time
import torch
from tqdm import tqdm
from collections import OrderedDict
def load_inputs(path, device, is_half=False):
parm = torch.load(path, map_location=torch.device("cpu"))
for key in parm.keys():
parm[key] = parm[key].to(device)
if is_half and parm[key].dtype == torch.float32:
parm[key] = parm[key].half()
elif not is_half and parm[key].dtype == torch.float16:
parm[key] = parm[key].float()
return parm
def benchmark(
model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
):
parm = load_inputs(inputs_path, device, is_half)
total_ts = 0.0
bar = tqdm(range(epoch))
for i in bar:
start_time = time.perf_counter()
o = model(**parm)
total_ts += time.perf_counter() - start_time
print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
def to_jit_model(
model_path,
model_type: str,
mode: str = "trace",
inputs_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
model = None
if model_type.lower() == "synthesizer":
from .get_synthesizer import get_synthesizer
model, _ = get_synthesizer(model_path, device)
model.forward = model.infer
elif model_type.lower() == "rmvpe":
from .get_rmvpe import get_rmvpe
model = get_rmvpe(model_path, device)
elif model_type.lower() == "hubert":
from .get_hubert import get_hubert_model
model = get_hubert_model(model_path, device)
model.forward = model.infer
else:
raise ValueError(f"No model type named {model_type}")
model = model.eval()
model = model.half() if is_half else model.float()
if mode == "trace":
assert not inputs_path
inputs = load_inputs(inputs_path, device, is_half)
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
elif mode == "script":
model_jit = torch.jit.script(model)
model_jit.to(device)
model_jit = model_jit.half() if is_half else model_jit.float()
# model = model.half() if is_half else model.float()
return (model, model_jit)
def export(
model: torch.nn.Module,
mode: str = "trace",
inputs: dict = None,
device=torch.device("cpu"),
is_half: bool = False,
) -> dict:
model = model.half() if is_half else model.float()
model.eval()
if mode == "trace":
assert inputs is not None
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
elif mode == "script":
model_jit = torch.jit.script(model)
model_jit.to(device)
model_jit = model_jit.half() if is_half else model_jit.float()
buffer = BytesIO()
# model_jit=model_jit.cpu()
torch.jit.save(model_jit, buffer)
del model_jit
cpt = OrderedDict()
cpt["model"] = buffer.getvalue()
cpt["is_half"] = is_half
return cpt
def load(path: str):
with open(path, "rb") as f:
return pickle.load(f)
def save(ckpt: dict, save_path: str):
with open(save_path, "wb") as f:
pickle.dump(ckpt, f)
def rmvpe_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
from .get_rmvpe import get_rmvpe
model = get_rmvpe(model_path, device)
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export(model, mode, inputs, device, is_half)
ckpt["device"] = str(device)
save(ckpt, save_path)
return ckpt
def synthesizer_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
from .get_synthesizer import get_synthesizer
model, cpt = get_synthesizer(model_path, device)
assert isinstance(cpt, dict)
model.forward = model.infer
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export(model, mode, inputs, device, is_half)
cpt.pop("weight")
cpt["model"] = ckpt["model"]
cpt["device"] = device
save(cpt, save_path)
return cpt

163
infer/lib/jit/utils.py Normal file
View File

@ -0,0 +1,163 @@
from io import BytesIO
import pickle
import time
import torch
from tqdm import tqdm
from collections import OrderedDict
def load_inputs(path, device, is_half=False):
parm = torch.load(path, map_location=torch.device("cpu"))
for key in parm.keys():
parm[key] = parm[key].to(device)
if is_half and parm[key].dtype == torch.float32:
parm[key] = parm[key].half()
elif not is_half and parm[key].dtype == torch.float16:
parm[key] = parm[key].float()
return parm
def benchmark(
model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
):
parm = load_inputs(inputs_path, device, is_half)
total_ts = 0.0
bar = tqdm(range(epoch))
for i in bar:
start_time = time.perf_counter()
o = model(**parm)
total_ts += time.perf_counter() - start_time
print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
def to_jit_model(
model_path,
model_type: str,
mode: str = "trace",
inputs_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
model = None
if model_type.lower() == "synthesizer":
from .synthesizer import get_synthesizer
model, _ = get_synthesizer(model_path, device)
model.forward = model.infer
elif model_type.lower() == "rmvpe":
from .rmvpe import get_rmvpe
model = get_rmvpe(model_path, device)
elif model_type.lower() == "hubert":
from .hubert import get_hubert_model
model = get_hubert_model(model_path, device)
model.forward = model.infer
else:
raise ValueError(f"No model type named {model_type}")
model = model.eval()
model = model.half() if is_half else model.float()
if mode == "trace":
assert not inputs_path
inputs = load_inputs(inputs_path, device, is_half)
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
elif mode == "script":
model_jit = torch.jit.script(model)
model_jit.to(device)
model_jit = model_jit.half() if is_half else model_jit.float()
# model = model.half() if is_half else model.float()
return (model, model_jit)
def export(
model: torch.nn.Module,
mode: str = "trace",
inputs: dict = None,
device=torch.device("cpu"),
is_half: bool = False,
) -> dict:
model = model.half() if is_half else model.float()
model.eval()
if mode == "trace":
assert inputs is not None
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
elif mode == "script":
model_jit = torch.jit.script(model)
model_jit.to(device)
model_jit = model_jit.half() if is_half else model_jit.float()
buffer = BytesIO()
# model_jit=model_jit.cpu()
torch.jit.save(model_jit, buffer)
del model_jit
cpt = OrderedDict()
cpt["model"] = buffer.getvalue()
cpt["is_half"] = is_half
return cpt
def load(path: str):
with open(path, "rb") as f:
return pickle.load(f)
def save(ckpt: dict, save_path: str):
with open(save_path, "wb") as f:
pickle.dump(ckpt, f)
def rmvpe_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
from .rmvpe import get_rmvpe
model = get_rmvpe(model_path, device)
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export(model, mode, inputs, device, is_half)
ckpt["device"] = str(device)
save(ckpt, save_path)
return ckpt
def synthesizer_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
from .synthesizer import get_synthesizer
model, cpt = get_synthesizer(model_path, device)
assert isinstance(cpt, dict)
model.forward = model.infer
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export(model, mode, inputs, device, is_half)
cpt.pop("weight")
cpt["model"] = ckpt["model"]
cpt["device"] = device
save(cpt, save_path)
return cpt

View File

@ -3,7 +3,6 @@ import os
import sys import sys
import traceback import traceback
from infer.lib import jit from infer.lib import jit
from infer.lib.jit.get_synthesizer import get_synthesizer
from time import time as ttime from time import time as ttime
import fairseq import fairseq
import faiss import faiss
@ -114,7 +113,7 @@ class RVC:
self.net_g: nn.Module = None self.net_g: nn.Module = None
def set_default_model(): def set_default_model():
self.net_g, cpt = get_synthesizer(self.pth_path, self.device) self.net_g, cpt = jit.get_synthesizer(self.pth_path, self.device)
self.tgt_sr = cpt["config"][-1] self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1) self.if_f0 = cpt.get("f0", 1)

View File

@ -6,7 +6,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from functools import lru_cache from functools import lru_cache
from time import time as ttime from time import time
import faiss import faiss
import librosa import librosa
@ -235,7 +235,7 @@ class Pipeline(object):
"padding_mask": padding_mask, "padding_mask": padding_mask,
"output_layer": 9 if version == "v1" else 12, "output_layer": 9 if version == "v1" else 12,
} }
t0 = ttime() t0 = time()
with torch.no_grad(): with torch.no_grad():
logits = model.extract_features(**inputs) logits = model.extract_features(**inputs)
feats = model.final_proj(logits[0]) if version == "v1" else logits[0] feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
@ -270,7 +270,7 @@ class Pipeline(object):
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute( feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
0, 2, 1 0, 2, 1
) )
t1 = ttime() t1 = time()
p_len = audio0.shape[0] // self.window p_len = audio0.shape[0] // self.window
if feats.shape[1] < p_len: if feats.shape[1] < p_len:
p_len = feats.shape[1] p_len = feats.shape[1]
@ -296,7 +296,7 @@ class Pipeline(object):
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
torch.mps.empty_cache() torch.mps.empty_cache()
t2 = ttime() t2 = time()
times[0] += t1 - t0 times[0] += t1 - t0
times[2] += t2 - t1 times[2] += t2 - t1
return audio1 return audio1
@ -356,7 +356,7 @@ class Pipeline(object):
s = 0 s = 0
audio_opt = [] audio_opt = []
t = None t = None
t1 = ttime() t1 = time()
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect") audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
p_len = audio_pad.shape[0] // self.window p_len = audio_pad.shape[0] // self.window
inp_f0 = None inp_f0 = None
@ -387,7 +387,7 @@ class Pipeline(object):
pitchf = pitchf.astype(np.float32) pitchf = pitchf.astype(np.float32)
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long() pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float() pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
t2 = ttime() t2 = time()
times[1] += t2 - t1 times[1] += t2 - t1
for t in opt_ts: for t in opt_ts:
t = t // self.window * self.window t = t // self.window * self.window