mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-05-06 20:01:37 +08:00
optimize(vc.hash): use general get_synthesizer
This commit is contained in:
parent
bb675b1e45
commit
e80074e219
@ -1,2 +1,2 @@
|
|||||||
from .utils import load, rmvpe_jit_export, synthesizer_jit_export
|
from .utils import load, rmvpe_jit_export, synthesizer_jit_export
|
||||||
from .synthesizer import get_synthesizer
|
from .synthesizer import get_synthesizer, get_synthesizer_ckpt
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
|
||||||
def get_synthesizer(pth_path, device=torch.device("cpu")):
|
|
||||||
from infer.lib.infer_pack.models import (
|
from infer.lib.infer_pack.models import (
|
||||||
SynthesizerTrnMs256NSFsid,
|
SynthesizerTrnMs256NSFsid,
|
||||||
SynthesizerTrnMs256NSFsid_nono,
|
SynthesizerTrnMs256NSFsid_nono,
|
||||||
@ -9,7 +8,6 @@ def get_synthesizer(pth_path, device=torch.device("cpu")):
|
|||||||
SynthesizerTrnMs768NSFsid_nono,
|
SynthesizerTrnMs768NSFsid_nono,
|
||||||
)
|
)
|
||||||
|
|
||||||
cpt = torch.load(pth_path, map_location=torch.device("cpu"))
|
|
||||||
# tgt_sr = cpt["config"][-1]
|
# tgt_sr = cpt["config"][-1]
|
||||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||||
if_f0 = cpt.get("f0", 1)
|
if_f0 = cpt.get("f0", 1)
|
||||||
@ -36,3 +34,8 @@ def get_synthesizer(pth_path, device=torch.device("cpu")):
|
|||||||
net_g.eval().to(device)
|
net_g.eval().to(device)
|
||||||
net_g.remove_weight_norm()
|
net_g.remove_weight_norm()
|
||||||
return net_g, cpt
|
return net_g, cpt
|
||||||
|
|
||||||
|
def get_synthesizer(pth_path, device=torch.device("cpu")):
|
||||||
|
return get_synthesizer_ckpt(
|
||||||
|
torch.load(pth_path, map_location=torch.device("cpu")), device,
|
||||||
|
)
|
||||||
|
@ -6,6 +6,7 @@ from scipy.fft import fft
|
|||||||
from pybase16384 import encode_to_string, decode_from_string
|
from pybase16384 import encode_to_string, decode_from_string
|
||||||
|
|
||||||
from configs import CPUConfig, singleton_variable
|
from configs import CPUConfig, singleton_variable
|
||||||
|
from infer.lib.jit import get_synthesizer_ckpt
|
||||||
|
|
||||||
from .pipeline import Pipeline
|
from .pipeline import Pipeline
|
||||||
from .utils import load_hubert
|
from .utils import load_hubert
|
||||||
@ -123,38 +124,19 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
|||||||
n = -n
|
n = -n
|
||||||
audio_opt = audio_opt[n:-n]
|
audio_opt = audio_opt[n:-n]
|
||||||
h = wave_hash(audio_opt)
|
h = wave_hash(audio_opt)
|
||||||
del pipeline, audio, audio_opt
|
del pipeline, audio_opt
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
def model_hash_ckpt(cpt):
|
def model_hash_ckpt(cpt):
|
||||||
from infer.lib.infer_pack.models import (
|
|
||||||
SynthesizerTrnMs256NSFsid,
|
|
||||||
SynthesizerTrnMs256NSFsid_nono,
|
|
||||||
SynthesizerTrnMs768NSFsid,
|
|
||||||
SynthesizerTrnMs768NSFsid_nono,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = CPUConfig()
|
config = CPUConfig()
|
||||||
|
|
||||||
with TorchSeedContext(114514):
|
with TorchSeedContext(114514):
|
||||||
|
net_g, cpt = get_synthesizer_ckpt(cpt, config.device)
|
||||||
tgt_sr = cpt["config"][-1]
|
tgt_sr = cpt["config"][-1]
|
||||||
if_f0 = cpt.get("f0", 1)
|
if_f0 = cpt.get("f0", 1)
|
||||||
version = cpt.get("version", "v1")
|
version = cpt.get("version", "v1")
|
||||||
synthesizer_class = {
|
|
||||||
("v1", 1): SynthesizerTrnMs256NSFsid,
|
|
||||||
("v1", 0): SynthesizerTrnMs256NSFsid_nono,
|
|
||||||
("v2", 1): SynthesizerTrnMs768NSFsid,
|
|
||||||
("v2", 0): SynthesizerTrnMs768NSFsid_nono,
|
|
||||||
}
|
|
||||||
net_g = synthesizer_class.get((version, if_f0), SynthesizerTrnMs256NSFsid)(
|
|
||||||
*cpt["config"], is_half=config.is_half
|
|
||||||
)
|
|
||||||
|
|
||||||
del net_g.enc_q
|
|
||||||
|
|
||||||
net_g.load_state_dict(cpt["weight"], strict=False)
|
|
||||||
net_g.eval().to(config.device)
|
|
||||||
if config.is_half:
|
if config.is_half:
|
||||||
net_g = net_g.half()
|
net_g = net_g.half()
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user