Attempt to infer V2 models (#927)

This commit is contained in:
Matej Tkac 2023-08-03 03:23:20 +01:00 committed by GitHub
parent 064fecbd5d
commit 296905983a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 34 additions and 16 deletions

View File

@ -1,13 +1,26 @@
import os, sys, pdb, torch
from scipy.io import wavfile
from fairseq import checkpoint_utils
from lib.audio import load_audio
from lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)
from vc_infer_pipeline import VC
from multiprocessing import cpu_count
import numpy as np
import torch
import sys
import glob
import argparse
import os
import sys
import pdb
import torch
now_dir = os.getcwd()
sys.path.append(now_dir)
import argparse
import glob
import sys
import torch
import numpy as np
from multiprocessing import cpu_count
####
# USAGE
@ -119,11 +132,6 @@ class Config:
config = Config(device, is_half)
now_dir = os.getcwd()
sys.path.append(now_dir)
from vc_infer_pipeline import VC
from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono
from my_utils import load_audio
from fairseq import checkpoint_utils
from scipy.io import wavfile
hubert_model = None
@ -224,10 +232,20 @@ def get_vc(model_path):
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
if if_f0 == 1:
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
else:
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
if version == "v1":
if if_f0 == 1:
net_g = SynthesizerTrnMs256NSFsid(
*cpt["config"], is_half=is_half
)
else:
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif version == "v2":
if if_f0 == 1:
net_g = SynthesizerTrnMs768NSFsid(
*cpt["config"], is_half=is_half
)
else:
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del net_g.enc_q
print(net_g.load_state_dict(cpt["weight"], strict=False))
net_g.eval().to(device)