mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-05-07 04:09:06 +08:00
feat: add support for ascend npu
This commit is contained in:
parent
b0fca77ea0
commit
1d1ccb6e4e
@ -16,6 +16,10 @@ try:
|
|||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
import torch_npu # pylint: disable=unused-import
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -128,6 +132,13 @@ class Config:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_npu() -> bool:
|
||||||
|
if hasattr(torch, "npu") and torch.npu.is_available():
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_xpu() -> bool:
|
def has_xpu() -> bool:
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
@ -175,6 +186,21 @@ class Config:
|
|||||||
)
|
)
|
||||||
if self.gpu_mem <= 4:
|
if self.gpu_mem <= 4:
|
||||||
self.preprocess_per = 3.0
|
self.preprocess_per = 3.0
|
||||||
|
elif self.has_npu():
|
||||||
|
self.device = self.instead = "npu:0"
|
||||||
|
self.is_half = True
|
||||||
|
i_device = int(self.device.split(":")[-1])
|
||||||
|
self.gpu_name = torch.npu.get_device_name(i_device)
|
||||||
|
logger.info("Found NPU %s", self.gpu_name)
|
||||||
|
self.gpu_mem = int(
|
||||||
|
torch.npu.get_device_properties(i_device).total_memory
|
||||||
|
/ 1024
|
||||||
|
/ 1024
|
||||||
|
/ 1024
|
||||||
|
+ 0.4
|
||||||
|
)
|
||||||
|
if self.gpu_mem <= 4:
|
||||||
|
self.preprocess_per = 3.0
|
||||||
elif self.has_mps():
|
elif self.has_mps():
|
||||||
logger.info("No supported Nvidia GPU found")
|
logger.info("No supported Nvidia GPU found")
|
||||||
self.device = self.instead = "mps"
|
self.device = self.instead = "mps"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user