mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-01-01 20:45:04 +08:00
Format code (#932)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
296905983a
commit
9a20c3b28f
@ -234,16 +234,12 @@ def get_vc(model_path):
|
|||||||
version = cpt.get("version", "v1")
|
version = cpt.get("version", "v1")
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
if if_f0 == 1:
|
if if_f0 == 1:
|
||||||
net_g = SynthesizerTrnMs256NSFsid(
|
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
|
||||||
*cpt["config"], is_half=is_half
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
if if_f0 == 1:
|
if if_f0 == 1:
|
||||||
net_g = SynthesizerTrnMs768NSFsid(
|
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half)
|
||||||
*cpt["config"], is_half=is_half
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||||||
del net_g.enc_q
|
del net_g.enc_q
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
# This code references https://huggingface.co/JosephusCheung/ASimilarityCalculatior/blob/main/qwerty.py
|
# This code references https://huggingface.co/JosephusCheung/ASimilarityCalculatior/blob/main/qwerty.py
|
||||||
# Fill in the path of the model to be queried and the root directory of the reference models, and this script will return the similarity between the model to be queried and all reference models.
|
# Fill in the path of the model to be queried and the root directory of the reference models, and this script will return the similarity between the model to be queried and all reference models.
|
||||||
import sys,os
|
import sys, os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def cal_cross_attn(to_q, to_k, to_v, rand_input):
|
def cal_cross_attn(to_q, to_k, to_v, rand_input):
|
||||||
hidden_dim, embed_dim = to_q.shape
|
hidden_dim, embed_dim = to_q.shape
|
||||||
attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
|
attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
|
||||||
@ -16,41 +17,50 @@ def cal_cross_attn(to_q, to_k, to_v, rand_input):
|
|||||||
|
|
||||||
return torch.einsum(
|
return torch.einsum(
|
||||||
"ik, jk -> ik",
|
"ik, jk -> ik",
|
||||||
F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1),
|
F.softmax(
|
||||||
attn_to_v(rand_input)
|
torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)),
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
|
attn_to_v(rand_input),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def model_hash(filename):
|
def model_hash(filename):
|
||||||
try:
|
try:
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
|
|
||||||
file.seek(0x100000)
|
file.seek(0x100000)
|
||||||
m.update(file.read(0x10000))
|
m.update(file.read(0x10000))
|
||||||
return m.hexdigest()[0:8]
|
return m.hexdigest()[0:8]
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return 'NOFILE'
|
return "NOFILE"
|
||||||
|
|
||||||
|
|
||||||
def eval(model, n, input):
|
def eval(model, n, input):
|
||||||
qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight"
|
qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight"
|
||||||
uk = f"enc_p.encoder.attn_layers.{n}.conv_k.weight"
|
uk = f"enc_p.encoder.attn_layers.{n}.conv_k.weight"
|
||||||
vk = f"enc_p.encoder.attn_layers.{n}.conv_v.weight"
|
vk = f"enc_p.encoder.attn_layers.{n}.conv_v.weight"
|
||||||
atoq, atok, atov = model[qk][:,:,0], model[uk][:,:,0], model[vk][:,:,0]
|
atoq, atok, atov = model[qk][:, :, 0], model[uk][:, :, 0], model[vk][:, :, 0]
|
||||||
|
|
||||||
attn = cal_cross_attn(atoq, atok, atov, input)
|
attn = cal_cross_attn(atoq, atok, atov, input)
|
||||||
return attn
|
return attn
|
||||||
|
|
||||||
def main(path,root):
|
|
||||||
|
def main(path, root):
|
||||||
torch.manual_seed(114514)
|
torch.manual_seed(114514)
|
||||||
model_a = torch.load(path, map_location="cpu")["weight"]
|
model_a = torch.load(path, map_location="cpu")["weight"]
|
||||||
|
|
||||||
print("query:\t\t%s\t%s"%(path,model_hash(path)))
|
print("query:\t\t%s\t%s" % (path, model_hash(path)))
|
||||||
|
|
||||||
map_attn_a = {}
|
map_attn_a = {}
|
||||||
map_rand_input = {}
|
map_rand_input = {}
|
||||||
for n in range(6):
|
for n in range(6):
|
||||||
hidden_dim, embed_dim,_ = model_a[f"enc_p.encoder.attn_layers.{n}.conv_v.weight"].shape
|
hidden_dim, embed_dim, _ = model_a[
|
||||||
|
f"enc_p.encoder.attn_layers.{n}.conv_v.weight"
|
||||||
|
].shape
|
||||||
rand_input = torch.randn([embed_dim, hidden_dim])
|
rand_input = torch.randn([embed_dim, hidden_dim])
|
||||||
|
|
||||||
map_attn_a[n] = eval(model_a, n, rand_input)
|
map_attn_a[n] = eval(model_a, n, rand_input)
|
||||||
@ -59,7 +69,7 @@ def main(path,root):
|
|||||||
del model_a
|
del model_a
|
||||||
|
|
||||||
for name in sorted(list(os.listdir(root))):
|
for name in sorted(list(os.listdir(root))):
|
||||||
path="%s/%s"%(root,name)
|
path = "%s/%s" % (root, name)
|
||||||
model_b = torch.load(path, map_location="cpu")["weight"]
|
model_b = torch.load(path, map_location="cpu")["weight"]
|
||||||
|
|
||||||
sims = []
|
sims = []
|
||||||
@ -70,9 +80,13 @@ def main(path,root):
|
|||||||
sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
|
sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
|
||||||
sims.append(sim)
|
sims.append(sim)
|
||||||
|
|
||||||
print("reference:\t%s\t%s\t%s"%(path,model_hash(path),f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%"))
|
print(
|
||||||
|
"reference:\t%s\t%s\t%s"
|
||||||
|
% (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
query_path=r"weights\mi v3.pth"
|
query_path = r"weights\mi v3.pth"
|
||||||
reference_root=r"weights"
|
reference_root = r"weights"
|
||||||
main(query_path,reference_root)
|
main(query_path, reference_root)
|
||||||
|
Loading…
Reference in New Issue
Block a user