diff --git a/gui.py b/gui.py index 63ab100..7079c9f 100644 --- a/gui.py +++ b/gui.py @@ -150,8 +150,12 @@ class RVC: assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) + if Config.is_half: + feats = feats.half() + else: + feats = feats.float() inputs = { - "source": feats.half().to(device), + "source": feats.to(device), "padding_mask": padding_mask.to(device), "output_layer": 9 if self.version == "v1" else 12, }