Better workaround for Navi1, removing --pre for Navi3

This commit is contained in:
DGdev91 2024-03-12 00:09:07 +01:00
parent 3e0146f9bd
commit 8262cd71c4

View File

@ -129,11 +129,11 @@ case "$gpu_info" in
export HSA_OVERRIDE_GFX_VERSION=10.3.0 export HSA_OVERRIDE_GFX_VERSION=10.3.0
if [[ -z "${TORCH_COMMAND}" ]] if [[ -z "${TORCH_COMMAND}" ]]
then then
pyv="$(${python_cmd} -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')" pyv="$(${python_cmd} -c 'import sys; print(float(".".join(map(str, sys.version_info[0:2]))) <= 3.10)')"
if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] if [[ $pyv == "True" ]]
then then
# Navi users will still use torch 1.13 because 2.0 does not seem to work. # Using an old nightly compiled against rocm 5.2 for Navi1, see https://github.com/pytorch/pytorch/issues/106728#issuecomment-1749511711
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" export TORCH_COMMAND="pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp310-cp310-linux_x86_64.whl"
else else
printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m"
exit 1 exit 1
@ -143,7 +143,7 @@ case "$gpu_info" in
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;; ;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7"
;; ;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"