diff --git a/infer-web.py b/infer-web.py index b027f0e..a4a2c29 100644 --- a/infer-web.py +++ b/infer-web.py @@ -72,7 +72,7 @@ from config import ( noautoopen, ) from infer_uvr5 import _audio_pre_ -from my_utils import load_audio +from my_utils import load_audio, train_index from train.process_ckpt import show_info, change_info, merge, extract_small_model # from trainset_preprocess_pipeline import PreProcess @@ -614,47 +614,6 @@ def click_train( return "训练结束, 您可查看控制台训练日志或实验文件夹下的train.log" -# but4.click(train_index, [exp_dir1], info3) -def train_index(exp_dir1): - exp_dir = "%s/logs/%s" % (now_dir, exp_dir1) - os.makedirs(exp_dir, exist_ok=True) - feature_dir = "%s/3_feature256" % (exp_dir) - if os.path.exists(feature_dir) == False: - return "请先进行特征提取!" - listdir_res = list(os.listdir(feature_dir)) - if len(listdir_res) == 0: - return "请先进行特征提取!" - npys = [] - for name in sorted(listdir_res): - phone = np.load("%s/%s" % (feature_dir, name)) - npys.append(phone) - big_npy = np.concatenate(npys, 0) - np.save("%s/total_fea.npy" % exp_dir, big_npy) - n_ivf = big_npy.shape[0] // 39 - infos = [] - infos.append("%s,%s" % (big_npy.shape, n_ivf)) - yield "\n".join(infos) - index = faiss.index_factory(256, "IVF%s,Flat" % n_ivf) - infos.append("training") - yield "\n".join(infos) - index_ivf = faiss.extract_index_ivf(index) # - index_ivf.nprobe = int(np.power(n_ivf, 0.3)) - index.train(big_npy) - faiss.write_index( - index, - "%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), - ) - infos.append("adding") - yield "\n".join(infos) - index.add(big_npy) - faiss.write_index( - index, - "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), - ) - infos.append("成功构建索引, added_IVF%s_Flat_nprobe_%s.index" % (n_ivf, index_ivf.nprobe)) - yield "\n".join(infos) - - # but5.click(train1key, [exp_dir1, sr2, if_f0_3, trainset_dir4, spk_id5, gpus6, np7, f0method8, save_epoch10, total_epoch11, batch_size12, if_save_latest13, pretrained_G14, pretrained_D15, gpus16, if_cache_gpu17], info3) def train1key( exp_dir1, @@ -835,34 +794,7 @@ def train1key( p.wait() yield get_info_str("训练结束, 您可查看控制台训练日志或实验文件夹下的train.log") #######step3b:训练索引 - feature_dir = "%s/3_feature256" % (exp_dir) - npys = [] - listdir_res = list(os.listdir(feature_dir)) - for name in sorted(listdir_res): - phone = np.load("%s/%s" % (feature_dir, name)) - npys.append(phone) - big_npy = np.concatenate(npys, 0) - np.save("%s/total_fea.npy" % exp_dir, big_npy) - n_ivf = big_npy.shape[0] // 39 - yield get_info_str("%s,%s" % (big_npy.shape, n_ivf)) - index = faiss.index_factory(256, "IVF%s,Flat" % n_ivf) - yield get_info_str("training index") - index_ivf = faiss.extract_index_ivf(index) # - index_ivf.nprobe = int(np.power(n_ivf, 0.3)) - index.train(big_npy) - faiss.write_index( - index, - "%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), - ) - yield get_info_str("adding index") - index.add(big_npy) - faiss.write_index( - index, - "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), - ) - yield get_info_str( - "成功构建索引, added_IVF%s_Flat_nprobe_%s.index" % (n_ivf, index_ivf.nprobe) - ) + yield from train_index(exp_dir1) yield get_info_str("全流程结束!") diff --git a/my_utils.py b/my_utils.py index 89a1527..c74776e 100644 --- a/my_utils.py +++ b/my_utils.py @@ -1,7 +1,71 @@ +import os + +import faiss import ffmpeg import numpy as np +def train_index(exp_dir1): + '''train and save faiss index + + Args: + exp_dir1(string): Relative path where index is stored + ''' + exp_dir = "%s/logs/%s" % (os.getcwd(), exp_dir1) + os.makedirs(exp_dir, exist_ok=True) + feature_dir = "%s/3_feature256" % (exp_dir) + + if os.path.exists(feature_dir) == False: + return "请先进行特征提取!" + listdir_res = list(os.listdir(feature_dir)) + if len(listdir_res) == 0: + return "请先进行特征提取!" + + npys = [] + for name in sorted(listdir_res): + phone = np.load("%s/%s" % (feature_dir, name)) + npys.append(phone) + big_npy = np.concatenate(npys, 0) + np.save("%s/total_fea.npy" % exp_dir, big_npy) + + # use recommended parameter in https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index + N = big_npy.shape[0] + dim = big_npy.shape[1] + if 4 * np.rint(np.sqrt(N)) * 30 > N: + n_ivf = N // 30 + else: + n_ivf = -(-N // 256) + for x in range(4, 18, 2): + K = x * np.rint(np.sqrt(N)).astype(int) + if 30 * K <= N <= 256 * K: + n_ivf = K + break + index_string = "IVF%s,PQ%sx4fs,RFlat" % (n_ivf, -(-dim//2)) + index_name = index_string.replace(",", "_") + + + infos = [] + infos.append("%s,%s" % (big_npy.shape, n_ivf)) + yield "\n".join(infos) + index = faiss.index_factory(dim, index_string) + infos.append("training") + yield "\n".join(infos) + index.train(big_npy) + faiss.write_index( + index, + "%s/trained_%s.index" % (exp_dir, index_name), + ) + infos.append("adding") + yield "\n".join(infos) + index.add(big_npy) + faiss.write_index( + index, + "%s/added_%s.index" % (exp_dir, index_name), + ) + infos.append("成功构建索引, added_%s.index" % (index_name)) + yield "\n".join(infos) + + def load_audio(file, sr): try: # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26