From ad6a6e59c29501f392f6c5173606049453b762f3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 3 Jun 2024 08:13:55 +0000 Subject: [PATCH] chore(format): run black on dev --- infer-web.py | 43 ++++++++++++++++++++++++--------- infer/lib/train/process_ckpt.py | 20 +++++++++------ infer/lib/train/utils.py | 6 ++--- infer/modules/vc/modules.py | 3 ++- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/infer-web.py b/infer-web.py index 5148fe0..14adfd6 100644 --- a/infer-web.py +++ b/infer-web.py @@ -569,7 +569,9 @@ def click_train( sort_keys=True, ) f.write("\n") - cmd = '"%s" infer/modules/train/train.py -e "%s" -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s -a "%s"' % ( + cmd = ( + '"%s" infer/modules/train/train.py -e "%s" -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s -a "%s"' + % ( config.python_cmd, exp_dir1, sr2, @@ -583,8 +585,9 @@ def click_train( 1 if if_cache_gpu17 == i18n("是") else 0, 1 if if_save_every_weights18 == i18n("是") else 0, version19, - author + author, ) + ) if gpus16: cmd += '-g "%s"' % (gpus16) @@ -713,7 +716,7 @@ def train1key( if_save_every_weights18, version19, gpus_rmvpe, - author + author, ): infos = [] @@ -751,7 +754,7 @@ def train1key( if_cache_gpu17, if_save_every_weights18, version19, - author + author, ) yield get_info_str( i18n("训练结束, 您可查看控制台训练日志或实验文件夹下的train.log") @@ -1245,7 +1248,9 @@ with gr.Blocks(title="RVC WebUI") as app: ) with gr.Column(): but2 = gr.Button(i18n("特征提取"), variant="primary") - info2 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8) + info2 = gr.Textbox( + label=i18n("输出信息"), value="", max_lines=8 + ) f0method8.change( fn=change_f0_method, inputs=[f0method8], @@ -1266,7 +1271,9 @@ with gr.Blocks(title="RVC WebUI") as app: api_name="train_extract_f0_feature", ) with gr.Group(): - gr.Markdown(value=i18n("### 第三步 开始训练\n填写训练设置, 开始训练模型和索引.")) + gr.Markdown( + value=i18n("### 第三步 开始训练\n填写训练设置, 开始训练模型和索引.") + ) with gr.Row(): with gr.Column(): save_epoch10 = gr.Slider( @@ -1348,7 +1355,7 @@ with gr.Blocks(title="RVC WebUI") as app: [if_f0_3, sr2, version19], [f0method8, gpus_rmvpe, pretrained_G14, pretrained_D15], ) - + but3 = gr.Button(i18n("训练模型"), variant="primary") but4 = gr.Button(i18n("训练特征索引"), variant="primary") but5 = gr.Button(i18n("一键训练"), variant="primary") @@ -1521,7 +1528,9 @@ with gr.Blocks(title="RVC WebUI") as app: ) with gr.Column(): but7 = gr.Button(i18n("修改"), variant="primary") - info5 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8) + info5 = gr.Textbox( + label=i18n("输出信息"), value="", max_lines=8 + ) but7.click( change_info, [ckpt_path0, info_, name_to_save1], @@ -1541,7 +1550,9 @@ with gr.Blocks(title="RVC WebUI") as app: ) but8 = gr.Button(i18n("查看"), variant="primary") with gr.Column(): - info6 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8) + info6 = gr.Textbox( + label=i18n("输出信息"), value="", max_lines=8 + ) but8.click(show_info, [ckpt_path1], info6, api_name="ckpt_show") with gr.Group(): gr.Markdown( @@ -1592,13 +1603,23 @@ with gr.Blocks(title="RVC WebUI") as app: ) with gr.Column(): but9 = gr.Button(i18n("提取"), variant="primary") - info7 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8) + info7 = gr.Textbox( + label=i18n("输出信息"), value="", max_lines=8 + ) ckpt_path2.change( change_info_, [ckpt_path2], [sr__, if_f0__, version_1] ) but9.click( extract_small_model, - [ckpt_path2, save_name, extauthor, sr__, if_f0__, info___, version_1], + [ + ckpt_path2, + save_name, + extauthor, + sr__, + if_f0__, + info___, + version_1, + ], info7, api_name="ckpt_extract", ) diff --git a/infer/lib/train/process_ckpt.py b/infer/lib/train/process_ckpt.py index 7c1b3d4..1dd37cb 100644 --- a/infer/lib/train/process_ckpt.py +++ b/infer/lib/train/process_ckpt.py @@ -43,7 +43,8 @@ def save_small_model(ckpt, sr, if_f0, name, epoch, version, hps): opt["info"] = "%sepoch" % epoch opt["name"] = name opt["timestamp"] = int(time()) - if hps.author: opt["author"] = hps.author + if hps.author: + opt["author"] = hps.author opt["sr"] = sr opt["f0"] = if_f0 opt["version"] = version @@ -179,7 +180,8 @@ def extract_small_model(path, name, author, sr, if_f0, info, version): opt["info"] = info opt["name"] = name opt["timestamp"] = int(time()) - if author: opt["author"] = author + if author: + opt["author"] = author opt["version"] = version opt["sr"] = sr opt["f0"] = int(if_f0) @@ -216,12 +218,15 @@ def merge(path1, path2, alpha1, sr, f0, info, name, version): continue opt["weight"][key] = a[key] return opt - + def authors(c1, c2): a1, a2 = c1.get("author", ""), c2.get("author", "") - if a1 == a2: return a1 - if not a1: a1 = "Unknown" - if not a2: a2 = "Unknown" + if a1 == a2: + return a1 + if not a1: + a1 = "Unknown" + if not a2: + a2 = "Unknown" return f"{a1} & {a2}" ckpt1 = torch.load(path1, map_location="cpu") @@ -260,7 +265,8 @@ def merge(path1, path2, alpha1, sr, f0, info, name, version): """ opt["name"] = name opt["timestamp"] = int(time()) - if author: opt["author"] = author + if author: + opt["author"] = author opt["sr"] = sr opt["f0"] = 1 if f0 == i18n("是") else 0 opt["version"] = version diff --git a/infer/lib/train/utils.py b/infer/lib/train/utils.py index e55a4df..d16756a 100644 --- a/infer/lib/train/utils.py +++ b/infer/lib/train/utils.py @@ -358,9 +358,7 @@ def get_hparams(init=True): required=True, help="if caching the dataset in GPU memory, 1 or 0", ) - parser.add_argument( - "-a", "--author", type=str, default="", help="Model author" - ) + parser.add_argument("-a", "--author", type=str, default="", help="Model author") args = parser.parse_args() name = args.experiment_dir @@ -389,6 +387,7 @@ def get_hparams(init=True): hparams.author = args.author return hparams + """ def get_hparams_from_dir(model_dir): config_save_path = os.path.join(model_dir, "config.json") @@ -435,6 +434,7 @@ def check_git_hash(model_dir): open(path, "w").write(cur_hash) """ + def get_logger(model_dir, filename="train.log"): global logger logger = logging.getLogger(os.path.basename(model_dir)) diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index c31e6eb..e4dd871 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -96,7 +96,8 @@ class VC: {"value": to_return_protect[2], "__type__": "update"}, {"value": to_return_protect[3], "__type__": "update"}, {"value": "", "__type__": "update"}, - ) if to_return_protect + ) + if to_return_protect else {"visible": True, "maximum": 0, "__type__": "update"} ) person = f'{os.getenv("weight_root")}/{sid}'