From 296905983a3f87c3f90510f9efa3adee84a622dc Mon Sep 17 00:00:00 2001 From: Matej Tkac <105047985+matej-tkac@users.noreply.github.com> Date: Thu, 3 Aug 2023 03:23:20 +0100 Subject: [PATCH] Attempt to infer V2 models (#927) --- infer_cli.py | 50 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/infer_cli.py b/infer_cli.py index 5eadea6..e382eff 100644 --- a/infer_cli.py +++ b/infer_cli.py @@ -1,13 +1,26 @@ -import os, sys, pdb, torch +from scipy.io import wavfile +from fairseq import checkpoint_utils +from lib.audio import load_audio +from lib.infer_pack.models import ( + SynthesizerTrnMs256NSFsid, + SynthesizerTrnMs256NSFsid_nono, + SynthesizerTrnMs768NSFsid, + SynthesizerTrnMs768NSFsid_nono, +) +from vc_infer_pipeline import VC +from multiprocessing import cpu_count +import numpy as np +import torch +import sys +import glob +import argparse +import os +import sys +import pdb +import torch now_dir = os.getcwd() sys.path.append(now_dir) -import argparse -import glob -import sys -import torch -import numpy as np -from multiprocessing import cpu_count #### # USAGE @@ -119,11 +132,6 @@ class Config: config = Config(device, is_half) now_dir = os.getcwd() sys.path.append(now_dir) -from vc_infer_pipeline import VC -from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono -from my_utils import load_audio -from fairseq import checkpoint_utils -from scipy.io import wavfile hubert_model = None @@ -224,10 +232,20 @@ def get_vc(model_path): cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk if_f0 = cpt.get("f0", 1) version = cpt.get("version", "v1") - if if_f0 == 1: - net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half) - else: - net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) + if version == "v1": + if if_f0 == 1: + net_g = SynthesizerTrnMs256NSFsid( + *cpt["config"], is_half=is_half + ) + else: + net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) + elif version == "v2": + if if_f0 == 1: + net_g = SynthesizerTrnMs768NSFsid( + *cpt["config"], is_half=is_half + ) + else: + net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) del net_g.enc_q print(net_g.load_state_dict(cpt["weight"], strict=False)) net_g.eval().to(device)