diff --git a/configs/config.py b/configs/config.py index c8aaff6..d9a9748 100644 --- a/configs/config.py +++ b/configs/config.py @@ -16,6 +16,10 @@ try: ipex_init() except Exception: # pylint: disable=broad-exception-caught pass +try: + import torch_npu # pylint: disable=unused-import +except Exception: + pass import logging logger = logging.getLogger(__name__) @@ -127,6 +131,13 @@ class Config: return True except Exception: return False + + @staticmethod + def has_npu() -> bool: + if hasattr(torch, "npu") and torch.npu.is_available(): + return True + else: + return False @staticmethod def has_xpu() -> bool: @@ -175,6 +186,21 @@ class Config: ) if self.gpu_mem <= 4: self.preprocess_per = 3.0 + elif self.has_npu(): + self.device = self.instead = "npu:0" + self.is_half = True + i_device = int(self.device.split(":")[-1]) + self.gpu_name = torch.npu.get_device_name(i_device) + logger.info("Found NPU %s", self.gpu_name) + self.gpu_mem = int( + torch.npu.get_device_properties(i_device).total_memory + / 1024 + / 1024 + / 1024 + + 0.4 + ) + if self.gpu_mem <= 4: + self.preprocess_per = 3.0 elif self.has_mps(): logger.info("No supported Nvidia GPU found") self.device = self.instead = "mps"