fix Config, GUIConfig and self (#340)

Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
This commit is contained in:
Ftps 2023-05-26 20:32:19 +09:00 committed by GitHub
parent 0729c9d6f2
commit a2ef4cca76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 10 deletions

19
gui.py
View File

@ -16,7 +16,7 @@ now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from config import Config as MyConfig from config import Config as MyConfig
is_half = MyConfig().is_half Config = Config()
import PySimpleGUI as sg import PySimpleGUI as sg
import sounddevice as sd import sounddevice as sd
import noisereduce as nr import noisereduce as nr
@ -71,7 +71,7 @@ class RVC:
) )
self.model = models[0] self.model = models[0]
self.model = self.model.to(device) self.model = self.model.to(device)
if is_half == True: if Config.is_half:
self.model = self.model.half() self.model = self.model.half()
else: else:
self.model = self.model.float() self.model = self.model.float()
@ -81,25 +81,24 @@ class RVC:
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
self.if_f0 = cpt.get("f0", 1) self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1") self.version = cpt.get("version", "v1")
if self.version == "v1": if self.version == "v1":
if self.if_f0 == 1: if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs256NSFsid( self.net_g = SynthesizerTrnMs256NSFsid(
*cpt["config"], is_half=self.config.is_half *cpt["config"], is_half=Config.is_half
) )
else: else:
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif self.version == "v2": elif self.version == "v2":
if self.if_f0 == 1: if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs768NSFsid( self.net_g = SynthesizerTrnMs768NSFsid(
*cpt["config"], is_half=self.config.is_half *cpt["config"], is_half=Config.is_half
) )
else: else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del self.net_g.enc_q del self.net_g.enc_q
print(self.net_g.load_state_dict(cpt["weight"], strict=False)) print(self.net_g.load_state_dict(cpt["weight"], strict=False))
self.net_g.eval().to(device) self.net_g.eval().to(device)
if is_half == True: if Config.is_half:
self.net_g = self.net_g.half() self.net_g = self.net_g.half()
else: else:
self.net_g = self.net_g.float() self.net_g = self.net_g.float()
@ -160,7 +159,7 @@ class RVC:
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.no_grad(): with torch.no_grad():
logits = self.model.extract_features(**inputs) logits = self.model.extract_features(**inputs)
feats = model.final_proj(logits[0]) if self.version == "v1" else logits[0] feats = self.model.final_proj(logits[0]) if self.version == "v1" else logits[0]
####索引优化 ####索引优化
try: try:
@ -174,7 +173,7 @@ class RVC:
weight = np.square(1 / score) weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True) weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
if is_half == True: if Config.is_half:
npy = npy.astype("float16") npy = npy.astype("float16")
feats = ( feats = (
torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate
@ -220,7 +219,7 @@ class RVC:
return infered_audio return infered_audio
class Config: class GUIConfig:
def __init__(self) -> None: def __init__(self) -> None:
self.hubert_path: str = "" self.hubert_path: str = ""
self.pth_path: str = "" self.pth_path: str = ""
@ -240,7 +239,7 @@ class Config:
class GUI: class GUI:
def __init__(self) -> None: def __init__(self) -> None:
self.config = Config() self.config = GUIConfig()
self.flag_vc = False self.flag_vc = False
self.launcher() self.launcher()