feat(web): add hash_similarity calculating

This commit is contained in:
源文雨 2024-06-03 00:20:47 +09:00
parent e4fff618bf
commit 2140875115
3 changed files with 29 additions and 7 deletions

View File

@ -10,7 +10,7 @@ load_dotenv("sha256.env")
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from infer.modules.vc import VC, show_info
from infer.modules.vc import VC, show_info, hash_similarity
from infer.modules.uvr5.modules import uvr
from infer.lib.train.process_ckpt import (
change_info,
@ -1417,7 +1417,29 @@ with gr.Blocks(title="RVC WebUI") as app:
with gr.TabItem(i18n("ckpt处理")):
with gr.Group():
gr.Markdown(value=i18n("模型融合, 可用于测试音色融合"))
gr.Markdown(value=i18n("### 模型比较\n> 模型ID(长)请于下方`查看模型信息`中获得\n\n可用于比较两模型推理相似度"))
with gr.Row():
with gr.Column():
id_a = gr.Textbox(
label=i18n("A模型ID(长)"), value=""
)
id_b = gr.Textbox(
label=i18n("B模型ID(长)"), value=""
)
with gr.Column():
butmodelcmp = gr.Button(i18n("计算"), variant="primary")
infomodelcmp = gr.Textbox(label=i18n("相似度(0到1)"), value="", max_lines=8)
butmodelcmp.click(
hash_similarity,
[
id_a,
id_b,
],
infomodelcmp,
api_name="ckpt_merge",
)
with gr.Group():
gr.Markdown(value=i18n("### 模型融合\n可用于测试音色融合"))
with gr.Row():
ckpt_a = gr.Textbox(
label=i18n("A模型路径"), value="", interactive=True
@ -1483,7 +1505,7 @@ with gr.Blocks(title="RVC WebUI") as app:
) # def merge(path1,path2,alpha1,sr,f0,info):
with gr.Group():
gr.Markdown(
value=i18n("修改模型信息(仅支持weights文件夹下提取的小模型文件)")
value=i18n("### 修改模型信息\n> 仅支持weights文件夹下提取的小模型文件")
)
with gr.Row():
ckpt_path0 = gr.Textbox(
@ -1512,7 +1534,7 @@ with gr.Blocks(title="RVC WebUI") as app:
)
with gr.Group():
gr.Markdown(
value=i18n("查看模型信息(仅支持weights文件夹下提取的小模型文件)")
value=i18n("### 查看模型信息\n> 仅支持weights文件夹下提取的小模型文件")
)
with gr.Row():
ckpt_path1 = gr.Textbox(
@ -1524,7 +1546,7 @@ with gr.Blocks(title="RVC WebUI") as app:
with gr.Group():
gr.Markdown(
value=i18n(
"模型提取(输入logs文件夹下大文件模型路径),适用于训一半不想训了模型没有自动提取保存小文件模型,或者想测试中间模型的情况"
"### 模型提取\n> 输入logs文件夹下大文件模型路径\n\n适用于训一半不想训了模型没有自动提取保存小文件模型, 或者想测试中间模型的情况"
)
)
with gr.Row():

View File

@ -2,4 +2,4 @@ from .pipeline import Pipeline
from .modules import VC
from .utils import get_index_path_from_model, load_hubert
from .info import show_info
from .hash import model_hash_ckpt, hash_id
from .hash import model_hash_ckpt, hash_id, hash_similarity

View File

@ -189,7 +189,7 @@ def _extend_difference(n, a, b):
return n
def hash_similarity(h1: str, h2: str) -> int:
def hash_similarity(h1: str, h2: str) -> float:
h1b, h2b = decode_from_string(h1), decode_from_string(h2)
if len(h1b) != half_hash_len * 2 or len(h2b) != half_hash_len * 2:
raise Exception("invalid hash length")