From 2140875115fc75486d698b41d375a7137caca604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Mon, 3 Jun 2024 00:20:47 +0900 Subject: [PATCH] feat(web): add hash_similarity calculating --- infer-web.py | 32 +++++++++++++++++++++++++++----- infer/modules/vc/__init__.py | 2 +- infer/modules/vc/hash.py | 2 +- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/infer-web.py b/infer-web.py index 2760abe..02c492e 100644 --- a/infer-web.py +++ b/infer-web.py @@ -10,7 +10,7 @@ load_dotenv("sha256.env") if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -from infer.modules.vc import VC, show_info +from infer.modules.vc import VC, show_info, hash_similarity from infer.modules.uvr5.modules import uvr from infer.lib.train.process_ckpt import ( change_info, @@ -1417,7 +1417,29 @@ with gr.Blocks(title="RVC WebUI") as app: with gr.TabItem(i18n("ckpt处理")): with gr.Group(): - gr.Markdown(value=i18n("模型融合, 可用于测试音色融合")) + gr.Markdown(value=i18n("### 模型比较\n> 模型ID(长)请于下方`查看模型信息`中获得\n\n可用于比较两模型推理相似度")) + with gr.Row(): + with gr.Column(): + id_a = gr.Textbox( + label=i18n("A模型ID(长)"), value="" + ) + id_b = gr.Textbox( + label=i18n("B模型ID(长)"), value="" + ) + with gr.Column(): + butmodelcmp = gr.Button(i18n("计算"), variant="primary") + infomodelcmp = gr.Textbox(label=i18n("相似度(0到1)"), value="", max_lines=8) + butmodelcmp.click( + hash_similarity, + [ + id_a, + id_b, + ], + infomodelcmp, + api_name="ckpt_merge", + ) + with gr.Group(): + gr.Markdown(value=i18n("### 模型融合\n可用于测试音色融合")) with gr.Row(): ckpt_a = gr.Textbox( label=i18n("A模型路径"), value="", interactive=True @@ -1483,7 +1505,7 @@ with gr.Blocks(title="RVC WebUI") as app: ) # def merge(path1,path2,alpha1,sr,f0,info): with gr.Group(): gr.Markdown( - value=i18n("修改模型信息(仅支持weights文件夹下提取的小模型文件)") + value=i18n("### 修改模型信息\n> 仅支持weights文件夹下提取的小模型文件") ) with gr.Row(): ckpt_path0 = gr.Textbox( @@ -1512,7 +1534,7 @@ with gr.Blocks(title="RVC WebUI") as app: ) with gr.Group(): gr.Markdown( - value=i18n("查看模型信息(仅支持weights文件夹下提取的小模型文件)") + value=i18n("### 查看模型信息\n> 仅支持weights文件夹下提取的小模型文件") ) with gr.Row(): ckpt_path1 = gr.Textbox( @@ -1524,7 +1546,7 @@ with gr.Blocks(title="RVC WebUI") as app: with gr.Group(): gr.Markdown( value=i18n( - "模型提取(输入logs文件夹下大文件模型路径),适用于训一半不想训了模型没有自动提取保存小文件模型,或者想测试中间模型的情况" + "### 模型提取\n> 输入logs文件夹下大文件模型路径\n\n适用于训一半不想训了模型没有自动提取保存小文件模型, 或者想测试中间模型的情况" ) ) with gr.Row(): diff --git a/infer/modules/vc/__init__.py b/infer/modules/vc/__init__.py index 01141ef..06b6084 100644 --- a/infer/modules/vc/__init__.py +++ b/infer/modules/vc/__init__.py @@ -2,4 +2,4 @@ from .pipeline import Pipeline from .modules import VC from .utils import get_index_path_from_model, load_hubert from .info import show_info -from .hash import model_hash_ckpt, hash_id +from .hash import model_hash_ckpt, hash_id, hash_similarity diff --git a/infer/modules/vc/hash.py b/infer/modules/vc/hash.py index 325a15f..ecef1f3 100644 --- a/infer/modules/vc/hash.py +++ b/infer/modules/vc/hash.py @@ -189,7 +189,7 @@ def _extend_difference(n, a, b): return n -def hash_similarity(h1: str, h2: str) -> int: +def hash_similarity(h1: str, h2: str) -> float: h1b, h2b = decode_from_string(h1), decode_from_string(h2) if len(h1b) != half_hash_len * 2 or len(h2b) != half_hash_len * 2: raise Exception("invalid hash length")