Format code (#221)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot] 2023-05-05 13:13:41 +08:00 committed by GitHub
parent ccf6e6bbd2
commit 6726af00cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 11 deletions

8
gui.py
View File

@ -133,7 +133,9 @@ class RVC:
score, ix = index.search(npy, k=8) score, ix = index.search(npy, k=8)
weight = np.square(1 / score) weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True) weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1).astype("float16") npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1).astype(
"float16"
)
feats = ( feats = (
torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate
@ -211,9 +213,7 @@ class GUI:
title=i18n("加载模型"), title=i18n("加载模型"),
layout=[ layout=[
[ [
sg.Input( sg.Input(default_text="hubert_base.pt", key="hubert_path"),
default_text="hubert_base.pt", key="hubert_path"
),
sg.FileBrowse(i18n("Hubert模型")), sg.FileBrowse(i18n("Hubert模型")),
], ],
[ [

View File

@ -694,9 +694,9 @@ def train_index(exp_dir1):
# faiss.write_index(index, '%s/trained_IVF%s_Flat_FastScan.index'%(exp_dir,n_ivf)) # faiss.write_index(index, '%s/trained_IVF%s_Flat_FastScan.index'%(exp_dir,n_ivf))
infos.append("adding") infos.append("adding")
yield "\n".join(infos) yield "\n".join(infos)
batch_size_add=8192 batch_size_add = 8192
for i in range(0,big_npy.shape[0],batch_size_add): for i in range(0, big_npy.shape[0], batch_size_add):
index.add(big_npy[i:i+batch_size_add]) index.add(big_npy[i : i + batch_size_add])
faiss.write_index( faiss.write_index(
index, index,
"%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),
@ -915,9 +915,9 @@ def train1key(
"%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), "%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),
) )
yield get_info_str("adding index") yield get_info_str("adding index")
batch_size_add=8192 batch_size_add = 8192
for i in range(0,big_npy.shape[0],batch_size_add): for i in range(0, big_npy.shape[0], batch_size_add):
index.add(big_npy[i:i+batch_size_add]) index.add(big_npy[i : i + batch_size_add])
faiss.write_index( faiss.write_index(
index, index,
"%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),