From 0b15d48f209a17862290137dbc0a12ce38340708 Mon Sep 17 00:00:00 2001 From: GratefulTony Date: Thu, 27 Jul 2023 20:44:16 -0600 Subject: [PATCH] feat: unblock cpu training (#889) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update train_nsf_sim_cache_sid_load_pretrain.py patch to unblock cpu training. CPU training took ~12 hours for me. * Update train_nsf_sim_cache_sid_load_pretrain.py Co-authored-by: Nato Boram --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: Nato Boram --- train_nsf_sim_cache_sid_load_pretrain.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/train_nsf_sim_cache_sid_load_pretrain.py b/train_nsf_sim_cache_sid_load_pretrain.py index 42d2cc5..c1bdf11 100644 --- a/train_nsf_sim_cache_sid_load_pretrain.py +++ b/train_nsf_sim_cache_sid_load_pretrain.py @@ -67,8 +67,13 @@ class EpochRecorder: def main(): n_gpus = torch.cuda.device_count() + if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True: n_gpus = 1 + if n_gpus < 1: + # patch to unblock people without gpus. there is probably a better way. + print("NO GPU DETECTED: falling back to CPU - this may take a while") + n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) children = []