diff --git a/lib/infer_pack/models.py b/lib/infer_pack/models.py index ed72341..f94f5b1 100644 --- a/lib/infer_pack/models.py +++ b/lib/infer_pack/models.py @@ -414,13 +414,19 @@ class SourceModuleHnNSF(torch.nn.Module): self.l_tanh = torch.nn.Tanh() def forward(self, x, upp=None): + if hasattr(self,"ddtype")==False: + self.ddtype=self.l_linear.weight.dtype sine_wavs, uv, _ = self.l_sin_gen(x, upp) - if self.is_half: - sine_wavs = sine_wavs.half() - sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x))) + # print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype) + # if self.is_half: + # sine_wavs = sine_wavs.half() + # sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x))) + # print(sine_wavs.dtype,self.ddtype) + if(sine_wavs.dtype!=self.ddtype): + sine_wavs=sine_wavs.to(self.ddtype) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) return sine_merge, None, None # noise, uv - class GeneratorNSF(torch.nn.Module): def __init__( self,