Add NPU Support

This commit is contained in:
wangshuai09 2024-01-27 17:21:32 +08:00
parent cf2772fab0
commit ec124607f4
7 changed files with 62 additions and 3 deletions

View File

@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache from functools import lru_cache
import torch import torch
from modules import errors, shared from modules import errors, shared, npu_specific
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
@ -40,6 +40,9 @@ def get_optimal_device_name():
if has_xpu(): if has_xpu():
return xpu_specific.get_xpu_device_string() return xpu_specific.get_xpu_device_string()
if npu_specific.has_npu:
return npu_specific.get_npu_device_string()
return "cpu" return "cpu"
@ -67,6 +70,9 @@ def torch_gc():
if has_xpu(): if has_xpu():
xpu_specific.torch_xpu_gc() xpu_specific.torch_xpu_gc()
if npu_specific.has_npu:
npu_specific.torch_npu_gc()
def enable_tf32(): def enable_tf32():
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -164,4 +170,3 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype) x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x) conv2d(x)

View File

@ -143,13 +143,17 @@ def initialize_rest(*, reload_script_modules=False):
its optimization may be None because the list of optimizaers has neet been filled its optimization may be None because the list of optimizaers has neet been filled
by that time, so we apply optimization again. by that time, so we apply optimization again.
""" """
from modules import devices
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
if devices.npu_specific.has_npu:
import torch
torch.npu.set_device(0)
shared.sd_model # noqa: B018 shared.sd_model # noqa: B018
if sd_hijack.current_optimizer is None: if sd_hijack.current_optimizer is None:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
from modules import devices
devices.first_time_calculation() devices.first_time_calculation()
if not shared.cmd_opts.skip_load_model_at_start: if not shared.cmd_opts.skip_load_model_at_start:
Thread(target=load_model).start() Thread(target=load_model).start()

34
modules/npu_specific.py Normal file
View File

@ -0,0 +1,34 @@
import importlib
import torch
from modules import shared
def check_for_npu():
if importlib.util.find_spec("torch_npu") is None:
return False
import torch_npu
torch_npu.npu.set_device(0)
try:
# Will raise a RuntimeError if no NPU is found
_ = torch.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
def get_npu_device_string():
if shared.cmd_opts.device_id is not None:
return f"npu:{shared.cmd_opts.device_id}"
return "npu:0"
def torch_npu_gc():
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
torch.npu.set_device(0)
with torch.npu.device(get_npu_device_string()):
torch.npu.empty_cache()
has_npu = check_for_npu()

View File

@ -151,6 +151,10 @@ class EmbeddingDatabase:
return embedding return embedding
def get_expected_shape(self): def get_expected_shape(self):
# workaround
if devices.npu_specific.has_npu:
import torch
torch.npu.set_device(0)
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1] return vec.shape[1]

View File

@ -5,6 +5,8 @@ accelerate
basicsr basicsr
blendmodes blendmodes
clean-fid clean-fid
cloudpickle
decorator
einops einops
fastapi>=0.90.1 fastapi>=0.90.1
gfpgan gfpgan
@ -26,9 +28,11 @@ resize-right
safetensors safetensors
scikit-image>=0.19 scikit-image>=0.19
synr==0.5.0
timm timm
tomesd tomesd
torch torch
torchdiffeq torchdiffeq
torchsde torchsde
tornado
transformers==4.30.2 transformers==4.30.2

View File

@ -4,6 +4,8 @@ accelerate==0.21.0
basicsr==1.4.2 basicsr==1.4.2
blendmodes==2022 blendmodes==2022
clean-fid==0.1.35 clean-fid==0.1.35
cloudpickle==3.0.0
decorator==5.1.1
einops==0.4.1 einops==0.4.1
fastapi==0.94.0 fastapi==0.94.0
gfpgan==1.3.8 gfpgan==1.3.8
@ -23,10 +25,12 @@ realesrgan==0.3.0
resize-right==0.0.2 resize-right==0.0.2
safetensors==0.3.1 safetensors==0.3.1
scikit-image==0.21.0 scikit-image==0.21.0
synr==0.5.0
timm==0.9.2 timm==0.9.2
tomesd==0.1.3 tomesd==0.1.3
torch torch
torchdiffeq==0.2.3 torchdiffeq==0.2.3
torchsde==0.2.6 torchsde==0.2.6
tornado==6.4
transformers==4.30.2 transformers==4.30.2
httpx==0.24.1 httpx==0.24.1

View File

@ -159,6 +159,10 @@ then
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then then
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2" export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]]
then
export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu"
fi fi
fi fi