feat: add support for ascend npu

This commit is contained in:
statelesshz 2024-05-14 09:07:11 +08:00
parent b0fca77ea0
commit 1d1ccb6e4e

View File

@ -16,6 +16,10 @@ try:
ipex_init() ipex_init()
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
pass pass
try:
import torch_npu # pylint: disable=unused-import
except Exception:
pass
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -127,6 +131,13 @@ class Config:
return True return True
except Exception: except Exception:
return False return False
@staticmethod
def has_npu() -> bool:
if hasattr(torch, "npu") and torch.npu.is_available():
return True
else:
return False
@staticmethod @staticmethod
def has_xpu() -> bool: def has_xpu() -> bool:
@ -175,6 +186,21 @@ class Config:
) )
if self.gpu_mem <= 4: if self.gpu_mem <= 4:
self.preprocess_per = 3.0 self.preprocess_per = 3.0
elif self.has_npu():
self.device = self.instead = "npu:0"
self.is_half = True
i_device = int(self.device.split(":")[-1])
self.gpu_name = torch.npu.get_device_name(i_device)
logger.info("Found NPU %s", self.gpu_name)
self.gpu_mem = int(
torch.npu.get_device_properties(i_device).total_memory
/ 1024
/ 1024
/ 1024
+ 0.4
)
if self.gpu_mem <= 4:
self.preprocess_per = 3.0
elif self.has_mps(): elif self.has_mps():
logger.info("No supported Nvidia GPU found") logger.info("No supported Nvidia GPU found")
self.device = self.instead = "mps" self.device = self.instead = "mps"