optimize nsf inference (#2387)

This commit is contained in:
yxlllc 2024-11-24 21:36:41 +08:00 committed by GitHub
parent 1376ce739d
commit 709bbbac1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -349,7 +349,25 @@ class SineGen(torch.nn.Module):
if uv.device.type == "privateuseone": # for DirectML if uv.device.type == "privateuseone": # for DirectML
uv = uv.float() uv = uv.float()
return uv return uv
def _f02sine(self, f0, upp):
""" f0: (batchsize, length, dim)
where dim indicates fundamental tone and overtones
"""
a = torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
rad = f0 / self.sampling_rate * a
rad2 = torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
rad += F.pad(rad_acc, (0, 0, 1, 0), mode='constant')
rad = rad.reshape(f0.shape[0], -1, 1)
b = torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
rad *= b
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
rand_ini[..., 0] = 0
rad += rand_ini
sines = torch.sin(2 * np.pi * rad)
return sines
def forward(self, f0: torch.Tensor, upp: int): def forward(self, f0: torch.Tensor, upp: int):
"""sine_tensor, uv = forward(f0) """sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1) input F0: tensor(batchsize=1, length, dim=1)
@ -358,45 +376,8 @@ class SineGen(torch.nn.Module):
output uv: tensor(batchsize=1, length, 1) output uv: tensor(batchsize=1, length, 1)
""" """
with torch.no_grad(): with torch.no_grad():
f0 = f0[:, None].transpose(1, 2) f0 = f0.unsqueeze(-1)
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) sine_waves = self._f02sine(f0, upp) * self.sine_amp
# fundamental component
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in range(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
idx + 2
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
rad_values = (
f0_buf / self.sampling_rate
) % 1 ###%1意味着n_har的乘积无法后处理优化
rand_ini = torch.rand(
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(
rad_values, 1
) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one *= upp
tmp_over_one = F.interpolate(
tmp_over_one.transpose(2, 1),
scale_factor=float(upp),
mode="linear",
align_corners=True,
).transpose(2, 1)
rad_values = F.interpolate(
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
).transpose(
2, 1
) #######
tmp_over_one %= 1
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sine_waves = torch.sin(
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
)
sine_waves = sine_waves * self.sine_amp
uv = self._f02uv(f0) uv = self._f02uv(f0)
uv = F.interpolate( uv = F.interpolate(
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest" uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"