From 709bbbac1a88ee18810907b2a7d2fb68c404ceca Mon Sep 17 00:00:00 2001
From: yxlllc <33565655+yxlllc@users.noreply.github.com>
Date: Sun, 24 Nov 2024 21:36:41 +0800
Subject: [PATCH] optimize nsf inference (#2387)

---
 infer/lib/infer_pack/models.py | 61 ++++++++++++----------------------
 1 file changed, 21 insertions(+), 40 deletions(-)

diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py
index a1a27e2..a900048 100644
--- a/infer/lib/infer_pack/models.py
+++ b/infer/lib/infer_pack/models.py
@@ -349,7 +349,25 @@ class SineGen(torch.nn.Module):
         if uv.device.type == "privateuseone":  # for DirectML
             uv = uv.float()
         return uv
-
+    
+    def _f02sine(self, f0, upp):
+        """ f0: (batchsize, length, dim)
+            where dim indicates fundamental tone and overtones
+        """
+        a = torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
+        rad = f0 / self.sampling_rate * a
+        rad2 = torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5
+        rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
+        rad += F.pad(rad_acc, (0, 0, 1, 0), mode='constant')
+        rad = rad.reshape(f0.shape[0], -1, 1)
+        b = torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
+        rad *= b
+        rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
+        rand_ini[..., 0] = 0
+        rad += rand_ini
+        sines = torch.sin(2 * np.pi * rad)
+        return sines
+        
     def forward(self, f0: torch.Tensor, upp: int):
         """sine_tensor, uv = forward(f0)
         input F0: tensor(batchsize=1, length, dim=1)
@@ -358,45 +376,8 @@ class SineGen(torch.nn.Module):
         output uv: tensor(batchsize=1, length, 1)
         """
         with torch.no_grad():
-            f0 = f0[:, None].transpose(1, 2)
-            f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
-            # fundamental component
-            f0_buf[:, :, 0] = f0[:, :, 0]
-            for idx in range(self.harmonic_num):
-                f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
-                    idx + 2
-                )  # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
-            rad_values = (
-                f0_buf / self.sampling_rate
-            ) % 1  ###%1意味着n_har的乘积无法后处理优化
-            rand_ini = torch.rand(
-                f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
-            )
-            rand_ini[:, 0] = 0
-            rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
-            tmp_over_one = torch.cumsum(
-                rad_values, 1
-            )  # % 1  #####%1意味着后面的cumsum无法再优化
-            tmp_over_one *= upp
-            tmp_over_one = F.interpolate(
-                tmp_over_one.transpose(2, 1),
-                scale_factor=float(upp),
-                mode="linear",
-                align_corners=True,
-            ).transpose(2, 1)
-            rad_values = F.interpolate(
-                rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
-            ).transpose(
-                2, 1
-            )  #######
-            tmp_over_one %= 1
-            tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
-            cumsum_shift = torch.zeros_like(rad_values)
-            cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
-            sine_waves = torch.sin(
-                torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
-            )
-            sine_waves = sine_waves * self.sine_amp
+            f0 = f0.unsqueeze(-1)
+            sine_waves = self._f02sine(f0, upp) * self.sine_amp
             uv = self._f02uv(f0)
             uv = F.interpolate(
                 uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"