From 8767e11cf1bb60a923a71a12665b30ac76ab8e8e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 2 Jun 2024 22:49:59 +0900 Subject: [PATCH] chore(format): run black on dev (#2090) Co-authored-by: github-actions[bot] --- infer-web.py | 60 ++++++++-------- infer/lib/train/process_ckpt.py | 1 + infer/modules/vc/hash.py | 119 ++++++++++++++++++++++---------- infer/modules/vc/info.py | 28 ++++++-- infer/modules/vc/modules.py | 5 +- infer/modules/vc/pipeline.py | 1 + 6 files changed, 143 insertions(+), 71 deletions(-) diff --git a/infer-web.py b/infer-web.py index 4905990..2760abe 100644 --- a/infer-web.py +++ b/infer-web.py @@ -847,10 +847,7 @@ with gr.Blocks(title="RVC WebUI") as app: value=0, ) input_audio0 = gr.File( - label=i18n( - "待处理音频文件" - ), - file_types=["audio"] + label=i18n("待处理音频文件"), file_types=["audio"] ) file_index2 = gr.Dropdown( label=i18n("自动检测index路径,下拉式选择(dropdown)"), @@ -937,28 +934,28 @@ with gr.Blocks(title="RVC WebUI") as app: api_name="infer_refresh", ) with gr.Group(): - vc_output1 = gr.Textbox(label=i18n("输出信息")) + vc_output1 = gr.Textbox(label=i18n("输出信息")) - but0.click( - vc.vc_single, - [ - spk_item, - input_audio0, - vc_transform0, - f0_file, - f0method0, - file_index1, - file_index2, - # file_big_npy1, - index_rate1, - filter_radius0, - resample_sr0, - rms_mix_rate0, - protect0, - ], - [vc_output1, vc_output2], - api_name="infer_convert", - ) + but0.click( + vc.vc_single, + [ + spk_item, + input_audio0, + vc_transform0, + f0_file, + f0method0, + file_index1, + file_index2, + # file_big_npy1, + index_rate1, + filter_radius0, + resample_sr0, + rms_mix_rate0, + protect0, + ], + [vc_output1, vc_output2], + api_name="infer_convert", + ) with gr.TabItem(i18n("批量推理")): gr.Markdown( value=i18n( @@ -990,9 +987,7 @@ with gr.Blocks(title="RVC WebUI") as app: interactive=True, ) file_index3 = gr.File( - label=i18n( - "特征检索库文件路径,为空则使用下拉的选择结果" - ), + label=i18n("特征检索库文件路径,为空则使用下拉的选择结果"), ) refresh_button.click( @@ -1099,7 +1094,14 @@ with gr.Blocks(title="RVC WebUI") as app: sid0.change( fn=vc.get_vc, inputs=[sid0, protect0, protect1], - outputs=[spk_item, protect0, protect1, file_index2, file_index4, modelinfo], + outputs=[ + spk_item, + protect0, + protect1, + file_index2, + file_index4, + modelinfo, + ], api_name="infer_change_voice", ) with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")): diff --git a/infer/lib/train/process_ckpt.py b/infer/lib/train/process_ckpt.py index 23377f6..47a0bc7 100644 --- a/infer/lib/train/process_ckpt.py +++ b/infer/lib/train/process_ckpt.py @@ -10,6 +10,7 @@ from infer.modules.vc import model_hash_ckpt, hash_id i18n = I18nAuto() + # add author sign def save_small_model(ckpt, sr, if_f0, name, epoch, version, hps): try: diff --git a/infer/modules/vc/hash.py b/infer/modules/vc/hash.py index 8e8334d..325a15f 100644 --- a/infer/modules/vc/hash.py +++ b/infer/modules/vc/hash.py @@ -7,6 +7,7 @@ from pybase16384 import encode_to_string, decode_from_string if __name__ == "__main__": import os, sys + now_dir = os.getcwd() sys.path.append(now_dir) @@ -17,6 +18,7 @@ from .utils import load_hubert from infer.lib.audio import load_audio + class TorchSeedContext: def __init__(self, seed): self.seed = seed @@ -29,27 +31,38 @@ class TorchSeedContext: def __exit__(self, type, value, traceback): torch.random.set_rng_state(self.state) + half_hash_len = 512 -expand_factor = 65536*8 +expand_factor = 65536 * 8 + @singleton_variable def original_audio_time_minus(): - __original_audio = load_audio(str(pathlib.Path(__file__).parent / "lgdsng.mp3"), 16000) + __original_audio = load_audio( + str(pathlib.Path(__file__).parent / "lgdsng.mp3"), 16000 + ) np.divide(__original_audio, np.abs(__original_audio).max(), __original_audio) return -__original_audio + @singleton_variable def original_audio_freq_minus(): - __original_audio = load_audio(str(pathlib.Path(__file__).parent / "lgdsng.mp3"), 16000) + __original_audio = load_audio( + str(pathlib.Path(__file__).parent / "lgdsng.mp3"), 16000 + ) np.divide(__original_audio, np.abs(__original_audio).max(), __original_audio) __original_audio = fft(__original_audio) return -__original_audio + def _cut_u16(n): - if n > 16384: n = 16384 + 16384*(1-np.exp((16384-n)/expand_factor)) - elif n < -16384: n = -16384 - 16384*(1-np.exp((n+16384)/expand_factor)) + if n > 16384: + n = 16384 + 16384 * (1 - np.exp((16384 - n) / expand_factor)) + elif n < -16384: + n = -16384 - 16384 * (1 - np.exp((n + 16384) / expand_factor)) return n + # wave_hash will change time_field, use carefully def wave_hash(time_field): np.divide(time_field, np.abs(time_field).max(), time_field) @@ -60,35 +73,56 @@ def wave_hash(time_field): raise Exception("freq not hashable") np.add(time_field, original_audio_time_minus(), out=time_field) np.add(freq_field, original_audio_freq_minus(), out=freq_field) - hash = np.zeros(half_hash_len//2*2, dtype='>i2') + hash = np.zeros(half_hash_len // 2 * 2, dtype=">i2") d = 375 * 512 // half_hash_len - for i in range(half_hash_len//4): - a = i*2 - b = a+1 - x = a + half_hash_len//2 - y = x+1 - s = np.average(freq_field[i*d:(i+1)*d]) - hash[a] = np.int16(_cut_u16(round(32768*np.real(s)))) - hash[b] = np.int16(_cut_u16(round(32768*np.imag(s)))) - hash[x] = np.int16(_cut_u16(round(32768*np.sum(time_field[i*d:i*d+d//2])))) - hash[y] = np.int16(_cut_u16(round(32768*np.sum(time_field[i*d+d//2:(i+1)*d])))) + for i in range(half_hash_len // 4): + a = i * 2 + b = a + 1 + x = a + half_hash_len // 2 + y = x + 1 + s = np.average(freq_field[i * d : (i + 1) * d]) + hash[a] = np.int16(_cut_u16(round(32768 * np.real(s)))) + hash[b] = np.int16(_cut_u16(round(32768 * np.imag(s)))) + hash[x] = np.int16( + _cut_u16(round(32768 * np.sum(time_field[i * d : i * d + d // 2]))) + ) + hash[y] = np.int16( + _cut_u16(round(32768 * np.sum(time_field[i * d + d // 2 : (i + 1) * d]))) + ) return encode_to_string(hash.tobytes()) + def audio_hash(file): return wave_hash(load_audio(file, 16000)) + def model_hash(config, tgt_sr, net_g, if_f0, version): pipeline = Pipeline(tgt_sr, config) audio = load_audio(str(pathlib.Path(__file__).parent / "lgdsng.mp3"), 16000) audio_max = np.abs(audio).max() / 0.95 if audio_max > 1: np.divide(audio, audio_max, audio) - audio_opt = pipeline.pipeline(load_hubert(config.device, config.is_half), net_g, 0, audio, - [0, 0, 0], 6, "rmvpe", "", 0, if_f0, 3, tgt_sr, 16000, 0.25, - version, 0.33) + audio_opt = pipeline.pipeline( + load_hubert(config.device, config.is_half), + net_g, + 0, + audio, + [0, 0, 0], + 6, + "rmvpe", + "", + 0, + if_f0, + 3, + tgt_sr, + 16000, + 0.25, + version, + 0.33, + ) opt_len = len(audio_opt) diff = 48000 - opt_len - n = diff//2 + n = diff // 2 if n > 0: audio_opt = np.pad(audio_opt, (n, n)) elif n < 0: @@ -98,6 +132,7 @@ def model_hash(config, tgt_sr, net_g, if_f0, version): del pipeline, audio, audio_opt return h + def model_hash_ckpt(cpt): from infer.lib.infer_pack.models import ( SynthesizerTrnMs256NSFsid, @@ -105,6 +140,7 @@ def model_hash_ckpt(cpt): SynthesizerTrnMs768NSFsid, SynthesizerTrnMs768NSFsid_nono, ) + config = Config() with TorchSeedContext(114514): tgt_sr = cpt["config"][-1] @@ -116,9 +152,9 @@ def model_hash_ckpt(cpt): ("v2", 1): SynthesizerTrnMs768NSFsid, ("v2", 0): SynthesizerTrnMs768NSFsid_nono, } - net_g = synthesizer_class.get( - (version, if_f0), SynthesizerTrnMs256NSFsid - )(*cpt["config"], is_half=config.is_half) + net_g = synthesizer_class.get((version, if_f0), SynthesizerTrnMs256NSFsid)( + *cpt["config"], is_half=config.is_half + ) del net_g.enc_q @@ -128,43 +164,54 @@ def model_hash_ckpt(cpt): net_g = net_g.half() else: net_g = net_g.float() - + h = model_hash(config, tgt_sr, net_g, if_f0, version) del net_g return h + def model_hash_from(path): cpt = torch.load(path, map_location="cpu") h = model_hash_ckpt(cpt) del cpt return h + def _extend_difference(n, a, b): - if n < a: n = a - elif n > b: n = b + if n < a: + n = a + elif n > b: + n = b n -= a - n /= (b-a) + n /= b - a return n + def hash_similarity(h1: str, h2: str) -> int: h1b, h2b = decode_from_string(h1), decode_from_string(h2) - if len(h1b) != half_hash_len*2 or len(h2b) != half_hash_len*2: + if len(h1b) != half_hash_len * 2 or len(h2b) != half_hash_len * 2: raise Exception("invalid hash length") - h1n, h2n = np.frombuffer(h1b, dtype='>i2'), np.frombuffer(h2b, dtype='>i2') + h1n, h2n = np.frombuffer(h1b, dtype=">i2"), np.frombuffer(h2b, dtype=">i2") d = 0 - for i in range(half_hash_len//4): - a = i*2 - b = a+1 + for i in range(half_hash_len // 4): + a = i * 2 + b = a + 1 ax = complex(h1n[a], h1n[b]) bx = complex(h2n[a], h2n[b]) - if abs(ax) == 0 or abs(bx) == 0: continue + if abs(ax) == 0 or abs(bx) == 0: + continue d += np.abs(ax - bx) - frac = (np.linalg.norm(h1n) * np.linalg.norm(h2n)) - cosine = np.dot(h1n.astype(np.float32), h2n.astype(np.float32)) / frac if frac != 0 else 1.0 - distance = _extend_difference(np.exp(-d/expand_factor), 0.5, 1.0) + frac = np.linalg.norm(h1n) * np.linalg.norm(h2n) + cosine = ( + np.dot(h1n.astype(np.float32), h2n.astype(np.float32)) / frac + if frac != 0 + else 1.0 + ) + distance = _extend_difference(np.exp(-d / expand_factor), 0.5, 1.0) return round((abs(cosine) + distance) / 2, 6) + def hash_id(h: str) -> str: return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1] diff --git a/infer/modules/vc/info.py b/infer/modules/vc/info.py index f2bf320..a7647bb 100644 --- a/infer/modules/vc/info.py +++ b/infer/modules/vc/info.py @@ -7,6 +7,7 @@ from .hash import model_hash_ckpt, hash_id i18n = I18nAuto() + def show_model_info(cpt, show_long_id=False): try: h = model_hash_ckpt(cpt) @@ -14,10 +15,27 @@ def show_model_info(cpt, show_long_id=False): idread = cpt.get("id", "None") hread = cpt.get("hash", "None") if id != idread: - id += "("+i18n("实际计算")+"), "+idread+"("+i18n("从模型中读取")+")" - if not show_long_id: h = i18n("不显示") + id += ( + "(" + + i18n("实际计算") + + "), " + + idread + + "(" + + i18n("从模型中读取") + + ")" + ) + if not show_long_id: + h = i18n("不显示") elif h != hread: - h += "("+i18n("实际计算")+"), "+hread+"("+i18n("从模型中读取")+")" + h += ( + "(" + + i18n("实际计算") + + "), " + + hread + + "(" + + i18n("从模型中读取") + + ")" + ) txt = f"""{i18n("模型名")}: %s {i18n("封装时间")}: %s {i18n("信息")}: %s @@ -32,13 +50,15 @@ def show_model_info(cpt, show_long_id=False): cpt.get("sr", "None"), i18n("有") if cpt.get("f0", 0) == 1 else i18n("无"), cpt.get("version", "None"), - id, h + id, + h, ) except: txt = traceback.format_exc() return txt + def show_info(path): try: a = torch.load(path, map_location="cpu") diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index 0bc35be..5db22e7 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -136,7 +136,7 @@ class VC: to_return_protect1, index, index, - show_model_info(self.cpt) + show_model_info(self.cpt), ) if to_return_protect else {"visible": True, "maximum": n_spk, "__type__": "update"} @@ -173,7 +173,8 @@ class VC: self.hubert_model = load_hubert(self.config.device, self.config.is_half) if file_index: - if hasattr(file_index, "name"): file_index = str(file_index.name) + if hasattr(file_index, "name"): + file_index = str(file_index.name) file_index = ( file_index.strip(" ") .strip('"') diff --git a/infer/modules/vc/pipeline.py b/infer/modules/vc/pipeline.py index b5baf06..a0cde7d 100644 --- a/infer/modules/vc/pipeline.py +++ b/infer/modules/vc/pipeline.py @@ -114,6 +114,7 @@ class Pipeline(object): ) elif f0_method == "harvest": from hashlib import md5 + f0_cache_key = md5(x.tobytes()).digest() input_audio_path2wav[f0_cache_key] = x.astype(np.double) f0 = cache_harvest_f0(f0_cache_key, self.sr, f0_max, f0_min, 10)