chore(format): run black on dev (#1731)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot] 2024-01-16 20:32:37 +09:00 committed by GitHub
parent 4e8e235024
commit 8d5c34601b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 36 deletions

View File

@ -158,7 +158,10 @@ if __name__ == "__main__":
data["fcpe"] = data["f0method"] == "fcpe" data["fcpe"] = data["f0method"] == "fcpe"
if data["sg_hostapi"] in self.hostapis: if data["sg_hostapi"] in self.hostapis:
self.update_devices(hostapi_name=data["sg_hostapi"]) self.update_devices(hostapi_name=data["sg_hostapi"])
if data["sg_input_device"] not in self.input_devices or data["sg_output_device"] not in self.output_devices: if (
data["sg_input_device"] not in self.input_devices
or data["sg_output_device"] not in self.output_devices
):
self.update_devices() self.update_devices()
data["sg_hostapi"] = self.hostapis[0] data["sg_hostapi"] = self.hostapis[0]
data["sg_input_device"] = self.input_devices[ data["sg_input_device"] = self.input_devices[
@ -256,7 +259,7 @@ if __name__ == "__main__":
key="sg_hostapi", key="sg_hostapi",
default_value=data.get("sg_hostapi", ""), default_value=data.get("sg_hostapi", ""),
enable_events=True, enable_events=True,
size=(20, 1) size=(20, 1),
), ),
sg.Checkbox( sg.Checkbox(
i18n("独占 WASAPI 设备"), i18n("独占 WASAPI 设备"),
@ -523,9 +526,7 @@ if __name__ == "__main__":
if self.gui_config.sg_hostapi not in self.hostapis: if self.gui_config.sg_hostapi not in self.hostapis:
self.gui_config.sg_hostapi = self.hostapis[0] self.gui_config.sg_hostapi = self.hostapis[0]
self.window["sg_hostapi"].Update(values=self.hostapis) self.window["sg_hostapi"].Update(values=self.hostapis)
self.window["sg_hostapi"].Update( self.window["sg_hostapi"].Update(value=self.gui_config.sg_hostapi)
value=self.gui_config.sg_hostapi
)
if self.gui_config.sg_input_device not in self.input_devices: if self.gui_config.sg_input_device not in self.input_devices:
self.gui_config.sg_input_device = self.input_devices[0] self.gui_config.sg_input_device = self.input_devices[0]
self.window["sg_input_device"].Update(values=self.input_devices) self.window["sg_input_device"].Update(values=self.input_devices)
@ -589,7 +590,9 @@ if __name__ == "__main__":
if values["I_noise_reduce"]: if values["I_noise_reduce"]:
self.delay_time += min(values["crossfade_length"], 0.04) self.delay_time += min(values["crossfade_length"], 0.04)
self.window["sr_stream"].update(self.gui_config.samplerate) self.window["sr_stream"].update(self.gui_config.samplerate)
self.window["delay_time"].update(int(np.round(self.delay_time * 1000))) self.window["delay_time"].update(
int(np.round(self.delay_time * 1000))
)
# Parameter hot update # Parameter hot update
if event == "threhold": if event == "threhold":
self.gui_config.threhold = values["threhold"] self.gui_config.threhold = values["threhold"]
@ -611,7 +614,9 @@ if __name__ == "__main__":
self.delay_time += ( self.delay_time += (
1 if values["I_noise_reduce"] else -1 1 if values["I_noise_reduce"] else -1
) * min(values["crossfade_length"], 0.04) ) * min(values["crossfade_length"], 0.04)
self.window["delay_time"].update(int(np.round(self.delay_time * 1000))) self.window["delay_time"].update(
int(np.round(self.delay_time * 1000))
)
elif event == "O_noise_reduce": elif event == "O_noise_reduce":
self.gui_config.O_noise_reduce = values["O_noise_reduce"] self.gui_config.O_noise_reduce = values["O_noise_reduce"]
elif event == "use_pv": elif event == "use_pv":
@ -787,7 +792,10 @@ if __name__ == "__main__":
global flag_vc global flag_vc
if not flag_vc: if not flag_vc:
flag_vc = True flag_vc = True
if 'WASAPI' in self.gui_config.sg_hostapi and self.gui_config.sg_wasapi_exclusive: if (
"WASAPI" in self.gui_config.sg_hostapi
and self.gui_config.sg_wasapi_exclusive
):
extra_settings = sd.WasapiSettings(exclusive=True) extra_settings = sd.WasapiSettings(exclusive=True)
else: else:
extra_settings = None extra_settings = None
@ -847,9 +855,7 @@ if __name__ == "__main__":
self.input_wav_denoise[: -self.block_frame] = self.input_wav_denoise[ self.input_wav_denoise[: -self.block_frame] = self.input_wav_denoise[
self.block_frame : self.block_frame :
].clone() ].clone()
input_wav = self.input_wav[ input_wav = self.input_wav[-self.sola_buffer_frame - self.block_frame :]
-self.sola_buffer_frame - self.block_frame:
]
input_wav = self.tg( input_wav = self.tg(
input_wav.unsqueeze(0), self.input_wav.unsqueeze(0) input_wav.unsqueeze(0), self.input_wav.unsqueeze(0)
).squeeze(0) ).squeeze(0)
@ -857,15 +863,19 @@ if __name__ == "__main__":
input_wav[: self.sola_buffer_frame] += ( input_wav[: self.sola_buffer_frame] += (
self.nr_buffer * self.fade_out_window self.nr_buffer * self.fade_out_window
) )
self.input_wav_denoise[-self.block_frame :] = input_wav[: self.block_frame] self.input_wav_denoise[-self.block_frame :] = input_wav[
: self.block_frame
]
self.nr_buffer[:] = input_wav[self.block_frame :] self.nr_buffer[:] = input_wav[self.block_frame :]
self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler( self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler(
self.input_wav_denoise[-self.block_frame - 2 * self.zc :] self.input_wav_denoise[-self.block_frame - 2 * self.zc :]
)[160:] )[160:]
else: else:
self.input_wav_res[-160 * (indata.shape[0] // self.zc + 1) :] = self.resampler( self.input_wav_res[
self.input_wav[-indata.shape[0] - 2 * self.zc :] -160 * (indata.shape[0] // self.zc + 1) :
)[160:] ] = self.resampler(self.input_wav[-indata.shape[0] - 2 * self.zc :])[
160:
]
# infer # infer
if self.function == "vc": if self.function == "vc":
infer_wav = self.rvc.infer( infer_wav = self.rvc.infer(
@ -959,13 +969,17 @@ if __name__ == "__main__":
self.block_frame : self.block_frame + self.sola_buffer_frame self.block_frame : self.block_frame + self.sola_buffer_frame
] ]
outdata[:] = ( outdata[:] = (
infer_wav[: self.block_frame].repeat(self.gui_config.channels, 1).t().cpu().numpy() infer_wav[: self.block_frame]
.repeat(self.gui_config.channels, 1)
.t()
.cpu()
.numpy()
) )
total_time = time.perf_counter() - start_time total_time = time.perf_counter() - start_time
if flag_vc: if flag_vc:
self.window["infer_time"].update(int(total_time * 1000)) self.window["infer_time"].update(int(total_time * 1000))
printt("Infer time: %.2f", total_time) printt("Infer time: %.2f", total_time)
def update_devices(self, hostapi_name=None): def update_devices(self, hostapi_name=None):
"""获取设备列表""" """获取设备列表"""
global flag_vc global flag_vc
@ -981,22 +995,24 @@ if __name__ == "__main__":
if hostapi_name not in self.hostapis: if hostapi_name not in self.hostapis:
hostapi_name = self.hostapis[0] hostapi_name = self.hostapis[0]
self.input_devices = [ self.input_devices = [
d['name'] for d in devices d["name"]
if d["max_input_channels"] > 0 and d['hostapi_name'] == hostapi_name for d in devices
if d["max_input_channels"] > 0 and d["hostapi_name"] == hostapi_name
] ]
self.output_devices = [ self.output_devices = [
d['name'] for d in devices d["name"]
if d["max_output_channels"] > 0 and d['hostapi_name'] == hostapi_name for d in devices
if d["max_output_channels"] > 0 and d["hostapi_name"] == hostapi_name
] ]
self.input_devices_indices = [ self.input_devices_indices = [
d["index"] if "index" in d else d["name"] d["index"] if "index" in d else d["name"]
for d in devices for d in devices
if d["max_input_channels"] > 0 and d['hostapi_name'] == hostapi_name if d["max_input_channels"] > 0 and d["hostapi_name"] == hostapi_name
] ]
self.output_devices_indices = [ self.output_devices_indices = [
d["index"] if "index" in d else d["name"] d["index"] if "index" in d else d["name"]
for d in devices for d in devices
if d["max_output_channels"] > 0 and d['hostapi_name'] == hostapi_name if d["max_output_channels"] > 0 and d["hostapi_name"] == hostapi_name
] ]
def set_devices(self, input_device, output_device): def set_devices(self, input_device, output_device):
@ -1014,10 +1030,14 @@ if __name__ == "__main__":
return int( return int(
sd.query_devices(device=sd.default.device[0])["default_samplerate"] sd.query_devices(device=sd.default.device[0])["default_samplerate"]
) )
def get_device_channels(self): def get_device_channels(self):
max_input_channels = sd.query_devices(device=sd.default.device[0])["max_input_channels"] max_input_channels = sd.query_devices(device=sd.default.device[0])[
max_output_channels = sd.query_devices(device=sd.default.device[1])["max_output_channels"] "max_input_channels"
]
max_output_channels = sd.query_devices(device=sd.default.device[1])[
"max_output_channels"
]
return min(max_input_channels, max_output_channels, 2) return min(max_input_channels, max_output_channels, 2)
gui = GUI() gui = GUI()

View File

@ -117,12 +117,7 @@ def main():
children[i].join() children[i].join()
def run( def run(rank, n_gpus, hps, logger: logging.Logger):
rank,
n_gpus,
hps,
logger: logging.Logger
):
global global_step global global_step
if rank == 0: if rank == 0:
# logger = utils.get_logger(hps.model_dir) # logger = utils.get_logger(hps.model_dir)

View File

@ -91,8 +91,12 @@ class RVC:
self.pth_path: str = pth_path self.pth_path: str = pth_path
self.index_path = index_path self.index_path = index_path
self.index_rate = index_rate self.index_rate = index_rate
self.cache_pitch: torch.Tensor = torch.zeros(1024, device=self.device, dtype=torch.long) self.cache_pitch: torch.Tensor = torch.zeros(
self.cache_pitchf = torch.zeros(1024, device=self.device, dtype=torch.float32) 1024, device=self.device, dtype=torch.long
)
self.cache_pitchf = torch.zeros(
1024, device=self.device, dtype=torch.float32
)
if last_rvc is None: if last_rvc is None:
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
@ -307,6 +311,7 @@ class RVC:
def get_f0_rmvpe(self, x, f0_up_key): def get_f0_rmvpe(self, x, f0_up_key):
if hasattr(self, "model_rmvpe") == False: if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE from infer.lib.rmvpe import RMVPE
printt("Loading rmvpe model") printt("Loading rmvpe model")
self.model_rmvpe = RMVPE( self.model_rmvpe = RMVPE(
"assets/rmvpe/rmvpe.pt", "assets/rmvpe/rmvpe.pt",
@ -391,8 +396,8 @@ class RVC:
input_wav[-f0_extractor_frame:], self.f0_up_key, self.n_cpu, f0method input_wav[-f0_extractor_frame:], self.f0_up_key, self.n_cpu, f0method
) )
shift = block_frame_16k // 160 shift = block_frame_16k // 160
self.cache_pitch[: -shift] = self.cache_pitch[shift :].clone() self.cache_pitch[:-shift] = self.cache_pitch[shift:].clone()
self.cache_pitchf[: -shift] = self.cache_pitchf[shift :].clone() self.cache_pitchf[:-shift] = self.cache_pitchf[shift:].clone()
self.cache_pitch[4 - pitch.shape[0] :] = pitch[3:-1] self.cache_pitch[4 - pitch.shape[0] :] = pitch[3:-1]
self.cache_pitchf[4 - pitch.shape[0] :] = pitchf[3:-1] self.cache_pitchf[4 - pitch.shape[0] :] = pitchf[3:-1]
cache_pitch = self.cache_pitch[None, -p_len:] cache_pitch = self.cache_pitch[None, -p_len:]