diff --git a/config.py b/config.py index 1082d2a..c95feff 100644 --- a/config.py +++ b/config.py @@ -50,7 +50,6 @@ class Config: def device_config(self) -> tuple: if torch.cuda.is_available(): - self.gpu_name = torch.cuda.get_device_name(int(self.device.split(":")[-1])) i_device = int(self.device.split(":")[-1]) self.gpu_name = torch.cuda.get_device_name(i_device) if ( @@ -76,6 +75,8 @@ class Config: with open("trainset_preprocess_pipeline_print.py", "r+") as f: strr = f.read().replace("3.7", "3.0") f.write(strr) + else: + self.gpu_name = None elif torch.backends.mps.is_available(): print("没有发现支持的N卡, 使用MPS进行推理") self.device = "mps"