mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 12:18:58 +08:00
Format code (#932)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
296905983a
commit
9a20c3b28f
@ -234,16 +234,12 @@ def get_vc(model_path):
|
||||
version = cpt.get("version", "v1")
|
||||
if version == "v1":
|
||||
if if_f0 == 1:
|
||||
net_g = SynthesizerTrnMs256NSFsid(
|
||||
*cpt["config"], is_half=is_half
|
||||
)
|
||||
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
|
||||
else:
|
||||
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
||||
elif version == "v2":
|
||||
if if_f0 == 1:
|
||||
net_g = SynthesizerTrnMs768NSFsid(
|
||||
*cpt["config"], is_half=is_half
|
||||
)
|
||||
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half)
|
||||
else:
|
||||
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||||
del net_g.enc_q
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def cal_cross_attn(to_q, to_k, to_v, rand_input):
|
||||
hidden_dim, embed_dim = to_q.shape
|
||||
attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
|
||||
@ -16,21 +17,27 @@ def cal_cross_attn(to_q, to_k, to_v, rand_input):
|
||||
|
||||
return torch.einsum(
|
||||
"ik, jk -> ik",
|
||||
F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1),
|
||||
attn_to_v(rand_input)
|
||||
F.softmax(
|
||||
torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)),
|
||||
dim=-1,
|
||||
),
|
||||
attn_to_v(rand_input),
|
||||
)
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
m.update(file.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
except FileNotFoundError:
|
||||
return 'NOFILE'
|
||||
return "NOFILE"
|
||||
|
||||
|
||||
def eval(model, n, input):
|
||||
qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight"
|
||||
@ -41,6 +48,7 @@ def eval(model, n, input):
|
||||
attn = cal_cross_attn(atoq, atok, atov, input)
|
||||
return attn
|
||||
|
||||
|
||||
def main(path, root):
|
||||
torch.manual_seed(114514)
|
||||
model_a = torch.load(path, map_location="cpu")["weight"]
|
||||
@ -50,7 +58,9 @@ def main(path,root):
|
||||
map_attn_a = {}
|
||||
map_rand_input = {}
|
||||
for n in range(6):
|
||||
hidden_dim, embed_dim,_ = model_a[f"enc_p.encoder.attn_layers.{n}.conv_v.weight"].shape
|
||||
hidden_dim, embed_dim, _ = model_a[
|
||||
f"enc_p.encoder.attn_layers.{n}.conv_v.weight"
|
||||
].shape
|
||||
rand_input = torch.randn([embed_dim, hidden_dim])
|
||||
|
||||
map_attn_a[n] = eval(model_a, n, rand_input)
|
||||
@ -70,7 +80,11 @@ def main(path,root):
|
||||
sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
|
||||
sims.append(sim)
|
||||
|
||||
print("reference:\t%s\t%s\t%s"%(path,model_hash(path),f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%"))
|
||||
print(
|
||||
"reference:\t%s\t%s\t%s"
|
||||
% (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
query_path = r"weights\mi v3.pth"
|
||||
|
Loading…
x
Reference in New Issue
Block a user