drop: overwrite of config jsons & feat: read index from assets/indices

This commit is contained in:
源文雨 2024-01-22 17:18:59 +09:00
parent 3a80032e74
commit 6141253fba
8 changed files with 53 additions and 12 deletions

1
.env
View File

@ -5,4 +5,5 @@ no_proxy = localhost, 127.0.0.1, ::1
weight_root = assets/weights weight_root = assets/weights
weight_uvr5_root = assets/uvr5_weights weight_uvr5_root = assets/uvr5_weights
index_root = logs index_root = logs
outside_index_root = assets/indices
rmvpe_root = assets/rmvpe rmvpe_root = assets/rmvpe

2
assets/indices/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

View File

@ -2,6 +2,7 @@ import argparse
import os import os
import sys import sys
import json import json
import shutil
from multiprocessing import cpu_count from multiprocessing import cpu_count
import torch import torch
@ -65,7 +66,10 @@ class Config:
def load_config_json() -> dict: def load_config_json() -> dict:
d = {} d = {}
for config_file in version_config_list: for config_file in version_config_list:
with open(f"configs/{config_file}", "r") as f: p = f"configs/inuse/{config_file}"
if not os.path.exists(p):
shutil.copy(f"configs/{config_file}", p)
with open(f"configs/inuse/{config_file}", "r") as f:
d[config_file] = json.load(f) d[config_file] = json.load(f)
return d return d
@ -124,12 +128,13 @@ class Config:
def use_fp32_config(self): def use_fp32_config(self):
for config_file in version_config_list: for config_file in version_config_list:
self.json_config[config_file]["train"]["fp16_run"] = False self.json_config[config_file]["train"]["fp16_run"] = False
with open(f"configs/{config_file}", "r") as f: with open(f"configs/inuse/{config_file}", "r") as f:
strr = f.read().replace("true", "false") strr = f.read().replace("true", "false")
with open(f"configs/{config_file}", "w") as f: with open(f"configs/inuse/{config_file}", "w") as f:
f.write(strr) f.write(strr)
logger.info("overwrite "+config_file)
self.preprocess_per = 3.0 self.preprocess_per = 3.0
logger.info("overwrite configs.json") logger.info("overwrite preprocess_per to %d" % (self.preprocess_per))
def device_config(self) -> tuple: def device_config(self) -> tuple:
if torch.cuda.is_available(): if torch.cuda.is_available():

4
configs/inuse/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*
!.gitignore
!v1
!v2

2
configs/inuse/v1/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

2
configs/inuse/v2/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

View File

@ -1,6 +1,7 @@
import os import os
import sys import sys
from dotenv import load_dotenv from dotenv import load_dotenv
import shutil
load_dotenv() load_dotenv()
@ -147,7 +148,9 @@ if __name__ == "__main__":
def load(self): def load(self):
try: try:
with open("configs/config.json", "r") as j: if not os.path.exists("configs/inuse/config.json"):
shutil.copy("configs/config.json", "configs/inuse/config.json")
with open("configs/inuse/config.json", "r") as j:
data = json.load(j) data = json.load(j)
data["sr_model"] = data["sr_type"] == "sr_model" data["sr_model"] = data["sr_type"] == "sr_model"
data["sr_device"] = data["sr_type"] == "sr_device" data["sr_device"] = data["sr_type"] == "sr_device"
@ -179,7 +182,7 @@ if __name__ == "__main__":
self.output_devices_indices.index(sd.default.device[1]) self.output_devices_indices.index(sd.default.device[1])
] ]
except: except:
with open("configs/config.json", "w") as j: with open("configs/inuse/config.json", "w") as j:
data = { data = {
"pth_path": "", "pth_path": "",
"index_path": "", "index_path": "",
@ -578,7 +581,7 @@ if __name__ == "__main__":
].index(True) ].index(True)
], ],
} }
with open("configs/config.json", "w") as j: with open("configs/inuse/config.json", "w") as j:
json.dump(settings, j) json.dump(settings, j)
if self.stream is not None: if self.stream is not None:
self.delay_time = ( self.delay_time = (

View File

@ -131,16 +131,21 @@ class ToolButton(gr.Button, gr.components.FormComponent):
weight_root = os.getenv("weight_root") weight_root = os.getenv("weight_root")
weight_uvr5_root = os.getenv("weight_uvr5_root") weight_uvr5_root = os.getenv("weight_uvr5_root")
index_root = os.getenv("index_root") index_root = os.getenv("index_root")
outside_index_root = os.getenv("outside_index_root")
names = [] names = []
for name in os.listdir(weight_root): for name in os.listdir(weight_root):
if name.endswith(".pth"): if name.endswith(".pth"):
names.append(name) names.append(name)
index_paths = [] index_paths = []
for root, dirs, files in os.walk(index_root, topdown=False): def lookup_indices(index_root):
global index_paths
for root, dirs, files in os.walk(index_root, topdown=False):
for name in files: for name in files:
if name.endswith(".index") and "trained" not in name: if name.endswith(".index") and "trained" not in name:
index_paths.append("%s/%s" % (root, name)) index_paths.append("%s/%s" % (root, name))
lookup_indices(index_root)
lookup_indices(outside_index_root)
uvr5_names = [] uvr5_names = []
for name in os.listdir(weight_uvr5_root): for name in os.listdir(weight_uvr5_root):
if name.endswith(".pth") or "onnx" in name: if name.endswith(".pth") or "onnx" in name:
@ -658,6 +663,23 @@ def train_index(exp_dir1, version19):
"%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index" "%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index"
% (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19), % (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
) )
try:
os.link(
"%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index"
% (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
"%s/%s_IVF%s_Flat_nprobe_%s_%s_%s.index"
% (outside_index_root, exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
)
infos.append(
"链接索引到%s"
% (outside_index_root)
)
except:
infos.append(
"链接索引到%s失败"
% (outside_index_root)
)
infos.append("adding") infos.append("adding")
yield "\n".join(infos) yield "\n".join(infos)
@ -670,7 +692,7 @@ def train_index(exp_dir1, version19):
% (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19), % (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
) )
infos.append( infos.append(
"成功构建索引added_IVF%s_Flat_nprobe_%s_%s_%s.index" "成功构建索引 added_IVF%s_Flat_nprobe_%s_%s_%s.index"
% (n_ivf, index_ivf.nprobe, exp_dir1, version19) % (n_ivf, index_ivf.nprobe, exp_dir1, version19)
) )
# faiss.write_index(index, '%s/added_IVF%s_Flat_FastScan_%s.index'%(exp_dir,n_ivf,version19)) # faiss.write_index(index, '%s/added_IVF%s_Flat_FastScan_%s.index'%(exp_dir,n_ivf,version19))