lint, add init_cuda_malloc()

This commit is contained in:
Won-Kyu Park 2024-10-12 22:52:40 +09:00
parent e78be27e75
commit 0cc81464bb
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 16 additions and 15 deletions

@ -39,9 +39,9 @@ def get_gpu_names():
else:
gpu_names = set()
out = subprocess.check_output(['nvidia-smi', '-L'])
for l in out.split(b'\n'):
if len(l) > 0:
gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
for line in out.split(b'\n'):
if len(line) > 0:
gpu_names.add(line.decode('utf-8').split(' (UUID')[0])
return gpu_names
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
@ -55,7 +55,7 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor
def cuda_malloc_supported():
try:
names = get_gpu_names()
except:
except Exception:
names = set()
for x in names:
if "NVIDIA" in x:
@ -82,16 +82,16 @@ if not args.cuda_malloc:
version = module.__version__
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
args.cuda_malloc = cuda_malloc_supported()
except:
except Exception:
pass
def init_cuda_malloc():
if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
print(f"Setup environment PYTORCH_CUDA_ALLOC_CONF={env_var}")
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
print(f"Setup environment PYTORCH_CUDA_ALLOC_CONF={env_var}")

@ -10,7 +10,8 @@ from modules import initialize
startup_timer = timer.startup_timer
startup_timer.record("launcher")
import cuda_malloc
from cuda_malloc import init_cuda_malloc
init_cuda_malloc()
startup_timer.record("cuda_malloc")
initialize.imports()