FIx Onnx export (#1963)

* Add files via upload

* Add files via upload

* Add files via upload
This commit is contained in:
Ναρουσέ·μ·γιουμεμί·Χινακάννα 2024-04-14 13:55:15 +08:00 committed by GitHub
parent 9cae044bf2
commit 597e0a97d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 173 additions and 119 deletions

View File

@ -1,5 +1,6 @@
import math import math
import logging import logging
from typing import Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -13,10 +14,13 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from infer.lib.infer_pack import attentions, commons, modules from infer.lib.infer_pack import attentions, commons, modules
from infer.lib.infer_pack.commons import get_padding, init_weights from infer.lib.infer_pack.commons import get_padding, init_weights
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class TextEncoder256(nn.Module):
class TextEncoder(nn.Module):
def __init__( def __init__(
self, self,
in_channels,
out_channels, out_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
@ -26,25 +30,36 @@ class TextEncoder256(nn.Module):
p_dropout, p_dropout,
f0=True, f0=True,
): ):
super().__init__() super(TextEncoder, self).__init__()
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
self.n_heads = n_heads self.n_heads = n_heads
self.n_layers = n_layers self.n_layers = n_layers
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = float(p_dropout)
self.emb_phone = nn.Linear(256, hidden_channels) self.emb_phone = nn.Linear(in_channels, hidden_channels)
self.lrelu = nn.LeakyReLU(0.1, inplace=True) self.lrelu = nn.LeakyReLU(0.1, inplace=True)
if f0 == True: if f0 == True:
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
self.encoder = attentions.Encoder( self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
float(p_dropout),
) )
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, phone, pitch, lengths): def forward(
if pitch == None: self,
phone: torch.Tensor,
pitch: torch.Tensor,
lengths: torch.Tensor,
skip_head: Optional[torch.Tensor] = None,
):
if pitch is None:
x = self.emb_phone(phone) x = self.emb_phone(phone)
else: else:
x = self.emb_phone(phone) + self.emb_pitch(pitch) x = self.emb_phone(phone) + self.emb_pitch(pitch)
@ -56,54 +71,6 @@ class TextEncoder256(nn.Module):
) )
x = self.encoder(x * x_mask, x_mask) x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask
class TextEncoder768(nn.Module):
def __init__(
self,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
f0=True,
):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb_phone = nn.Linear(768, hidden_channels)
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
if f0 == True:
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, phone, pitch, lengths):
if pitch == None:
x = self.emb_phone(phone)
else:
x = self.emb_phone(phone) + self.emb_pitch(pitch)
x = x * math.sqrt(self.hidden_channels) # [b, t, h]
x = self.lrelu(x)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
x.dtype
)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask return m, logs, x_mask
@ -119,7 +86,7 @@ class ResidualCouplingBlock(nn.Module):
n_flows=4, n_flows=4,
gin_channels=0, gin_channels=0,
): ):
super().__init__() super(ResidualCouplingBlock, self).__init__()
self.channels = channels self.channels = channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -143,19 +110,36 @@ class ResidualCouplingBlock(nn.Module):
) )
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False): def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
reverse: bool = False,
):
if not reverse: if not reverse:
for flow in self.flows: for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse) x, _ = flow(x, x_mask, g=g, reverse=reverse)
else: else:
for flow in reversed(self.flows): for flow in reversed(self.flows):
x, _ = flow(x, x_mask, g=g, reverse=reverse) x, _ = flow.forward(x, x_mask, g=g, reverse=reverse)
return x return x
def remove_weight_norm(self): def remove_weight_norm(self):
for i in range(self.n_flows): for i in range(self.n_flows):
self.flows[i * 2].remove_weight_norm() self.flows[i * 2].remove_weight_norm()
def __prepare_scriptable__(self):
for i in range(self.n_flows):
for hook in self.flows[i * 2]._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
return self
class PosteriorEncoder(nn.Module): class PosteriorEncoder(nn.Module):
def __init__( def __init__(
@ -168,7 +152,7 @@ class PosteriorEncoder(nn.Module):
n_layers, n_layers,
gin_channels=0, gin_channels=0,
): ):
super().__init__() super(PosteriorEncoder, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -187,7 +171,9 @@ class PosteriorEncoder(nn.Module):
) )
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None): def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype x.dtype
) )
@ -201,6 +187,15 @@ class PosteriorEncoder(nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
self.enc.remove_weight_norm() self.enc.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.enc._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc)
return self
class Generator(torch.nn.Module): class Generator(torch.nn.Module):
def __init__( def __init__(
@ -250,7 +245,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None): def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: if g is not None:
x = x + self.cond(g) x = x + self.cond(g)
@ -271,6 +266,28 @@ class Generator(torch.nn.Module):
return x return x
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.ups: for l in self.ups:
remove_weight_norm(l) remove_weight_norm(l)
@ -291,7 +308,7 @@ class SineGen(torch.nn.Module):
voiced_thoreshold: F0 threshold for U/V classification (default 0) voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False) flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0) segment is always sin(torch.pi) or cos(0)
""" """
def __init__( def __init__(
@ -315,9 +332,11 @@ class SineGen(torch.nn.Module):
# generate uv signal # generate uv signal
uv = torch.ones_like(f0) uv = torch.ones_like(f0)
uv = uv * (f0 > self.voiced_threshold) uv = uv * (f0 > self.voiced_threshold)
if uv.device.type == "privateuseone": # for DirectML
uv = uv.float()
return uv return uv
def forward(self, f0, upp): def forward(self, f0: torch.Tensor, upp: int):
"""sine_tensor, uv = forward(f0) """sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1) input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0 f0 for unvoiced steps should be 0
@ -329,7 +348,7 @@ class SineGen(torch.nn.Module):
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
# fundamental component # fundamental component
f0_buf[:, :, 0] = f0[:, :, 0] f0_buf[:, :, 0] = f0[:, :, 0]
for idx in np.arange(self.harmonic_num): for idx in range(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
idx + 2 idx + 2
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
@ -347,12 +366,12 @@ class SineGen(torch.nn.Module):
tmp_over_one *= upp tmp_over_one *= upp
tmp_over_one = F.interpolate( tmp_over_one = F.interpolate(
tmp_over_one.transpose(2, 1), tmp_over_one.transpose(2, 1),
scale_factor=upp, scale_factor=float(upp),
mode="linear", mode="linear",
align_corners=True, align_corners=True,
).transpose(2, 1) ).transpose(2, 1)
rad_values = F.interpolate( rad_values = F.interpolate(
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
).transpose( ).transpose(
2, 1 2, 1
) ####### ) #######
@ -361,12 +380,12 @@ class SineGen(torch.nn.Module):
cumsum_shift = torch.zeros_like(rad_values) cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sine_waves = torch.sin( sine_waves = torch.sin(
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
) )
sine_waves = sine_waves * self.sine_amp sine_waves = sine_waves * self.sine_amp
uv = self._f02uv(f0) uv = self._f02uv(f0)
uv = F.interpolate( uv = F.interpolate(
uv.transpose(2, 1), scale_factor=upp, mode="nearest" uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
).transpose(2, 1) ).transpose(2, 1)
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves) noise = noise_amp * torch.randn_like(sine_waves)
@ -414,11 +433,19 @@ class SourceModuleHnNSF(torch.nn.Module):
# to merge source harmonics into a single excitation # to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh() self.l_tanh = torch.nn.Tanh()
# self.ddtype:int = -1
def forward(self, x, upp=None): def forward(self, x: torch.Tensor, upp: int = 1):
# if self.ddtype ==-1:
# self.ddtype = self.l_linear.weight.dtype
sine_wavs, uv, _ = self.l_sin_gen(x, upp) sine_wavs, uv, _ = self.l_sin_gen(x, upp)
if self.is_half: # print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype)
sine_wavs = sine_wavs.half() # if self.is_half:
# sine_wavs = sine_wavs.half()
# sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x)))
# print(sine_wavs.dtype,self.ddtype)
# if sine_wavs.dtype != self.l_linear.weight.dtype:
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
sine_merge = self.l_tanh(self.l_linear(sine_wavs)) sine_merge = self.l_tanh(self.l_linear(sine_wavs))
return sine_merge, None, None # noise, uv return sine_merge, None, None # noise, uv
@ -441,7 +468,7 @@ class GeneratorNSF(torch.nn.Module):
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates)) self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
self.m_source = SourceModuleHnNSF( self.m_source = SourceModuleHnNSF(
sampling_rate=sr, harmonic_num=0, is_half=is_half sampling_rate=sr, harmonic_num=0, is_half=is_half
) )
@ -466,7 +493,7 @@ class GeneratorNSF(torch.nn.Module):
) )
) )
if i + 1 < len(upsample_rates): if i + 1 < len(upsample_rates):
stride_f0 = np.prod(upsample_rates[i + 1 :]) stride_f0 = math.prod(upsample_rates[i + 1 :])
self.noise_convs.append( self.noise_convs.append(
Conv1d( Conv1d(
1, 1,
@ -493,26 +520,35 @@ class GeneratorNSF(torch.nn.Module):
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
self.upp = np.prod(upsample_rates) self.upp = math.prod(upsample_rates)
def forward(self, x, f0, g=None): self.lrelu_slope = modules.LRELU_SLOPE
def forward(self, x, f0, g: Optional[torch.Tensor] = None):
har_source, noi_source, uv = self.m_source(f0, self.upp) har_source, noi_source, uv = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2) har_source = har_source.transpose(1, 2)
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: if g is not None:
x = x + self.cond(g) x = x + self.cond(g)
# torch.jit.script() does not support direct indexing of torch modules
for i in range(self.num_upsamples): # That's why I wrote this
x = F.leaky_relu(x, modules.LRELU_SLOPE) for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
x = self.ups[i](x) if i < self.num_upsamples:
x_source = self.noise_convs[i](har_source) x = F.leaky_relu(x, self.lrelu_slope)
x = ups(x)
x_source = noise_convs(har_source)
x = x + x_source x = x + x_source
xs = None xs: Optional[torch.Tensor] = None
for j in range(self.num_kernels): l = [i * self.num_kernels + j for j in range(self.num_kernels)]
for j, resblock in enumerate(self.resblocks):
if j in l:
if xs is None: if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x) xs = resblock(x)
else: else:
xs += self.resblocks[i * self.num_kernels + j](x) xs += resblock(x)
# This assertion cannot be ignored! \
# If ignored, it will cause torch.jit.script() compilation errors
assert isinstance(xs, torch.Tensor)
x = xs / self.num_kernels x = xs / self.num_kernels
x = F.leaky_relu(x) x = F.leaky_relu(x)
x = self.conv_post(x) x = self.conv_post(x)
@ -525,6 +561,27 @@ class GeneratorNSF(torch.nn.Module):
for l in self.resblocks: for l in self.resblocks:
l.remove_weight_norm() l.remove_weight_norm()
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
for l in self.resblocks:
for hook in self.resblocks._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self
sr2sr = { sr2sr = {
"32k": 32000, "32k": 32000,
@ -554,11 +611,11 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
spk_embed_dim, spk_embed_dim,
gin_channels, gin_channels,
sr, sr,
version, encoder_dim,
**kwargs, **kwargs
): ):
super().__init__() super(SynthesizerTrnMsNSFsidM, self).__init__()
if type(sr) == type("strr"): if isinstance(sr, str):
sr = sr2sr[sr] sr = sr2sr[sr]
self.spec_channels = spec_channels self.spec_channels = spec_channels
self.inter_channels = inter_channels self.inter_channels = inter_channels
@ -567,7 +624,7 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
self.n_heads = n_heads self.n_heads = n_heads
self.n_layers = n_layers self.n_layers = n_layers
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = float(p_dropout)
self.resblock = resblock self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes self.resblock_dilation_sizes = resblock_dilation_sizes
@ -578,25 +635,15 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
# self.hop_length = hop_length# # self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim self.spk_embed_dim = spk_embed_dim
if version == "v1": self.enc_p = TextEncoder(
self.enc_p = TextEncoder256( encoder_dim,
inter_channels, inter_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, float(p_dropout),
)
else:
self.enc_p = TextEncoder768(
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
) )
self.dec = GeneratorNSF( self.dec = GeneratorNSF(
inter_channels, inter_channels,
@ -623,9 +670,11 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
) )
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
self.speaker_map = None
logger.debug( logger.debug(
f"gin_channels: {gin_channels}, self.spk_embed_dim: {self.spk_embed_dim}" "gin_channels: "
+ str(gin_channels)
+ ", self.spk_embed_dim: "
+ str(self.spk_embed_dim)
) )
def remove_weight_norm(self): def remove_weight_norm(self):
@ -633,9 +682,9 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
self.flow.remove_weight_norm() self.flow.remove_weight_norm()
self.enc_q.remove_weight_norm() self.enc_q.remove_weight_norm()
def construct_spkmixmap(self, n_speaker): def construct_spkmixmap(self):
self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels)) self.speaker_map = torch.zeros((self.n_speaker, 1, 1, self.gin_channels))
for i in range(n_speaker): for i in range(self.n_speaker):
self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]])) self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
self.speaker_map = self.speaker_map.unsqueeze(0) self.speaker_map = self.speaker_map.unsqueeze(0)
@ -810,6 +859,11 @@ class DiscriminatorP(torch.nn.Module):
b, c, t = x.shape b, c, t = x.shape
if t % self.period != 0: # pad first if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period) n_pad = self.period - (t % self.period)
if has_xpu and x.dtype == torch.bfloat16:
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
dtype=torch.bfloat16
)
else:
x = F.pad(x, (0, n_pad), "reflect") x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad t = t + n_pad
x = x.view(b, c, t // self.period, self.period) x = x.view(b, c, t // self.period, self.period)

View File

@ -18,7 +18,7 @@ def export_onnx(ModelPath, ExportedPath):
device = "cpu" # 导出时设备(不影响使用模型) device = "cpu" # 导出时设备(不影响使用模型)
net_g = SynthesizerTrnMsNSFsidM( net_g = SynthesizerTrnMsNSFsidM(
*cpt["config"], is_half=False, version=cpt.get("version", "v1") *cpt["config"], is_half=False, encoder_dim=vec_channels
) # fp32导出C++要支持fp16必须手动将内存重新排列所以暂时不用fp16 ) # fp32导出C++要支持fp16必须手动将内存重新排列所以暂时不用fp16
net_g.load_state_dict(cpt["weight"], strict=False) net_g.load_state_dict(cpt["weight"], strict=False)
input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"] input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"]
@ -44,8 +44,8 @@ def export_onnx(ModelPath, ExportedPath):
"rnd": [2], "rnd": [2],
}, },
do_constant_folding=False, do_constant_folding=False,
opset_version=13, opset_version=18,
verbose=False, verbose=True,
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
) )

View File

@ -6,12 +6,12 @@ if __name__ == "__main__":
ModelPath = "Shiroha/shiroha.pth" # 模型路径 ModelPath = "Shiroha/shiroha.pth" # 模型路径
ExportedPath = "model.onnx" # 输出路径 ExportedPath = "model.onnx" # 输出路径
hidden_channels = 256 # hidden_channels为768Vec做准备 encoder_dim = 256 # encoder_dim
cpt = torch.load(ModelPath, map_location="cpu") cpt = torch.load(ModelPath, map_location="cpu")
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
print(*cpt["config"]) print(*cpt["config"])
test_phone = torch.rand(1, 200, hidden_channels) # hidden unit test_phone = torch.rand(1, 200, encoder_dim) # hidden unit
test_phone_lengths = torch.tensor([200]).long() # hidden unit 长度(貌似没啥用) test_phone_lengths = torch.tensor([200]).long() # hidden unit 长度(貌似没啥用)
test_pitch = torch.randint(size=(1, 200), low=5, high=255) # 基频(单位赫兹) test_pitch = torch.randint(size=(1, 200), low=5, high=255) # 基频(单位赫兹)
test_pitchf = torch.rand(1, 200) # nsf基频 test_pitchf = torch.rand(1, 200) # nsf基频
@ -21,7 +21,7 @@ if __name__ == "__main__":
device = "cpu" # 导出时设备(不影响使用模型) device = "cpu" # 导出时设备(不影响使用模型)
net_g = SynthesizerTrnMsNSFsidM( net_g = SynthesizerTrnMsNSFsidM(
*cpt["config"], is_half=False *cpt["config"], is_half=False, encoder_dim = encoder_dim
) # fp32导出C++要支持fp16必须手动将内存重新排列所以暂时不用fp16 ) # fp32导出C++要支持fp16必须手动将内存重新排列所以暂时不用fp16
net_g.load_state_dict(cpt["weight"], strict=False) net_g.load_state_dict(cpt["weight"], strict=False)
input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"] input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"]
@ -47,7 +47,7 @@ if __name__ == "__main__":
"rnd": [2], "rnd": [2],
}, },
do_constant_folding=False, do_constant_folding=False,
opset_version=16, opset_version=18,
verbose=False, verbose=False,
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,