diff --git a/infer-web.py b/infer-web.py index d9ac429..e13e77e 100644 --- a/infer-web.py +++ b/infer-web.py @@ -35,8 +35,7 @@ from MDXNet import MDXNetDereverb from my_utils import load_audio from train.process_ckpt import change_info, extract_small_model, merge, show_info from vc_infer_pipeline import VC - -# from trainset_preprocess_pipeline import PreProcess +from sklearn.cluster import MiniBatchKMeans logging.getLogger("numba").setLevel(logging.WARNING) @@ -653,9 +652,13 @@ def change_sr2(sr2, if_f0_3, version19): "not exist, will not use pretrained model", ) return ( - "pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2) if if_pretrained_generator_exist else "", - "pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2) if if_pretrained_discriminator_exist else "", - {"visible": True, "__type__": "update"} + "pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2) + if if_pretrained_generator_exist + else "", + "pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2) + if if_pretrained_discriminator_exist + else "", + {"visible": True, "__type__": "update"}, ) @@ -679,8 +682,12 @@ def change_version19(sr2, if_f0_3, version19): "not exist, will not use pretrained model", ) return ( - "pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2) if if_pretrained_generator_exist else "", - "pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2) if if_pretrained_discriminator_exist else "", + "pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2) + if if_pretrained_generator_exist + else "", + "pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2) + if if_pretrained_discriminator_exist + else "", ) @@ -714,8 +721,12 @@ def change_f0(if_f0_3, sr2, version19): # f0method8,pretrained_G14,pretrained_D ) return ( {"visible": False, "__type__": "update"}, - "pretrained%s/G%s.pth" % (path_str, sr2) if if_pretrained_generator_exist else "", - "pretrained%s/D%s.pth" % (path_str, sr2) if if_pretrained_discriminator_exist else "", + ("pretrained%s/G%s.pth" % (path_str, sr2)) + if if_pretrained_generator_exist + else "", + ("pretrained%s/D%s.pth" % (path_str, sr2)) + if if_pretrained_discriminator_exist + else "", ) @@ -869,6 +880,7 @@ def train_index(exp_dir1, version19): listdir_res = list(os.listdir(feature_dir)) if len(listdir_res) == 0: return "请先进行特征提取!" + infos = [] npys = [] for name in sorted(listdir_res): phone = np.load("%s/%s" % (feature_dir, name)) @@ -877,10 +889,20 @@ def train_index(exp_dir1, version19): big_npy_idx = np.arange(big_npy.shape[0]) np.random.shuffle(big_npy_idx) big_npy = big_npy[big_npy_idx] + # if(big_npy.shape[0]>2e5): + if(1): + infos.append("Trying doing kmeans %s shape to 10k centers."%big_npy.shape[0]) + yield "\n".join(infos) + try: + big_npy = MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * config.n_cpu, compute_labels=False, init="random").fit(big_npy).cluster_centers_ + except: + info=traceback.format_exc() + print(info) + infos.append(info) + yield "\n".join(infos) + np.save("%s/total_fea.npy" % exp_dir, big_npy) - # n_ivf = big_npy.shape[0] // 39 n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) - infos = [] infos.append("%s,%s" % (big_npy.shape, n_ivf)) yield "\n".join(infos) index = faiss.index_factory(256 if version19 == "v1" else 768, "IVF%s,Flat" % n_ivf) @@ -1120,6 +1142,19 @@ def train1key( big_npy_idx = np.arange(big_npy.shape[0]) np.random.shuffle(big_npy_idx) big_npy = big_npy[big_npy_idx] + + # if(big_npy.shape[0]>2e5): + if(1): + info="Trying doing kmeans %s shape to 10k centers."%big_npy.shape[0] + print(info) + yield get_info_str(info) + try: + big_npy = MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * config.n_cpu, compute_labels=False, init="random").fit(big_npy).cluster_centers_ + except: + info=traceback.format_exc() + print(info) + yield get_info_str(info) + np.save("%s/total_fea.npy" % model_log_dir, big_npy) # n_ivf = big_npy.shape[0] // 39 @@ -1565,7 +1600,7 @@ with gr.Blocks() as app: maximum=config.n_cpu, step=1, label=i18n("提取音高和处理数据使用的CPU进程数"), - value=config.n_cpu, + value=int(np.ceil(config.n_cpu/1.5)), interactive=True, ) with gr.Group(): # 暂时单人的, 后面支持最多4人的#数据处理