From 4e0d399cba10d43f6494abdc6509504c44b6b767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 24 Jun 2023 13:56:09 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=20config.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/config.py b/config.py index 48187f5..a59f738 100644 --- a/config.py +++ b/config.py @@ -3,7 +3,7 @@ import torch from multiprocessing import cpu_count -def config_file_change_fp32(): +def use_fp32_config(): for config_file in ["32k.json", "40k.json", "48k.json"]: with open(f"configs/{config_file}", "r") as f: strr = f.read().replace("true", "false") @@ -58,6 +58,17 @@ class Config: cmd_opts.noparallel, cmd_opts.noautoopen, ) + + # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. + # check `getattr` and try it for compatibility + @staticmethod + def has_mps() -> bool: + if not torch.backends.mps.is_available(): return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False def device_config(self) -> tuple: if torch.cuda.is_available(): @@ -70,9 +81,9 @@ class Config: or "1070" in self.gpu_name or "1080" in self.gpu_name ): - print("16系/10系显卡和P40强制单精度") + print("16|10|P40 series, force to fp32") self.is_half = False - config_file_change_fp32() + use_fp32_config() else: self.gpu_name = None self.gpu_mem = int( @@ -87,16 +98,16 @@ class Config: strr = f.read().replace("3.7", "3.0") with open("trainset_preprocess_pipeline_print.py", "w") as f: f.write(strr) - elif torch.backends.mps.is_available(): - print("没有发现支持的N卡, 使用MPS进行推理") + elif self.has_mps(): + print("No supported Nvidia GPU, use MPS instead") self.device = "mps" self.is_half = False - config_file_change_fp32() + use_fp32_config() else: - print("没有发现支持的N卡, 使用CPU进行推理") + print("No supported Nvidia GPU, use CPU instead") self.device = "cpu" self.is_half = False - config_file_change_fp32() + use_fp32_config() if self.n_cpu == 0: self.n_cpu = cpu_count()