update index training script v2 (#643)

* update index training script v2

* Apply Code Formatter Change

---------

Co-authored-by: gak123 <gak123@users.noreply.github.com>
This commit is contained in:
Rice Cake 2023-06-28 13:48:06 +08:00 committed by GitHub
parent fad31f24f5
commit 7fc6642c04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,9 +2,15 @@
格式直接cid为自带的index位aid放不下了通过字典来查反正就5w个
"""
import faiss, numpy as np, os
from sklearn.cluster import MiniBatchKMeans
import traceback
from multiprocessing import cpu_count
# ###########如果是原始特征要先写save
inp_root = r"./logs/nene/3_feature768"
n_cpu = 0
if n_cpu == 0:
n_cpu = cpu_count()
inp_root = r"./logs/anz/3_feature768"
npys = []
listdir_res = list(os.listdir(inp_root))
for name in sorted(listdir_res):
@ -15,7 +21,27 @@ big_npy_idx = np.arange(big_npy.shape[0])
np.random.shuffle(big_npy_idx)
big_npy = big_npy[big_npy_idx]
print(big_npy.shape) # (6196072, 192)#fp32#4.43G
np.save("infer/big_src_feature_mi.npy", big_npy)
if big_npy.shape[0] > 2e5:
# if(1):
info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0]
print(info)
try:
big_npy = (
MiniBatchKMeans(
n_clusters=10000,
verbose=True,
batch_size=256 * n_cpu,
compute_labels=False,
init="random",
)
.fit(big_npy)
.cluster_centers_
)
except:
info = traceback.format_exc()
print(info)
np.save("tools/infer/big_src_feature_mi.npy", big_npy)
##################train+add
# big_npy=np.load("/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/inference_f0/big_src_feature_mi.npy")
@ -26,13 +52,15 @@ index_ivf = faiss.extract_index_ivf(index) #
index_ivf.nprobe = 1
index.train(big_npy)
faiss.write_index(
index, "infer/trained_IVF%s_Flat_baseline_src_feat_v2.index" % (n_ivf)
index, "tools/infer/trained_IVF%s_Flat_baseline_src_feat_v2.index" % (n_ivf)
)
print("adding")
batch_size_add = 8192
for i in range(0, big_npy.shape[0], batch_size_add):
index.add(big_npy[i : i + batch_size_add])
faiss.write_index(index, "infer/added_IVF%s_Flat_mi_baseline_src_feat.index" % (n_ivf))
faiss.write_index(
index, "tools/infer/added_IVF%s_Flat_mi_baseline_src_feat.index" % (n_ivf)
)
"""
大小都是FP32
big_src_feature 2.95G