diff --git a/config.py b/config.py index 7ea4c06..8e6fa22 100644 --- a/config.py +++ b/config.py @@ -2,6 +2,16 @@ import argparse import torch from multiprocessing import cpu_count +def config_file_change_fp32(): + for config_file in ["32k.json", "40k.json", "48k.json"]: + with open(f"configs/{config_file}", "r") as f: + strr = f.read().replace("true", "false") + with open(f"configs/{config_file}", "w") as f: + f.write(strr) + with open("trainset_preprocess_pipeline_print.py", "r") as f: + strr = f.read().replace("3.7", "3.0") + with open("trainset_preprocess_pipeline_print.py", "w") as f: + f.write(strr) class Config: def __init__(self): @@ -60,15 +70,7 @@ class Config: ): print("16系/10系显卡和P40强制单精度") self.is_half = False - for config_file in ["32k.json", "40k.json", "48k.json"]: - with open(f"configs/{config_file}", "r") as f: - strr = f.read().replace("true", "false") - with open(f"configs/{config_file}", "w") as f: - f.write(strr) - with open("trainset_preprocess_pipeline_print.py", "r") as f: - strr = f.read().replace("3.7", "3.0") - with open("trainset_preprocess_pipeline_print.py", "w") as f: - f.write(strr) + config_file_change_fp32() else: self.gpu_name = None self.gpu_mem = int( @@ -87,10 +89,12 @@ class Config: print("没有发现支持的N卡, 使用MPS进行推理") self.device = "mps" self.is_half = False + config_file_change_fp32() else: print("没有发现支持的N卡, 使用CPU进行推理") self.device = "cpu" self.is_half = False + config_file_change_fp32() if self.n_cpu == 0: self.n_cpu = cpu_count()