optimize(vc.modules): use general get_synthesizer

This commit is contained in:
源文雨 2024-06-05 15:27:20 +09:00
parent e80074e219
commit 0d2189fdeb

View File

@ -10,12 +10,7 @@ import torch
from io import BytesIO
from infer.lib.audio import load_audio, wav2
from infer.lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)
from infer.lib.jit import get_synthesizer_ckpt, get_synthesizer
from .info import show_model_info
from .pipeline import Pipeline
from .utils import get_index_path_from_model, load_hubert
@ -67,22 +62,9 @@ class VC:
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
###楼下不这么折腾清理不干净
self.net_g, self.cpt = get_synthesizer_ckpt(self.cpt, self.config.device)
self.if_f0 = self.cpt.get("f0", 1)
self.version = self.cpt.get("version", "v1")
if self.version == "v1":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs256NSFsid(
*self.cpt["config"], is_half=self.config.is_half
)
else:
self.net_g = SynthesizerTrnMs256NSFsid_nono(*self.cpt["config"])
elif self.version == "v2":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs768NSFsid(
*self.cpt["config"], is_half=self.config.is_half
)
else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*self.cpt["config"])
del self.net_g, self.cpt
if torch.cuda.is_available():
torch.cuda.empty_cache()
@ -100,36 +82,22 @@ class VC:
if to_return_protect
else {"visible": True, "maximum": 0, "__type__": "update"}
)
person = f'{os.getenv("weight_root")}/{sid}'
logger.info(f"Loading: {person}")
self.cpt = torch.load(person, map_location="cpu")
self.net_g, self.cpt = get_synthesizer(person, self.config.device)
self.tgt_sr = self.cpt["config"][-1]
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
self.if_f0 = self.cpt.get("f0", 1)
self.version = self.cpt.get("version", "v1")
synthesizer_class = {
("v1", 1): SynthesizerTrnMs256NSFsid,
("v1", 0): SynthesizerTrnMs256NSFsid_nono,
("v2", 1): SynthesizerTrnMs768NSFsid,
("v2", 0): SynthesizerTrnMs768NSFsid_nono,
}
self.net_g = synthesizer_class.get(
(self.version, self.if_f0), SynthesizerTrnMs256NSFsid
)(*self.cpt["config"], is_half=self.config.is_half)
del self.net_g.enc_q
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
self.net_g.eval().to(self.config.device)
if self.config.is_half:
self.net_g = self.net_g.half()
else:
self.net_g = self.net_g.float()
self.pipeline = Pipeline(self.tgt_sr, self.config)
n_spk = self.cpt["config"][-3]
index = {"value": get_index_path_from_model(sid), "__type__": "update"}
logger.info("Select index: " + index["value"])