diff --git a/cuda_malloc.py b/cuda_malloc.py index 41bd1368e..fae5d73cf 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -39,9 +39,9 @@ def get_gpu_names(): else: gpu_names = set() out = subprocess.check_output(['nvidia-smi', '-L']) - for l in out.split(b'\n'): - if len(l) > 0: - gpu_names.add(l.decode('utf-8').split(' (UUID')[0]) + for line in out.split(b'\n'): + if len(line) > 0: + gpu_names.add(line.decode('utf-8').split(' (UUID')[0]) return gpu_names blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", @@ -55,7 +55,7 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor def cuda_malloc_supported(): try: names = get_gpu_names() - except: + except Exception: names = set() for x in names: if "NVIDIA" in x: @@ -82,16 +82,16 @@ if not args.cuda_malloc: version = module.__version__ if int(version[0]) >= 2: #enable by default for torch version 2.0 and up args.cuda_malloc = cuda_malloc_supported() - except: + except Exception: pass +def init_cuda_malloc(): + if args.cuda_malloc and not args.disable_cuda_malloc: + env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) + if env_var is None: + env_var = "backend:cudaMallocAsync" + else: + env_var += ",backend:cudaMallocAsync" -if args.cuda_malloc and not args.disable_cuda_malloc: - env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) - if env_var is None: - env_var = "backend:cudaMallocAsync" - else: - env_var += ",backend:cudaMallocAsync" - - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var - print(f"Setup environment PYTORCH_CUDA_ALLOC_CONF={env_var}") + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + print(f"Setup environment PYTORCH_CUDA_ALLOC_CONF={env_var}") diff --git a/webui.py b/webui.py index 0567668d8..aef977f7c 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,8 @@ from modules import initialize startup_timer = timer.startup_timer startup_timer.record("launcher") -import cuda_malloc +from cuda_malloc import init_cuda_malloc +init_cuda_malloc() startup_timer.record("cuda_malloc") initialize.imports()