2024-01-27 17:21:32 +08:00
|
|
|
import importlib
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from modules import shared
|
|
|
|
|
|
|
|
|
|
|
|
def check_for_npu():
|
|
|
|
if importlib.util.find_spec("torch_npu") is None:
|
|
|
|
return False
|
|
|
|
import torch_npu
|
|
|
|
|
|
|
|
try:
|
|
|
|
# Will raise a RuntimeError if no NPU is found
|
2024-01-31 10:46:53 +08:00
|
|
|
_ = torch_npu.npu.device_count()
|
2024-01-27 17:21:32 +08:00
|
|
|
return torch.npu.is_available()
|
|
|
|
except RuntimeError:
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def get_npu_device_string():
|
|
|
|
if shared.cmd_opts.device_id is not None:
|
|
|
|
return f"npu:{shared.cmd_opts.device_id}"
|
|
|
|
return "npu:0"
|
|
|
|
|
|
|
|
|
|
|
|
def torch_npu_gc():
|
|
|
|
with torch.npu.device(get_npu_device_string()):
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
has_npu = check_for_npu()
|