Format code (#932)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot] 2023-08-03 10:25:05 +08:00 committed by GitHub
parent 296905983a
commit 9a20c3b28f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 19 deletions

View File

@ -234,16 +234,12 @@ def get_vc(model_path):
version = cpt.get("version", "v1") version = cpt.get("version", "v1")
if version == "v1": if version == "v1":
if if_f0 == 1: if if_f0 == 1:
net_g = SynthesizerTrnMs256NSFsid( net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
*cpt["config"], is_half=is_half
)
else: else:
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif version == "v2": elif version == "v2":
if if_f0 == 1: if if_f0 == 1:
net_g = SynthesizerTrnMs768NSFsid( net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half)
*cpt["config"], is_half=is_half
)
else: else:
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del net_g.enc_q del net_g.enc_q

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def cal_cross_attn(to_q, to_k, to_v, rand_input): def cal_cross_attn(to_q, to_k, to_v, rand_input):
hidden_dim, embed_dim = to_q.shape hidden_dim, embed_dim = to_q.shape
attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False) attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
@ -16,21 +17,27 @@ def cal_cross_attn(to_q, to_k, to_v, rand_input):
return torch.einsum( return torch.einsum(
"ik, jk -> ik", "ik, jk -> ik",
F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1), F.softmax(
attn_to_v(rand_input) torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)),
dim=-1,
),
attn_to_v(rand_input),
) )
def model_hash(filename): def model_hash(filename):
try: try:
with open(filename, "rb") as file: with open(filename, "rb") as file:
import hashlib import hashlib
m = hashlib.sha256() m = hashlib.sha256()
file.seek(0x100000) file.seek(0x100000)
m.update(file.read(0x10000)) m.update(file.read(0x10000))
return m.hexdigest()[0:8] return m.hexdigest()[0:8]
except FileNotFoundError: except FileNotFoundError:
return 'NOFILE' return "NOFILE"
def eval(model, n, input): def eval(model, n, input):
qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight" qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight"
@ -41,6 +48,7 @@ def eval(model, n, input):
attn = cal_cross_attn(atoq, atok, atov, input) attn = cal_cross_attn(atoq, atok, atov, input)
return attn return attn
def main(path, root): def main(path, root):
torch.manual_seed(114514) torch.manual_seed(114514)
model_a = torch.load(path, map_location="cpu")["weight"] model_a = torch.load(path, map_location="cpu")["weight"]
@ -50,7 +58,9 @@ def main(path,root):
map_attn_a = {} map_attn_a = {}
map_rand_input = {} map_rand_input = {}
for n in range(6): for n in range(6):
hidden_dim, embed_dim,_ = model_a[f"enc_p.encoder.attn_layers.{n}.conv_v.weight"].shape hidden_dim, embed_dim, _ = model_a[
f"enc_p.encoder.attn_layers.{n}.conv_v.weight"
].shape
rand_input = torch.randn([embed_dim, hidden_dim]) rand_input = torch.randn([embed_dim, hidden_dim])
map_attn_a[n] = eval(model_a, n, rand_input) map_attn_a[n] = eval(model_a, n, rand_input)
@ -70,7 +80,11 @@ def main(path,root):
sim = torch.mean(torch.cosine_similarity(attn_a, attn_b)) sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
sims.append(sim) sims.append(sim)
print("reference:\t%s\t%s\t%s"%(path,model_hash(path),f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%")) print(
"reference:\t%s\t%s\t%s"
% (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%")
)
if __name__ == "__main__": if __name__ == "__main__":
query_path = r"weights\mi v3.pth" query_path = r"weights\mi v3.pth"