diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index a1a27e2..a900048 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -349,7 +349,25 @@ class SineGen(torch.nn.Module): if uv.device.type == "privateuseone": # for DirectML uv = uv.float() return uv - + + def _f02sine(self, f0, upp): + """ f0: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + a = torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device) + rad = f0 / self.sampling_rate * a + rad2 = torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5 + rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0) + rad += F.pad(rad_acc, (0, 0, 1, 0), mode='constant') + rad = rad.reshape(f0.shape[0], -1, 1) + b = torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1) + rad *= b + rand_ini = torch.rand(1, 1, self.dim, device=f0.device) + rand_ini[..., 0] = 0 + rad += rand_ini + sines = torch.sin(2 * np.pi * rad) + return sines + def forward(self, f0: torch.Tensor, upp: int): """sine_tensor, uv = forward(f0) input F0: tensor(batchsize=1, length, dim=1) @@ -358,45 +376,8 @@ class SineGen(torch.nn.Module): output uv: tensor(batchsize=1, length, 1) """ with torch.no_grad(): - f0 = f0[:, None].transpose(1, 2) - f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) - # fundamental component - f0_buf[:, :, 0] = f0[:, :, 0] - for idx in range(self.harmonic_num): - f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( - idx + 2 - ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic - rad_values = ( - f0_buf / self.sampling_rate - ) % 1 ###%1意味着n_har的乘积无法后处理优化 - rand_ini = torch.rand( - f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device - ) - rand_ini[:, 0] = 0 - rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini - tmp_over_one = torch.cumsum( - rad_values, 1 - ) # % 1 #####%1意味着后面的cumsum无法再优化 - tmp_over_one *= upp - tmp_over_one = F.interpolate( - tmp_over_one.transpose(2, 1), - scale_factor=float(upp), - mode="linear", - align_corners=True, - ).transpose(2, 1) - rad_values = F.interpolate( - rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest" - ).transpose( - 2, 1 - ) ####### - tmp_over_one %= 1 - tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 - cumsum_shift = torch.zeros_like(rad_values) - cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - sine_waves = torch.sin( - torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi - ) - sine_waves = sine_waves * self.sine_amp + f0 = f0.unsqueeze(-1) + sine_waves = self._f02sine(f0, upp) * self.sine_amp uv = self._f02uv(f0) uv = F.interpolate( uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"