mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-01 03:03:00 +08:00
commit
96b550430a
@ -3,7 +3,7 @@ import contextlib
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import errors, shared
|
from modules import errors, shared, npu_specific
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
from modules import mac_specific
|
from modules import mac_specific
|
||||||
@ -57,6 +57,9 @@ def get_optimal_device_name():
|
|||||||
if has_xpu():
|
if has_xpu():
|
||||||
return xpu_specific.get_xpu_device_string()
|
return xpu_specific.get_xpu_device_string()
|
||||||
|
|
||||||
|
if npu_specific.has_npu:
|
||||||
|
return npu_specific.get_npu_device_string()
|
||||||
|
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
@ -84,6 +87,16 @@ def torch_gc():
|
|||||||
if has_xpu():
|
if has_xpu():
|
||||||
xpu_specific.torch_xpu_gc()
|
xpu_specific.torch_xpu_gc()
|
||||||
|
|
||||||
|
if npu_specific.has_npu:
|
||||||
|
torch_npu_set_device()
|
||||||
|
npu_specific.torch_npu_gc()
|
||||||
|
|
||||||
|
|
||||||
|
def torch_npu_set_device():
|
||||||
|
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
|
||||||
|
if npu_specific.has_npu:
|
||||||
|
torch.npu.set_device(0)
|
||||||
|
|
||||||
|
|
||||||
def enable_tf32():
|
def enable_tf32():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -256,4 +269,3 @@ def first_time_calculation():
|
|||||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||||
conv2d(x)
|
conv2d(x)
|
||||||
|
|
||||||
|
@ -142,13 +142,14 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
its optimization may be None because the list of optimizaers has neet been filled
|
its optimization may be None because the list of optimizaers has neet been filled
|
||||||
by that time, so we apply optimization again.
|
by that time, so we apply optimization again.
|
||||||
"""
|
"""
|
||||||
|
from modules import devices
|
||||||
|
devices.torch_npu_set_device()
|
||||||
|
|
||||||
shared.sd_model # noqa: B018
|
shared.sd_model # noqa: B018
|
||||||
|
|
||||||
if sd_hijack.current_optimizer is None:
|
if sd_hijack.current_optimizer is None:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
from modules import devices
|
|
||||||
devices.first_time_calculation()
|
devices.first_time_calculation()
|
||||||
if not shared.cmd_opts.skip_load_model_at_start:
|
if not shared.cmd_opts.skip_load_model_at_start:
|
||||||
Thread(target=load_model).start()
|
Thread(target=load_model).start()
|
||||||
|
@ -338,6 +338,7 @@ def prepare_environment():
|
|||||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
@ -421,6 +422,13 @@ def prepare_environment():
|
|||||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||||
startup_timer.record("install requirements")
|
startup_timer.record("install requirements")
|
||||||
|
|
||||||
|
if not os.path.isfile(requirements_file_for_npu):
|
||||||
|
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
|
||||||
|
|
||||||
|
if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
|
||||||
|
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
|
||||||
|
startup_timer.record("install requirements_for_npu")
|
||||||
|
|
||||||
if not args.skip_install:
|
if not args.skip_install:
|
||||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||||
|
|
||||||
|
31
modules/npu_specific.py
Normal file
31
modules/npu_specific.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import importlib
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
|
def check_for_npu():
|
||||||
|
if importlib.util.find_spec("torch_npu") is None:
|
||||||
|
return False
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Will raise a RuntimeError if no NPU is found
|
||||||
|
_ = torch_npu.npu.device_count()
|
||||||
|
return torch.npu.is_available()
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_npu_device_string():
|
||||||
|
if shared.cmd_opts.device_id is not None:
|
||||||
|
return f"npu:{shared.cmd_opts.device_id}"
|
||||||
|
return "npu:0"
|
||||||
|
|
||||||
|
|
||||||
|
def torch_npu_gc():
|
||||||
|
with torch.npu.device(get_npu_device_string()):
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
has_npu = check_for_npu()
|
@ -150,6 +150,7 @@ class EmbeddingDatabase:
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def get_expected_shape(self):
|
def get_expected_shape(self):
|
||||||
|
devices.torch_npu_set_device()
|
||||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||||
return vec.shape[1]
|
return vec.shape[1]
|
||||||
|
|
||||||
|
4
requirements_npu.txt
Normal file
4
requirements_npu.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
cloudpickle
|
||||||
|
decorator
|
||||||
|
synr==0.5.0
|
||||||
|
tornado
|
4
webui.sh
4
webui.sh
@ -158,6 +158,10 @@ then
|
|||||||
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
|
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
|
||||||
then
|
then
|
||||||
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
|
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
|
||||||
|
elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]]
|
||||||
|
then
|
||||||
|
export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu"
|
||||||
|
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user