diff --git a/train_nsf_sim_cache_sid_load_pretrain.py b/train_nsf_sim_cache_sid_load_pretrain.py index 42d2cc5..c1bdf11 100644 --- a/train_nsf_sim_cache_sid_load_pretrain.py +++ b/train_nsf_sim_cache_sid_load_pretrain.py @@ -67,8 +67,13 @@ class EpochRecorder: def main(): n_gpus = torch.cuda.device_count() + if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True: n_gpus = 1 + if n_gpus < 1: + # patch to unblock people without gpus. there is probably a better way. + print("NO GPU DETECTED: falling back to CPU - this may take a while") + n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) children = []