Merge pull request #1659 from yxlllc/dev

optimize stream logic
This commit is contained in:
RVC-Boss 2023-12-28 18:14:22 +08:00 committed by GitHub
commit d766298065
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

218
gui_v1.py
View File

@ -12,8 +12,7 @@ now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
import multiprocessing import multiprocessing
stream_latency = -1 flag_vc = False
def printt(strr, *args): def printt(strr, *args):
if len(args) == 0: if len(args) == 0:
@ -113,32 +112,36 @@ if __name__ == "__main__":
self.pth_path: str = "" self.pth_path: str = ""
self.index_path: str = "" self.index_path: str = ""
self.pitch: int = 0 self.pitch: int = 0
self.samplerate: int = 40000 self.sr_type: str = "sr_model"
self.block_time: float = 1.0 # s self.block_time: float = 0.25 # s
self.buffer_num: int = 1
self.threhold: int = -60 self.threhold: int = -60
self.crossfade_time: float = 0.05 self.crossfade_time: float = 0.05
self.extra_time: float = 2.5 self.extra_time: float = 2.5
self.I_noise_reduce = False self.I_noise_reduce: bool = False
self.O_noise_reduce = False self.O_noise_reduce: bool = False
self.rms_mix_rate = 0.0 self.use_pv: bool = False
self.index_rate = 0.3 self.rms_mix_rate: float = 0.0
self.n_cpu = min(n_cpu, 6) self.index_rate: float = 0.0
self.f0method = "harvest" self.n_cpu: int = min(n_cpu, 4)
self.sg_input_device = "" self.f0method: str = "fcpe"
self.sg_output_device = "" self.sg_input_device: str = ""
self.sg_output_device: str = ""
class GUI: class GUI:
def __init__(self) -> None: def __init__(self) -> None:
self.gui_config = GUIConfig() self.gui_config = GUIConfig()
self.config = Config() self.config = Config()
self.flag_vc = False
self.function = "vc" self.function = "vc"
self.delay_time = 0 self.delay_time = 0
self.input_devices = None
self.output_devices = None
self.input_devices_indices = None
self.output_devices_indices = None
self.stream = None
self.update_devices()
self.launcher() self.launcher()
def load(self): def load(self):
input_devices, output_devices, _, _ = self.get_devices()
try: try:
with open("configs/config.json", "r") as j: with open("configs/config.json", "r") as j:
data = json.load(j) data = json.load(j)
@ -149,25 +152,26 @@ if __name__ == "__main__":
data["crepe"] = data["f0method"] == "crepe" data["crepe"] = data["f0method"] == "crepe"
data["rmvpe"] = data["f0method"] == "rmvpe" data["rmvpe"] = data["f0method"] == "rmvpe"
data["fcpe"] = data["f0method"] == "fcpe" data["fcpe"] = data["f0method"] == "fcpe"
if data["sg_input_device"] not in input_devices: if data["sg_input_device"] not in self.input_devices:
data["sg_input_device"] = input_devices[sd.default.device[0]] data["sg_input_device"] = self.input_devices[self.input_devices_indices.index(sd.default.device[0])]
if data["sg_output_device"] not in output_devices: if data["sg_output_device"] not in self.output_devices:
data["sg_output_device"] = output_devices[sd.default.device[1]] data["sg_output_device"] = self.output_devices[self.output_devices_indices.index(sd.default.device[1])]
except: except:
with open("configs/config.json", "w") as j: with open("configs/config.json", "w") as j:
data = { data = {
"pth_path": " ", "pth_path": "",
"index_path": " ", "index_path": "",
"sg_input_device": input_devices[sd.default.device[0]], "sg_input_device": self.input_devices[self.input_devices_indices.index(sd.default.device[0])],
"sg_output_device": output_devices[sd.default.device[1]], "sg_output_device": self.output_devices[self.output_devices_indices.index(sd.default.device[1])],
"sr_type": "sr_model", "sr_type": "sr_model",
"threhold": "-60", "threhold": -60,
"pitch": "0", "pitch": 0,
"index_rate": "0", "index_rate": 0,
"rms_mix_rate": "0", "rms_mix_rate": 0,
"block_time": "0.25", "block_time": 0.25,
"crossfade_length": "0.05", "crossfade_length": 0.05,
"extra_time": "2.5", "extra_time": 2.5,
"n_cpu": 4,
"f0method": "rmvpe", "f0method": "rmvpe",
"use_jit": False, "use_jit": False,
"use_pv": False, "use_pv": False,
@ -185,7 +189,6 @@ if __name__ == "__main__":
data = self.load() data = self.load()
self.config.use_jit = False # data.get("use_jit", self.config.use_jit) self.config.use_jit = False # data.get("use_jit", self.config.use_jit)
sg.theme("LightBlue3") sg.theme("LightBlue3")
input_devices, output_devices, _, _ = self.get_devices()
layout = [ layout = [
[ [
sg.Frame( sg.Frame(
@ -224,7 +227,7 @@ if __name__ == "__main__":
[ [
sg.Text(i18n("输入设备")), sg.Text(i18n("输入设备")),
sg.Combo( sg.Combo(
input_devices, self.input_devices,
key="sg_input_device", key="sg_input_device",
default_value=data.get("sg_input_device", ""), default_value=data.get("sg_input_device", ""),
), ),
@ -232,7 +235,7 @@ if __name__ == "__main__":
[ [
sg.Text(i18n("输出设备")), sg.Text(i18n("输出设备")),
sg.Combo( sg.Combo(
output_devices, self.output_devices,
key="sg_output_device", key="sg_output_device",
default_value=data.get("sg_output_device", ""), default_value=data.get("sg_output_device", ""),
), ),
@ -463,32 +466,27 @@ if __name__ == "__main__":
self.event_handler() self.event_handler()
def event_handler(self): def event_handler(self):
global flag_vc
while True: while True:
event, values = self.window.read() event, values = self.window.read()
if event == sg.WINDOW_CLOSED: if event == sg.WINDOW_CLOSED:
self.flag_vc = False self.stop_stream()
exit() exit()
if event == "reload_devices": if event == "reload_devices":
prev_input = self.window["sg_input_device"].get() self.update_devices()
prev_output = self.window["sg_output_device"].get() if self.gui_config.sg_input_device not in self.input_devices:
input_devices, output_devices, _, _ = self.get_devices(update=True) self.gui_config.sg_input_device = self.input_devices[0]
if prev_input not in input_devices: self.window["sg_input_device"].Update(values=self.input_devices)
self.gui_config.sg_input_device = input_devices[0]
else:
self.gui_config.sg_input_device = prev_input
self.window["sg_input_device"].Update(values=input_devices)
self.window["sg_input_device"].Update( self.window["sg_input_device"].Update(
value=self.gui_config.sg_input_device value=self.gui_config.sg_input_device
) )
if prev_output not in output_devices: if self.gui_config.sg_output_device not in self.output_devices:
self.gui_config.sg_output_device = output_devices[0] self.gui_config.sg_output_device = self.output_devices[0]
else: self.window["sg_output_device"].Update(values=self.output_devices)
self.gui_config.sg_output_device = prev_output
self.window["sg_output_device"].Update(values=output_devices)
self.window["sg_output_device"].Update( self.window["sg_output_device"].Update(
value=self.gui_config.sg_output_device value=self.gui_config.sg_output_device
) )
if event == "start_vc" and self.flag_vc == False: if event == "start_vc" and not flag_vc:
if self.set_values(values) == True: if self.set_values(values) == True:
printt("cuda_is_available: %s", torch.cuda.is_available()) printt("cuda_is_available: %s", torch.cuda.is_available())
self.start_vc() self.start_vc()
@ -527,22 +525,17 @@ if __name__ == "__main__":
} }
with open("configs/config.json", "w") as j: with open("configs/config.json", "w") as j:
json.dump(settings, j) json.dump(settings, j)
global stream_latency if self.stream is not None:
while stream_latency < 0: self.delay_time = (
time.sleep(0.01) self.stream.latency[-1]
self.delay_time = ( + values["block_time"]
stream_latency + values["crossfade_length"]
+ values["block_time"] + 0.01
+ values["crossfade_length"] )
+ 0.01
)
if values["I_noise_reduce"]: if values["I_noise_reduce"]:
self.delay_time += min(values["crossfade_length"], 0.04) self.delay_time += min(values["crossfade_length"], 0.04)
self.window["sr_stream"].update(self.gui_config.samplerate) self.window["sr_stream"].update(self.gui_config.samplerate)
self.window["delay_time"].update(int(self.delay_time * 1000)) self.window["delay_time"].update(int(self.delay_time * 1000))
if event == "stop_vc" and self.flag_vc == True:
self.flag_vc = False
stream_latency = -1
# Parameter hot update # Parameter hot update
if event == "threhold": if event == "threhold":
self.gui_config.threhold = values["threhold"] self.gui_config.threhold = values["threhold"]
@ -560,7 +553,7 @@ if __name__ == "__main__":
self.gui_config.f0method = event self.gui_config.f0method = event
elif event == "I_noise_reduce": elif event == "I_noise_reduce":
self.gui_config.I_noise_reduce = values["I_noise_reduce"] self.gui_config.I_noise_reduce = values["I_noise_reduce"]
if stream_latency > 0: if self.stream is not None:
self.delay_time += ( self.delay_time += (
1 if values["I_noise_reduce"] else -1 1 if values["I_noise_reduce"] else -1
) * min(values["crossfade_length"], 0.04) ) * min(values["crossfade_length"], 0.04)
@ -571,11 +564,10 @@ if __name__ == "__main__":
self.gui_config.use_pv = values["use_pv"] self.gui_config.use_pv = values["use_pv"]
elif event in ["vc", "im"]: elif event in ["vc", "im"]:
self.function = event self.function = event
elif event != "start_vc" and self.flag_vc == True: elif event == "stop_vc" or event != "start_vc":
# Other parameters do not support hot update # Other parameters do not support hot update
self.flag_vc = False self.stop_stream()
stream_latency = -1
def set_values(self, values): def set_values(self, values):
if len(values["pth_path"].strip()) == 0: if len(values["pth_path"].strip()) == 0:
sg.popup(i18n("请选择pth文件")) sg.popup(i18n("请选择pth文件"))
@ -593,6 +585,8 @@ if __name__ == "__main__":
self.set_devices(values["sg_input_device"], values["sg_output_device"]) self.set_devices(values["sg_input_device"], values["sg_output_device"])
self.config.use_jit = False # values["use_jit"] self.config.use_jit = False # values["use_jit"]
# self.device_latency = values["device_latency"] # self.device_latency = values["device_latency"]
self.gui_config.sg_input_device = values["sg_input_device"]
self.gui_config.sg_output_device = values["sg_output_device"]
self.gui_config.pth_path = values["pth_path"] self.gui_config.pth_path = values["pth_path"]
self.gui_config.index_path = values["index_path"] self.gui_config.index_path = values["index_path"]
self.gui_config.sr_type = ["sr_model", "sr_device"][ self.gui_config.sr_type = ["sr_model", "sr_device"][
@ -625,7 +619,6 @@ if __name__ == "__main__":
def start_vc(self): def start_vc(self):
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.flag_vc = True
self.rvc = rvc_for_realtime.RVC( self.rvc = rvc_for_realtime.RVC(
self.gui_config.pitch, self.gui_config.pitch,
self.gui_config.pth_path, self.gui_config.pth_path,
@ -732,34 +725,37 @@ if __name__ == "__main__":
self.tg = TorchGate( self.tg = TorchGate(
sr=self.gui_config.samplerate, n_fft=4 * self.zc, prop_decrease=0.9 sr=self.gui_config.samplerate, n_fft=4 * self.zc, prop_decrease=0.9
).to(self.config.device) ).to(self.config.device)
thread_vc = threading.Thread(target=self.soundinput) self.start_stream()
thread_vc.start()
def start_stream(self):
def soundinput(self): global flag_vc
""" if not flag_vc:
接受音频输入 flag_vc = True
""" channels = 1 if sys.platform == "darwin" else 2
channels = 1 if sys.platform == "darwin" else 2 self.stream = sd.Stream(
with sd.Stream( channels=channels,
channels=channels, callback=self.audio_callback,
callback=self.audio_callback, blocksize=self.block_frame,
blocksize=self.block_frame, samplerate=self.gui_config.samplerate,
samplerate=self.gui_config.samplerate, dtype="float32")
dtype="float32", self.stream.start()
) as stream:
global stream_latency
stream_latency = stream.latency[-1]
while self.flag_vc:
time.sleep(self.gui_config.block_time)
printt("Audio block passed.")
printt("ENDing VC")
def stop_stream(self):
global flag_vc
if flag_vc:
flag_vc = False
if self.stream is not None:
self.stream.stop()
self.stream.close()
self.stream = None
def audio_callback( def audio_callback(
self, indata: np.ndarray, outdata: np.ndarray, frames, times, status self, indata: np.ndarray, outdata: np.ndarray, frames, times, status
): ):
""" """
音频处理 音频处理
""" """
global flag_vc
start_time = time.perf_counter() start_time = time.perf_counter()
indata = librosa.to_mono(indata.T) indata = librosa.to_mono(indata.T)
if self.gui_config.threhold > -60: if self.gui_config.threhold > -60:
@ -908,66 +904,52 @@ if __name__ == "__main__":
infer_wav[: self.block_frame].repeat(2, 1).t().cpu().numpy() infer_wav[: self.block_frame].repeat(2, 1).t().cpu().numpy()
) )
total_time = time.perf_counter() - start_time total_time = time.perf_counter() - start_time
self.window["infer_time"].update(int(total_time * 1000)) if flag_vc:
self.window["infer_time"].update(int(total_time * 1000))
printt("Infer time: %.2f", total_time) printt("Infer time: %.2f", total_time)
def get_devices(self, update: bool = True): def update_devices(self):
"""获取设备列表""" """获取设备列表"""
if update: sd._terminate()
sd._terminate() sd._initialize()
sd._initialize()
devices = sd.query_devices() devices = sd.query_devices()
hostapis = sd.query_hostapis() hostapis = sd.query_hostapis()
for hostapi in hostapis: for hostapi in hostapis:
for device_idx in hostapi["devices"]: for device_idx in hostapi["devices"]:
devices[device_idx]["hostapi_name"] = hostapi["name"] devices[device_idx]["hostapi_name"] = hostapi["name"]
input_devices = [ self.input_devices = [
f"{d['name']} ({d['hostapi_name']})" f"{d['name']} ({d['hostapi_name']})"
for d in devices for d in devices
if d["max_input_channels"] > 0 if d["max_input_channels"] > 0
] ]
output_devices = [ self.output_devices = [
f"{d['name']} ({d['hostapi_name']})" f"{d['name']} ({d['hostapi_name']})"
for d in devices for d in devices
if d["max_output_channels"] > 0 if d["max_output_channels"] > 0
] ]
input_devices_indices = [ self.input_devices_indices = [
d["index"] if "index" in d else d["name"] d["index"] if "index" in d else d["name"]
for d in devices for d in devices
if d["max_input_channels"] > 0 if d["max_input_channels"] > 0
] ]
output_devices_indices = [ self.output_devices_indices = [
d["index"] if "index" in d else d["name"] d["index"] if "index" in d else d["name"]
for d in devices for d in devices
if d["max_output_channels"] > 0 if d["max_output_channels"] > 0
] ]
return (
input_devices,
output_devices,
input_devices_indices,
output_devices_indices,
)
def set_devices(self, input_device, output_device): def set_devices(self, input_device, output_device):
"""设置输出设备""" """设置输出设备"""
( sd.default.device[0] = self.input_devices_indices[
input_devices, self.input_devices.index(input_device)
output_devices,
input_device_indices,
output_device_indices,
) = self.get_devices()
sd.default.device[0] = input_device_indices[
input_devices.index(input_device)
] ]
sd.default.device[1] = output_device_indices[ sd.default.device[1] = self.output_devices_indices[
output_devices.index(output_device) self.output_devices.index(output_device)
] ]
printt("Input device: %s:%s", str(sd.default.device[0]), input_device) printt("Input device: %s:%s", str(sd.default.device[0]), input_device)
printt("Output device: %s:%s", str(sd.default.device[1]), output_device) printt("Output device: %s:%s", str(sd.default.device[1]), output_device)
def get_device_samplerate(self): def get_device_samplerate(self):
return int( return int(sd.query_devices(device=sd.default.device[0])['default_samplerate'])
sd.query_devices(device=sd.default.device[0])["default_samplerate"]
)
gui = GUI() gui = GUI()