Merge pull request #14981 from wangshuai09/gpu_info_for_ascend

Add training support and change lspci for Ascend NPU
This commit is contained in:
AUTOMATIC1111 2024-03-04 20:06:54 +03:00 committed by GitHub
commit eee46a5094
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 3 deletions

View File

@ -95,6 +95,7 @@ class HypernetworkModule(torch.nn.Module):
zeros_(b) zeros_(b)
else: else:
raise KeyError(f"Key {weight_init} is not defined as initialization!") raise KeyError(f"Key {weight_init} is not defined as initialization!")
devices.torch_npu_set_device()
self.to(devices.device) self.to(devices.device)
def fix_old_state_dict(self, state_dict): def fix_old_state_dict(self, state_dict):

View File

@ -230,7 +230,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
for fixes in self.hijack.fixes: for fixes in self.hijack.fixes:
for _position, embedding in fixes: for _position, embedding in fixes:
used_embeddings[embedding.name] = embedding used_embeddings[embedding.name] = embedding
devices.torch_npu_set_device()
z = self.process_tokens(tokens, multipliers) z = self.process_tokens(tokens, multipliers)
zs.append(z) zs.append(z)

View File

@ -158,9 +158,9 @@ then
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then then
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2" export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]] elif npu-smi info 2>/dev/null
then then
export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu" export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu==2.1.0"
fi fi
fi fi