diff --git a/vc_infer_pipeline.py b/vc_infer_pipeline.py index 1db5e56..6ff5eb5 100644 --- a/vc_infer_pipeline.py +++ b/vc_infer_pipeline.py @@ -123,6 +123,7 @@ class VC(object): # _, I = index.search(npy, 1) # npy = big_npy[I.squeeze()] + #by github @nadare881 score, ix = index.search(npy, k=8) weight = np.square(1 / score) weight /= weight.sum(axis=1, keepdims=True)