diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 1aceb0dbf..c62d4f642 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -313,9 +313,42 @@ def requirements_met(requirements_file): return True +def get_cuda_comp_cap(): + """ + Returns float of CUDA Compute Capability using nvidia-smi + Returns 0.0 on error + CUDA Compute Capability + ref https://developer.nvidia.com/cuda-gpus + ref https://en.wikipedia.org/wiki/CUDA + Blackwell consumer GPUs should return 12.0 data-center GPUs should return 10.0 + """ + try: + return max(map(float, subprocess.check_output(['nvidia-smi', '--query-gpu=compute_cap', '--format=noheader,csv'], text=True).splitlines())) + except Exception as _: + return 0.0 + + +def early_access_blackwell_wheels(): + """For Blackwell GPUs, use Early Access PyTorch Wheels provided by Nvidia""" + if all([ + os.environ.get('TORCH_INDEX_URL') is None, + sys.version_info.major == 3, + sys.version_info.minor in (10, 11, 12), + platform.system() == "Windows", + get_cuda_comp_cap() >= 10, # Blackwell + ]): + base_repo = 'https://huggingface.co/w-e-w/torch-2.6.0-cu128.nv/resolve/main/' + ea_whl = { + 10: f'{base_repo}torch-2.6.0+cu128.nv-cp310-cp310-win_amd64.whl#sha256=fef3de7ce8f4642e405576008f384304ad0e44f7b06cc1aa45e0ab4b6e70490d {base_repo}torchvision-0.20.0a0+cu128.nv-cp310-cp310-win_amd64.whl#sha256=50841254f59f1db750e7348b90a8f4cd6befec217ab53cbb03780490b225abef', + 11: f'{base_repo}torch-2.6.0+cu128.nv-cp311-cp311-win_amd64.whl#sha256=6665c36e6a7e79e7a2cb42bec190d376be9ca2859732ed29dd5b7b5a612d0d26 {base_repo}torchvision-0.20.0a0+cu128.nv-cp311-cp311-win_amd64.whl#sha256=bbc0ee4938e35fe5a30de3613bfcd2d8ef4eae334cf8d49db860668f0bb47083', + 12: f'{base_repo}torch-2.6.0+cu128.nv-cp312-cp312-win_amd64.whl#sha256=a3197f72379d34b08c4a4bcf49ea262544a484e8702b8c46cbcd66356c89def6 {base_repo}torchvision-0.20.0a0+cu128.nv-cp312-cp312-win_amd64.whl#sha256=235e7be71ac4e75b0f8e817bae4796d7bac8a67146d2037ab96394f2bdc63e6c' + } + return f'pip install {ea_whl.get(sys.version_info.minor)}' + + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}") + torch_command = os.environ.get('TORCH_COMMAND', early_access_blackwell_wheels() or f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}") if args.use_ipex: if platform.system() == "Windows": # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main