Parameter hot update (#1148)

This commit is contained in:
yxlllc 2023-09-01 10:16:49 +08:00 committed by GitHub
parent a05c72ec6f
commit 0c75454ddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 7 deletions

View File

@ -195,6 +195,7 @@ if __name__ == "__main__":
resolution=1, resolution=1,
orientation="h", orientation="h",
default_value=data.get("threhold", ""), default_value=data.get("threhold", ""),
enable_events=True,
), ),
], ],
[ [
@ -205,6 +206,7 @@ if __name__ == "__main__":
resolution=1, resolution=1,
orientation="h", orientation="h",
default_value=data.get("pitch", ""), default_value=data.get("pitch", ""),
enable_events=True,
), ),
], ],
[ [
@ -215,6 +217,7 @@ if __name__ == "__main__":
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=data.get("index_rate", ""), default_value=data.get("index_rate", ""),
enable_events=True,
), ),
], ],
[ [
@ -224,24 +227,28 @@ if __name__ == "__main__":
"f0method", "f0method",
key="pm", key="pm",
default=data.get("pm", "") == True, default=data.get("pm", "") == True,
enable_events=True,
), ),
sg.Radio( sg.Radio(
"harvest", "harvest",
"f0method", "f0method",
key="harvest", key="harvest",
default=data.get("harvest", "") == True, default=data.get("harvest", "") == True,
enable_events=True,
), ),
sg.Radio( sg.Radio(
"crepe", "crepe",
"f0method", "f0method",
key="crepe", key="crepe",
default=data.get("crepe", "") == True, default=data.get("crepe", "") == True,
enable_events=True,
), ),
sg.Radio( sg.Radio(
"rmvpe", "rmvpe",
"f0method", "f0method",
key="rmvpe", key="rmvpe",
default=data.get("rmvpe", "") == True, default=data.get("rmvpe", "") == True,
enable_events=True,
), ),
], ],
], ],
@ -257,6 +264,7 @@ if __name__ == "__main__":
resolution=0.03, resolution=0.03,
orientation="h", orientation="h",
default_value=data.get("block_time", ""), default_value=data.get("block_time", ""),
enable_events=True,
), ),
], ],
[ [
@ -269,6 +277,7 @@ if __name__ == "__main__":
default_value=data.get( default_value=data.get(
"n_cpu", min(self.config.n_cpu, n_cpu) "n_cpu", min(self.config.n_cpu, n_cpu)
), ),
enable_events=True,
), ),
], ],
[ [
@ -279,6 +288,7 @@ if __name__ == "__main__":
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=data.get("crossfade_length", ""), default_value=data.get("crossfade_length", ""),
enable_events=True,
), ),
], ],
[ [
@ -289,11 +299,12 @@ if __name__ == "__main__":
resolution=0.01, resolution=0.01,
orientation="h", orientation="h",
default_value=data.get("extra_time", ""), default_value=data.get("extra_time", ""),
enable_events=True,
), ),
], ],
[ [
sg.Checkbox(i18n("输入降噪"), key="I_noise_reduce"), sg.Checkbox(i18n("输入降噪"), key="I_noise_reduce", enable_events=True),
sg.Checkbox(i18n("输出降噪"), key="O_noise_reduce"), sg.Checkbox(i18n("输出降噪"), key="O_noise_reduce", enable_events=True),
], ],
], ],
title=i18n("性能设置"), title=i18n("性能设置"),
@ -306,7 +317,7 @@ if __name__ == "__main__":
sg.Text("0", key="infer_time"), sg.Text("0", key="infer_time"),
], ],
] ]
self.window = sg.Window("RVC - GUI", layout=layout) self.window = sg.Window("RVC - GUI", layout=layout, finalize=True)
self.event_handler() self.event_handler()
def event_handler(self): def event_handler(self):
@ -364,7 +375,28 @@ if __name__ == "__main__":
json.dump(settings, j) json.dump(settings, j)
if event == "stop_vc" and self.flag_vc == True: if event == "stop_vc" and self.flag_vc == True:
self.flag_vc = False self.flag_vc = False
# Parameter hot update
if event == 'threhold':
self.config.threhold = values["threhold"]
elif event == "pitch":
self.config.pitch = values["pitch"]
if hasattr(self, "rvc"):
self.rvc.change_key(values["pitch"])
elif event == "index_rate":
self.config.index_rate = values["index_rate"]
if hasattr(self, "rvc"):
self.rvc.change_index_rate(values["index_rate"])
elif event in ["pm", "harvest", "crepe", "rmvpe"]:
self.config.f0method = event
elif event == 'I_noise_reduce':
self.config.I_noise_reduce = values["I_noise_reduce"]
elif event == 'O_noise_reduce':
self.config.O_noise_reduce = values["O_noise_reduce"]
elif event != "start_vc" and self.flag_vc == True:
# Other parameters do not support hot update
self.flag_vc = False
def set_values(self, values): def set_values(self, values):
if len(values["pth_path"].strip()) == 0: if len(values["pth_path"].strip()) == 0:
sg.popup(i18n("请选择pth文件")) sg.popup(i18n("请选择pth文件"))
@ -470,9 +502,9 @@ if __name__ == "__main__":
self.sola_buffer: torch.Tensor = torch.zeros( self.sola_buffer: torch.Tensor = torch.zeros(
self.crossfade_frame, device=device, dtype=torch.float32 self.crossfade_frame, device=device, dtype=torch.float32
) )
self.fade_in_window: torch.Tensor = torch.linspace( self.fade_in_window: torch.Tensor = torch.sin(0.5 * np.pi * torch.linspace(
0.0, 1.0, steps=self.crossfade_frame, device=device, dtype=torch.float32 0.0, 1.0, steps=self.crossfade_frame, device=device, dtype=torch.float32
) )) ** 2
self.fade_out_window: torch.Tensor = 1 - self.fade_in_window self.fade_out_window: torch.Tensor = 1 - self.fade_in_window
self.resampler = tat.Resample( self.resampler = tat.Resample(
orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32 orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32

View File

@ -68,6 +68,7 @@ class RVC:
self.index = faiss.read_index(index_path) self.index = faiss.read_index(index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal) self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
print("index search enabled") print("index search enabled")
self.index_path = index_path
self.index_rate = index_rate self.index_rate = index_rate
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"], ["assets/hubert/hubert_base.pt"],
@ -111,7 +112,17 @@ class RVC:
self.is_half = config.is_half self.is_half = config.is_half
except: except:
print(traceback.format_exc()) print(traceback.format_exc())
def change_key(self, new_key):
self.f0_up_key = new_key
def change_index_rate(self, new_index_rate):
if new_index_rate != 0 and self.index_rate == 0:
self.index = faiss.read_index(self.index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
print("index search enabled")
self.index_rate = new_index_rate
def get_f0_post(self, f0): def get_f0_post(self, f0):
f0_min = self.f0_min f0_min = self.f0_min
f0_max = self.f0_max f0_max = self.f0_max