From f27a991794c3507d7760d6857fbddd70f403128f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 1 Apr 2023 17:05:48 +0800 Subject: [PATCH] fix: extract freture cannot run on pure cpu --- extract_feature_print.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extract_feature_print.py b/extract_feature_print.py index a92f272..2e5eeb3 100644 --- a/extract_feature_print.py +++ b/extract_feature_print.py @@ -51,7 +51,8 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( ) model = models[0] model = model.to(device) -model = model.half() +if torch.cuda.is_available(): + model = model.half() model.eval() todo=sorted(list(os.listdir(wavPath)))[i_part::n_part] @@ -70,7 +71,7 @@ else: feats = readwave(wav_path, normalize=saved_cfg.task.normalize) padding_mask = torch.BoolTensor(feats.shape).fill_(False) inputs = { - "source": feats.half().to(device), + "source": feats.half().to(device) if torch.cuda.is_available() else feats.to(device), "padding_mask": padding_mask.to(device), "output_layer": 9, # layer 9 }