diff --git a/infer/lib/jit/__init__.py b/infer/lib/jit/__init__.py index 2f2e04a..179fbe0 100644 --- a/infer/lib/jit/__init__.py +++ b/infer/lib/jit/__init__.py @@ -1,2 +1,2 @@ from .utils import load, rmvpe_jit_export, synthesizer_jit_export -from .synthesizer import get_synthesizer +from .synthesizer import get_synthesizer, get_synthesizer_ckpt diff --git a/infer/lib/jit/synthesizer.py b/infer/lib/jit/synthesizer.py index b8db4fa..5069fe9 100644 --- a/infer/lib/jit/synthesizer.py +++ b/infer/lib/jit/synthesizer.py @@ -1,7 +1,6 @@ import torch - -def get_synthesizer(pth_path, device=torch.device("cpu")): +def get_synthesizer_ckpt(cpt, device=torch.device("cpu")): from infer.lib.infer_pack.models import ( SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono, @@ -9,7 +8,6 @@ def get_synthesizer(pth_path, device=torch.device("cpu")): SynthesizerTrnMs768NSFsid_nono, ) - cpt = torch.load(pth_path, map_location=torch.device("cpu")) # tgt_sr = cpt["config"][-1] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] if_f0 = cpt.get("f0", 1) @@ -36,3 +34,8 @@ def get_synthesizer(pth_path, device=torch.device("cpu")): net_g.eval().to(device) net_g.remove_weight_norm() return net_g, cpt + +def get_synthesizer(pth_path, device=torch.device("cpu")): + return get_synthesizer_ckpt( + torch.load(pth_path, map_location=torch.device("cpu")), device, + ) diff --git a/infer/modules/vc/hash.py b/infer/modules/vc/hash.py index 2234aab..0b29936 100644 --- a/infer/modules/vc/hash.py +++ b/infer/modules/vc/hash.py @@ -6,6 +6,7 @@ from scipy.fft import fft from pybase16384 import encode_to_string, decode_from_string from configs import CPUConfig, singleton_variable +from infer.lib.jit import get_synthesizer_ckpt from .pipeline import Pipeline from .utils import load_hubert @@ -123,38 +124,19 @@ def model_hash(config, tgt_sr, net_g, if_f0, version): n = -n audio_opt = audio_opt[n:-n] h = wave_hash(audio_opt) - del pipeline, audio, audio_opt + del pipeline, audio_opt return h def model_hash_ckpt(cpt): - from infer.lib.infer_pack.models import ( - SynthesizerTrnMs256NSFsid, - SynthesizerTrnMs256NSFsid_nono, - SynthesizerTrnMs768NSFsid, - SynthesizerTrnMs768NSFsid_nono, - ) - config = CPUConfig() with TorchSeedContext(114514): + net_g, cpt = get_synthesizer_ckpt(cpt, config.device) tgt_sr = cpt["config"][-1] if_f0 = cpt.get("f0", 1) version = cpt.get("version", "v1") - synthesizer_class = { - ("v1", 1): SynthesizerTrnMs256NSFsid, - ("v1", 0): SynthesizerTrnMs256NSFsid_nono, - ("v2", 1): SynthesizerTrnMs768NSFsid, - ("v2", 0): SynthesizerTrnMs768NSFsid_nono, - } - net_g = synthesizer_class.get((version, if_f0), SynthesizerTrnMs256NSFsid)( - *cpt["config"], is_half=config.is_half - ) - del net_g.enc_q - - net_g.load_state_dict(cpt["weight"], strict=False) - net_g.eval().to(config.device) if config.is_half: net_g = net_g.half() else: