optimize(vc.hash): use general get_synthesizer

This commit is contained in:
源文雨 2024-06-05 15:19:05 +09:00
parent bb675b1e45
commit e80074e219
3 changed files with 10 additions and 25 deletions

View File

@ -1,2 +1,2 @@
from .utils import load, rmvpe_jit_export, synthesizer_jit_export
from .synthesizer import get_synthesizer
from .synthesizer import get_synthesizer, get_synthesizer_ckpt

View File

@ -1,7 +1,6 @@
import torch
def get_synthesizer(pth_path, device=torch.device("cpu")):
def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
from infer.lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
@ -9,7 +8,6 @@ def get_synthesizer(pth_path, device=torch.device("cpu")):
SynthesizerTrnMs768NSFsid_nono,
)
cpt = torch.load(pth_path, map_location=torch.device("cpu"))
# tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
if_f0 = cpt.get("f0", 1)
@ -36,3 +34,8 @@ def get_synthesizer(pth_path, device=torch.device("cpu")):
net_g.eval().to(device)
net_g.remove_weight_norm()
return net_g, cpt
def get_synthesizer(pth_path, device=torch.device("cpu")):
return get_synthesizer_ckpt(
torch.load(pth_path, map_location=torch.device("cpu")), device,
)

View File

@ -6,6 +6,7 @@ from scipy.fft import fft
from pybase16384 import encode_to_string, decode_from_string
from configs import CPUConfig, singleton_variable
from infer.lib.jit import get_synthesizer_ckpt
from .pipeline import Pipeline
from .utils import load_hubert
@ -123,38 +124,19 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
n = -n
audio_opt = audio_opt[n:-n]
h = wave_hash(audio_opt)
del pipeline, audio, audio_opt
del pipeline, audio_opt
return h
def model_hash_ckpt(cpt):
from infer.lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)
config = CPUConfig()
with TorchSeedContext(114514):
net_g, cpt = get_synthesizer_ckpt(cpt, config.device)
tgt_sr = cpt["config"][-1]
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
synthesizer_class = {
("v1", 1): SynthesizerTrnMs256NSFsid,
("v1", 0): SynthesizerTrnMs256NSFsid_nono,
("v2", 1): SynthesizerTrnMs768NSFsid,
("v2", 0): SynthesizerTrnMs768NSFsid_nono,
}
net_g = synthesizer_class.get((version, if_f0), SynthesizerTrnMs256NSFsid)(
*cpt["config"], is_half=config.is_half
)
del net_g.enc_q
net_g.load_state_dict(cpt["weight"], strict=False)
net_g.eval().to(config.device)
if config.is_half:
net_g = net_g.half()
else: