Attempt to infer V2 models (#927)
This commit is contained in:
parent
064fecbd5d
commit
296905983a
44
infer_cli.py
44
infer_cli.py
|
@ -1,13 +1,26 @@
|
||||||
import os, sys, pdb, torch
|
from scipy.io import wavfile
|
||||||
|
from fairseq import checkpoint_utils
|
||||||
|
from lib.audio import load_audio
|
||||||
|
from lib.infer_pack.models import (
|
||||||
|
SynthesizerTrnMs256NSFsid,
|
||||||
|
SynthesizerTrnMs256NSFsid_nono,
|
||||||
|
SynthesizerTrnMs768NSFsid,
|
||||||
|
SynthesizerTrnMs768NSFsid_nono,
|
||||||
|
)
|
||||||
|
from vc_infer_pipeline import VC
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import glob
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pdb
|
||||||
|
import torch
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from multiprocessing import cpu_count
|
|
||||||
|
|
||||||
####
|
####
|
||||||
# USAGE
|
# USAGE
|
||||||
|
@ -119,11 +132,6 @@ class Config:
|
||||||
config = Config(device, is_half)
|
config = Config(device, is_half)
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
from vc_infer_pipeline import VC
|
|
||||||
from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono
|
|
||||||
from my_utils import load_audio
|
|
||||||
from fairseq import checkpoint_utils
|
|
||||||
from scipy.io import wavfile
|
|
||||||
|
|
||||||
hubert_model = None
|
hubert_model = None
|
||||||
|
|
||||||
|
@ -224,10 +232,20 @@ def get_vc(model_path):
|
||||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
||||||
if_f0 = cpt.get("f0", 1)
|
if_f0 = cpt.get("f0", 1)
|
||||||
version = cpt.get("version", "v1")
|
version = cpt.get("version", "v1")
|
||||||
|
if version == "v1":
|
||||||
if if_f0 == 1:
|
if if_f0 == 1:
|
||||||
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
|
net_g = SynthesizerTrnMs256NSFsid(
|
||||||
|
*cpt["config"], is_half=is_half
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
||||||
|
elif version == "v2":
|
||||||
|
if if_f0 == 1:
|
||||||
|
net_g = SynthesizerTrnMs768NSFsid(
|
||||||
|
*cpt["config"], is_half=is_half
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||||||
del net_g.enc_q
|
del net_g.enc_q
|
||||||
print(net_g.load_state_dict(cpt["weight"], strict=False))
|
print(net_g.load_state_dict(cpt["weight"], strict=False))
|
||||||
net_g.eval().to(device)
|
net_g.eval().to(device)
|
||||||
|
|
Loading…
Reference in New Issue