gui json update (#479)

* Fix gui.py

There seemed to be some conflicts between #338 and #340, so I corrected them.

* Update gui.py

* Update gui.py

---------

Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com>
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
This commit is contained in:
yuuukiasuna 2023-06-08 16:53:51 +08:00 committed by GitHub
parent b28f98fed3
commit 1b307a4222
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 73 additions and 17 deletions

90
gui.py
View File

@ -12,6 +12,8 @@
""" """
import os, sys, traceback import os, sys, traceback
import json
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from config import Config from config import Config
@ -27,6 +29,7 @@ import torch.nn.functional as F
import torchaudio.transforms as tat import torchaudio.transforms as tat
import scipy.signal as signal import scipy.signal as signal
# import matplotlib.pyplot as plt # import matplotlib.pyplot as plt
from infer_pack.models import ( from infer_pack.models import (
SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid,
@ -38,6 +41,7 @@ from i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
current_dir = os.getcwd()
class RVC: class RVC:
@ -249,7 +253,33 @@ class GUI:
self.launcher() self.launcher()
def load(self):
input_devices, output_devices, _, _ = self.get_devices()
try:
with open('values1.json', 'r') as j:
data = json.load(j)
except:
with open('values1.json', 'w') as j:
data = {
"pth_path":" ",
"index_path":" ",
"sg_input_device": input_devices[sd.default.device[0]],
"sg_output_device": output_devices[sd.default.device[1]],
"threshold": '-60',
"pitch": '12',
"index_rate": '0.5',
"block_time": '1',
"crossfade_length": '0.15',
"extra_time": '1.1', }
return data
def launcher(self): def launcher(self):
data = self.load()
sg.theme("LightBlue3") sg.theme("LightBlue3")
input_devices, output_devices, _, _ = self.get_devices() input_devices, output_devices, _, _ = self.get_devices()
layout = [ layout = [
@ -258,17 +288,18 @@ class GUI:
title=i18n("加载模型"), title=i18n("加载模型"),
layout=[ layout=[
[ [
sg.Input(default_text="hubert_base.pt", key="hubert_path"), sg.Input(default_text="hubert_base.pt", key="hubert_path",disabled=True),
sg.FileBrowse(i18n("Hubert模型"), initial_folder=os.path.join(os.getcwd()),file_types=((". pt"),)), sg.FileBrowse(i18n("Hubert模型"), initial_folder=os.path.join(os.getcwd()),file_types=((". pt"),)),
], ],
[ [
sg.Input(default_text="TEMP\\atri.pth", key="pth_path"), sg.Input(default_text=data.get("pth_path", ''), key="pth_path",),
sg.FileBrowse(i18n("选择.pth文件"),initial_folder=os.path.join(os.getcwd(), "weights"), file_types=((". pth"),)), sg.FileBrowse(i18n("选择.pth文件"),initial_folder=os.path.join(os.getcwd(), "weights"), file_types=((". pth"),))
], ],
[ [
sg.Input( sg.Input(
default_text="TEMP\\added_IVF512_Flat_atri_baseline_src_feat.index", default_text=data.get('index_path', ''),
key="index_path", key="index_path",
), ),
sg.FileBrowse(i18n("选择.index文件"), initial_folder=os.path.join(os.getcwd(), "logs"),file_types=((". index"),)), sg.FileBrowse(i18n("选择.index文件"), initial_folder=os.path.join(os.getcwd(), "logs"),file_types=((". index"),)),
], ],
@ -276,6 +307,7 @@ class GUI:
sg.Input( sg.Input(
default_text="你不需要填写这个You don't need write this.", default_text="你不需要填写这个You don't need write this.",
key="npy_path", key="npy_path",
disabled=True
), ),
sg.FileBrowse(i18n("选择.npy文件"), initial_folder=os.path.join(os.getcwd(), "logs"),file_types=((". npy"),)), sg.FileBrowse(i18n("选择.npy文件"), initial_folder=os.path.join(os.getcwd(), "logs"),file_types=((". npy"),)),
], ],
@ -290,7 +322,7 @@ class GUI:
sg.Combo( sg.Combo(
input_devices, input_devices,
key="sg_input_device", key="sg_input_device",
default_value=input_devices[sd.default.device[0]], default_value=data.get('sg_input_device', ''),
), ),
], ],
[ [
@ -298,7 +330,7 @@ class GUI:
sg.Combo( sg.Combo(
output_devices, output_devices,
key="sg_output_device", key="sg_output_device",
default_value=output_devices[sd.default.device[1]], default_value=data.get('sg_output_device', ''),
), ),
], ],
], ],
@ -315,7 +347,7 @@ class GUI:
key="threhold", key="threhold",
resolution=1, resolution=1,
orientation="h", orientation="h",
default_value=-30, default_value=data.get('threhold', '')
), ),
], ],
[ [
@ -325,7 +357,7 @@ class GUI:
key="pitch", key="pitch",
resolution=1, resolution=1,
orientation="h", orientation="h",
default_value=12, default_value=data.get('pitch', ''),
), ),
], ],
[ [
@ -335,7 +367,7 @@ class GUI:
key="index_rate", key="index_rate",
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=0.5, default_value=data.get('index_rate', ''),
), ),
], ],
], ],
@ -350,7 +382,7 @@ class GUI:
key="block_time", key="block_time",
resolution=0.1, resolution=0.1,
orientation="h", orientation="h",
default_value=1.0, default_value=data.get('block_time', ''),
), ),
], ],
[ [
@ -360,7 +392,7 @@ class GUI:
key="crossfade_length", key="crossfade_length",
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=0.08, default_value=data.get('crossfade_length', ''),
), ),
], ],
[ [
@ -370,7 +402,7 @@ class GUI:
key="extra_time", key="extra_time",
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=0.05, default_value=data.get('extra_time', ''),
), ),
], ],
[ [
@ -388,7 +420,6 @@ class GUI:
sg.Text("0", key="infer_time"), sg.Text("0", key="infer_time"),
], ],
] ]
self.window = sg.Window("RVC - GUI", layout=layout) self.window = sg.Window("RVC - GUI", layout=layout)
self.event_handler() self.event_handler()
@ -400,15 +431,41 @@ class GUI:
exit() exit()
if event == "start_vc" and self.flag_vc == False: if event == "start_vc" and self.flag_vc == False:
self.set_values(values) self.set_values(values)
print(str(self.config.__dict__))
print("using_cuda:" + str(torch.cuda.is_available())) print("using_cuda:" + str(torch.cuda.is_available()))
self.start_vc() self.start_vc()
settings = {
"pth_path":values["pth_path"],
"index_path":values["index_path"],
"sg_input_device": values["sg_input_device"],
"sg_output_device": values["sg_output_device"],
"threhold": values["threhold"],
"pitch": values["pitch"],
"index_rate": values["index_rate"],
"block_time": values["block_time"],
"crossfade_length": values["crossfade_length"],
"extra_time": values["extra_time"],
}
with open('values1.json', 'w') as j:
json.dump(settings, j)
if event == "stop_vc" and self.flag_vc == True: if event == "stop_vc" and self.flag_vc == True:
self.flag_vc = False self.flag_vc = False
def set_values(self, values): def set_values(self, values):
self.set_devices(values["sg_input_device"], values["sg_output_device"]) self.set_devices(values["sg_input_device"], values["sg_output_device"])
self.config.hubert_path = values["hubert_path"] self.config.hubert_path = os.path.join(current_dir, 'hubert_base.pt')
self.config.pth_path = values["pth_path"] self.config.pth_path = values["pth_path"]
self.config.index_path = values["index_path"] self.config.index_path = values["index_path"]
self.config.npy_path = values["npy_path"] self.config.npy_path = values["npy_path"]
@ -615,5 +672,4 @@ class GUI:
print("input device:" + str(sd.default.device[0]) + ":" + str(input_device)) print("input device:" + str(sd.default.device[0]) + ":" + str(input_device))
print("output device:" + str(sd.default.device[1]) + ":" + str(output_device)) print("output device:" + str(sd.default.device[1]) + ":" + str(output_device))
gui = GUI() gui = GUI()