fix: extract freture cannot run on pure cpu

This commit is contained in:
源文雨 2023-04-01 17:05:48 +08:00
parent 9e59375311
commit f27a991794
1 changed files with 3 additions and 2 deletions

View File

@ -51,7 +51,8 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
)
model = models[0]
model = model.to(device)
model = model.half()
if torch.cuda.is_available():
model = model.half()
model.eval()
todo=sorted(list(os.listdir(wavPath)))[i_part::n_part]
@ -70,7 +71,7 @@ else:
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.half().to(device),
"source": feats.half().to(device) if torch.cuda.is_available() else feats.to(device),
"padding_mask": padding_mask.to(device),
"output_layer": 9, # layer 9
}