2023-08-02 21:20:46 +08:00
# This code references https://huggingface.co/JosephusCheung/ASimilarityCalculatior/blob/main/qwerty.py
# Fill in the path of the model to be queried and the root directory of the reference models, and this script will return the similarity between the model to be queried and all reference models.
2023-08-28 15:08:31 +08:00
import os
2023-09-01 15:18:08 +08:00
import logging
2023-09-02 11:50:52 +08:00
2023-09-01 15:18:08 +08:00
logger = logging . getLogger ( __name__ )
2023-08-28 15:08:31 +08:00
2023-08-02 21:20:46 +08:00
import torch
import torch . nn as nn
import torch . nn . functional as F
2023-08-03 10:25:05 +08:00
2023-08-02 21:20:46 +08:00
def cal_cross_attn ( to_q , to_k , to_v , rand_input ) :
hidden_dim , embed_dim = to_q . shape
attn_to_q = nn . Linear ( hidden_dim , embed_dim , bias = False )
attn_to_k = nn . Linear ( hidden_dim , embed_dim , bias = False )
attn_to_v = nn . Linear ( hidden_dim , embed_dim , bias = False )
attn_to_q . load_state_dict ( { " weight " : to_q } )
attn_to_k . load_state_dict ( { " weight " : to_k } )
attn_to_v . load_state_dict ( { " weight " : to_v } )
return torch . einsum (
" ik, jk -> ik " ,
2023-08-03 10:25:05 +08:00
F . softmax (
torch . einsum ( " ij, kj -> ik " , attn_to_q ( rand_input ) , attn_to_k ( rand_input ) ) ,
dim = - 1 ,
) ,
attn_to_v ( rand_input ) ,
2023-08-02 21:20:46 +08:00
)
2023-08-03 10:25:05 +08:00
2023-08-02 21:20:46 +08:00
def model_hash ( filename ) :
try :
with open ( filename , " rb " ) as file :
import hashlib
2023-08-03 10:25:05 +08:00
2023-08-02 21:20:46 +08:00
m = hashlib . sha256 ( )
file . seek ( 0x100000 )
m . update ( file . read ( 0x10000 ) )
return m . hexdigest ( ) [ 0 : 8 ]
except FileNotFoundError :
2023-08-03 10:25:05 +08:00
return " NOFILE "
2023-08-02 21:20:46 +08:00
def eval ( model , n , input ) :
qk = f " enc_p.encoder.attn_layers. { n } .conv_q.weight "
uk = f " enc_p.encoder.attn_layers. { n } .conv_k.weight "
vk = f " enc_p.encoder.attn_layers. { n } .conv_v.weight "
2023-08-03 10:25:05 +08:00
atoq , atok , atov = model [ qk ] [ : , : , 0 ] , model [ uk ] [ : , : , 0 ] , model [ vk ] [ : , : , 0 ]
2023-08-02 21:20:46 +08:00
attn = cal_cross_attn ( atoq , atok , atov , input )
return attn
2023-08-03 10:25:05 +08:00
def main ( path , root ) :
2023-08-02 21:20:46 +08:00
torch . manual_seed ( 114514 )
model_a = torch . load ( path , map_location = " cpu " ) [ " weight " ]
2023-09-01 15:18:08 +08:00
logger . info ( " Query: \t \t %s \t %s " % ( path , model_hash ( path ) ) )
2023-08-02 21:20:46 +08:00
map_attn_a = { }
map_rand_input = { }
for n in range ( 6 ) :
2023-08-03 10:25:05 +08:00
hidden_dim , embed_dim , _ = model_a [
f " enc_p.encoder.attn_layers. { n } .conv_v.weight "
] . shape
2023-08-02 21:20:46 +08:00
rand_input = torch . randn ( [ embed_dim , hidden_dim ] )
map_attn_a [ n ] = eval ( model_a , n , rand_input )
map_rand_input [ n ] = rand_input
del model_a
for name in sorted ( list ( os . listdir ( root ) ) ) :
2023-08-03 10:25:05 +08:00
path = " %s / %s " % ( root , name )
2023-08-02 21:20:46 +08:00
model_b = torch . load ( path , map_location = " cpu " ) [ " weight " ]
sims = [ ]
for n in range ( 6 ) :
attn_a = map_attn_a [ n ]
attn_b = eval ( model_b , n , map_rand_input [ n ] )
sim = torch . mean ( torch . cosine_similarity ( attn_a , attn_b ) )
sims . append ( sim )
2023-09-01 15:18:08 +08:00
logger . info (
2023-09-01 14:11:55 +08:00
" Reference: \t %s \t %s \t %s "
2023-08-03 10:25:05 +08:00
% ( path , model_hash ( path ) , f " { torch . mean ( torch . stack ( sims ) ) * 1e2 : .2f } % " )
)
2023-08-02 21:20:46 +08:00
if __name__ == " __main__ " :
2023-08-30 00:27:23 +08:00
query_path = r " assets \ weights \ mi v3.pth "
reference_root = r " assets \ weights "
2023-08-03 10:25:05 +08:00
main ( query_path , reference_root )