train index:auto kmeans when feature shape too large

train index:auto kmeans when feature shape too large
This commit is contained in:
RVC-Boss 2023-06-18 16:16:33 +08:00 committed by GitHub
parent 75264d09b6
commit e7f204b32e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 47 additions and 12 deletions

View File

@ -35,8 +35,7 @@ from MDXNet import MDXNetDereverb
from my_utils import load_audio
from train.process_ckpt import change_info, extract_small_model, merge, show_info
from vc_infer_pipeline import VC
# from trainset_preprocess_pipeline import PreProcess
from sklearn.cluster import MiniBatchKMeans
logging.getLogger("numba").setLevel(logging.WARNING)
@ -653,9 +652,13 @@ def change_sr2(sr2, if_f0_3, version19):
"not exist, will not use pretrained model",
)
return (
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2) if if_pretrained_generator_exist else "",
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2) if if_pretrained_discriminator_exist else "",
{"visible": True, "__type__": "update"}
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)
if if_pretrained_generator_exist
else "",
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)
if if_pretrained_discriminator_exist
else "",
{"visible": True, "__type__": "update"},
)
@ -679,8 +682,12 @@ def change_version19(sr2, if_f0_3, version19):
"not exist, will not use pretrained model",
)
return (
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2) if if_pretrained_generator_exist else "",
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2) if if_pretrained_discriminator_exist else "",
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)
if if_pretrained_generator_exist
else "",
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)
if if_pretrained_discriminator_exist
else "",
)
@ -714,8 +721,12 @@ def change_f0(if_f0_3, sr2, version19): # f0method8,pretrained_G14,pretrained_D
)
return (
{"visible": False, "__type__": "update"},
"pretrained%s/G%s.pth" % (path_str, sr2) if if_pretrained_generator_exist else "",
"pretrained%s/D%s.pth" % (path_str, sr2) if if_pretrained_discriminator_exist else "",
("pretrained%s/G%s.pth" % (path_str, sr2))
if if_pretrained_generator_exist
else "",
("pretrained%s/D%s.pth" % (path_str, sr2))
if if_pretrained_discriminator_exist
else "",
)
@ -869,6 +880,7 @@ def train_index(exp_dir1, version19):
listdir_res = list(os.listdir(feature_dir))
if len(listdir_res) == 0:
return "请先进行特征提取!"
infos = []
npys = []
for name in sorted(listdir_res):
phone = np.load("%s/%s" % (feature_dir, name))
@ -877,10 +889,20 @@ def train_index(exp_dir1, version19):
big_npy_idx = np.arange(big_npy.shape[0])
np.random.shuffle(big_npy_idx)
big_npy = big_npy[big_npy_idx]
# if(big_npy.shape[0]>2e5):
if(1):
infos.append("Trying doing kmeans %s shape to 10k centers."%big_npy.shape[0])
yield "\n".join(infos)
try:
big_npy = MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * config.n_cpu, compute_labels=False, init="random").fit(big_npy).cluster_centers_
except:
info=traceback.format_exc()
print(info)
infos.append(info)
yield "\n".join(infos)
np.save("%s/total_fea.npy" % exp_dir, big_npy)
# n_ivf = big_npy.shape[0] // 39
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
infos = []
infos.append("%s,%s" % (big_npy.shape, n_ivf))
yield "\n".join(infos)
index = faiss.index_factory(256 if version19 == "v1" else 768, "IVF%s,Flat" % n_ivf)
@ -1120,6 +1142,19 @@ def train1key(
big_npy_idx = np.arange(big_npy.shape[0])
np.random.shuffle(big_npy_idx)
big_npy = big_npy[big_npy_idx]
# if(big_npy.shape[0]>2e5):
if(1):
info="Trying doing kmeans %s shape to 10k centers."%big_npy.shape[0]
print(info)
yield get_info_str(info)
try:
big_npy = MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * config.n_cpu, compute_labels=False, init="random").fit(big_npy).cluster_centers_
except:
info=traceback.format_exc()
print(info)
yield get_info_str(info)
np.save("%s/total_fea.npy" % model_log_dir, big_npy)
# n_ivf = big_npy.shape[0] // 39
@ -1565,7 +1600,7 @@ with gr.Blocks() as app:
maximum=config.n_cpu,
step=1,
label=i18n("提取音高和处理数据使用的CPU进程数"),
value=config.n_cpu,
value=int(np.ceil(config.n_cpu/1.5)),
interactive=True,
)
with gr.Group(): # 暂时单人的, 后面支持最多4人的#数据处理