From 0d2189fdeb089ef4be609cddb14df4486857c211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:27:20 +0900 Subject: [PATCH] optimize(vc.modules): use general get_synthesizer --- infer/modules/vc/modules.py | 42 +++++-------------------------------- 1 file changed, 5 insertions(+), 37 deletions(-) diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index e4dd871..fe71180 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -10,12 +10,7 @@ import torch from io import BytesIO from infer.lib.audio import load_audio, wav2 -from infer.lib.infer_pack.models import ( - SynthesizerTrnMs256NSFsid, - SynthesizerTrnMs256NSFsid_nono, - SynthesizerTrnMs768NSFsid, - SynthesizerTrnMs768NSFsid_nono, -) +from infer.lib.jit import get_synthesizer_ckpt, get_synthesizer from .info import show_model_info from .pipeline import Pipeline from .utils import get_index_path_from_model, load_hubert @@ -67,22 +62,9 @@ class VC: elif torch.backends.mps.is_available(): torch.mps.empty_cache() ###楼下不这么折腾清理不干净 + self.net_g, self.cpt = get_synthesizer_ckpt(self.cpt, self.config.device) self.if_f0 = self.cpt.get("f0", 1) self.version = self.cpt.get("version", "v1") - if self.version == "v1": - if self.if_f0 == 1: - self.net_g = SynthesizerTrnMs256NSFsid( - *self.cpt["config"], is_half=self.config.is_half - ) - else: - self.net_g = SynthesizerTrnMs256NSFsid_nono(*self.cpt["config"]) - elif self.version == "v2": - if self.if_f0 == 1: - self.net_g = SynthesizerTrnMs768NSFsid( - *self.cpt["config"], is_half=self.config.is_half - ) - else: - self.net_g = SynthesizerTrnMs768NSFsid_nono(*self.cpt["config"]) del self.net_g, self.cpt if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -100,36 +82,22 @@ class VC: if to_return_protect else {"visible": True, "maximum": 0, "__type__": "update"} ) + person = f'{os.getenv("weight_root")}/{sid}' logger.info(f"Loading: {person}") - self.cpt = torch.load(person, map_location="cpu") + self.net_g, self.cpt = get_synthesizer(person, self.config.device) self.tgt_sr = self.cpt["config"][-1] self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk self.if_f0 = self.cpt.get("f0", 1) self.version = self.cpt.get("version", "v1") - synthesizer_class = { - ("v1", 1): SynthesizerTrnMs256NSFsid, - ("v1", 0): SynthesizerTrnMs256NSFsid_nono, - ("v2", 1): SynthesizerTrnMs768NSFsid, - ("v2", 0): SynthesizerTrnMs768NSFsid_nono, - } - - self.net_g = synthesizer_class.get( - (self.version, self.if_f0), SynthesizerTrnMs256NSFsid - )(*self.cpt["config"], is_half=self.config.is_half) - - del self.net_g.enc_q - - self.net_g.load_state_dict(self.cpt["weight"], strict=False) - self.net_g.eval().to(self.config.device) if self.config.is_half: self.net_g = self.net_g.half() else: self.net_g = self.net_g.float() - self.pipeline = Pipeline(self.tgt_sr, self.config) + n_spk = self.cpt["config"][-3] index = {"value": get_index_path_from_model(sid), "__type__": "update"} logger.info("Select index: " + index["value"])