diff --git a/extract_feature_print.py b/extract_feature_print.py index a92f272..2e5eeb3 100644 --- a/extract_feature_print.py +++ b/extract_feature_print.py @@ -51,7 +51,8 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( ) model = models[0] model = model.to(device) -model = model.half() +if torch.cuda.is_available(): + model = model.half() model.eval() todo=sorted(list(os.listdir(wavPath)))[i_part::n_part] @@ -70,7 +71,7 @@ else: feats = readwave(wav_path, normalize=saved_cfg.task.normalize) padding_mask = torch.BoolTensor(feats.shape).fill_(False) inputs = { - "source": feats.half().to(device), + "source": feats.half().to(device) if torch.cuda.is_available() else feats.to(device), "padding_mask": padding_mask.to(device), "output_layer": 9, # layer 9 }