From 597e0a97d015561d304ceab522d608e2ee09174d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Sun, 14 Apr 2024 13:55:15 +0800 Subject: [PATCH] FIx Onnx export (#1963) * Add files via upload * Add files via upload * Add files via upload --- infer/lib/infer_pack/models_onnx.py | 278 +++++++++++++++++----------- infer/modules/onnx/export.py | 6 +- tools/export_onnx.py | 8 +- 3 files changed, 173 insertions(+), 119 deletions(-) diff --git a/infer/lib/infer_pack/models_onnx.py b/infer/lib/infer_pack/models_onnx.py index a6d321f..f1b846a 100644 --- a/infer/lib/infer_pack/models_onnx.py +++ b/infer/lib/infer_pack/models_onnx.py @@ -1,5 +1,6 @@ import math import logging +from typing import Optional logger = logging.getLogger(__name__) @@ -13,10 +14,13 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm from infer.lib.infer_pack import attentions, commons, modules from infer.lib.infer_pack.commons import get_padding, init_weights +has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) -class TextEncoder256(nn.Module): + +class TextEncoder(nn.Module): def __init__( self, + in_channels, out_channels, hidden_channels, filter_channels, @@ -26,25 +30,36 @@ class TextEncoder256(nn.Module): p_dropout, f0=True, ): - super().__init__() + super(TextEncoder, self).__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.emb_phone = nn.Linear(256, hidden_channels) + self.p_dropout = float(p_dropout) + self.emb_phone = nn.Linear(in_channels, hidden_channels) self.lrelu = nn.LeakyReLU(0.1, inplace=True) if f0 == True: self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 self.encoder = attentions.Encoder( - hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + float(p_dropout), ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, phone, pitch, lengths): - if pitch == None: + def forward( + self, + phone: torch.Tensor, + pitch: torch.Tensor, + lengths: torch.Tensor, + skip_head: Optional[torch.Tensor] = None, + ): + if pitch is None: x = self.emb_phone(phone) else: x = self.emb_phone(phone) + self.emb_pitch(pitch) @@ -56,54 +71,6 @@ class TextEncoder256(nn.Module): ) x = self.encoder(x * x_mask, x_mask) stats = self.proj(x) * x_mask - - m, logs = torch.split(stats, self.out_channels, dim=1) - return m, logs, x_mask - - -class TextEncoder768(nn.Module): - def __init__( - self, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - f0=True, - ): - super().__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.emb_phone = nn.Linear(768, hidden_channels) - self.lrelu = nn.LeakyReLU(0.1, inplace=True) - if f0 == True: - self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 - self.encoder = attentions.Encoder( - hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, phone, pitch, lengths): - if pitch == None: - x = self.emb_phone(phone) - else: - x = self.emb_phone(phone) + self.emb_pitch(pitch) - x = x * math.sqrt(self.hidden_channels) # [b, t, h] - x = self.lrelu(x) - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( - x.dtype - ) - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs, x_mask @@ -119,7 +86,7 @@ class ResidualCouplingBlock(nn.Module): n_flows=4, gin_channels=0, ): - super().__init__() + super(ResidualCouplingBlock, self).__init__() self.channels = channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size @@ -143,19 +110,36 @@ class ResidualCouplingBlock(nn.Module): ) self.flows.append(modules.Flip()) - def forward(self, x, x_mask, g=None, reverse=False): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ): if not reverse: for flow in self.flows: x, _ = flow(x, x_mask, g=g, reverse=reverse) else: for flow in reversed(self.flows): - x, _ = flow(x, x_mask, g=g, reverse=reverse) + x, _ = flow.forward(x, x_mask, g=g, reverse=reverse) return x def remove_weight_norm(self): for i in range(self.n_flows): self.flows[i * 2].remove_weight_norm() + def __prepare_scriptable__(self): + for i in range(self.n_flows): + for hook in self.flows[i * 2]._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.weight_norm" + and hook.__class__.__name__ == "WeightNorm" + ): + torch.nn.utils.remove_weight_norm(self.flows[i * 2]) + + return self + class PosteriorEncoder(nn.Module): def __init__( @@ -168,7 +152,7 @@ class PosteriorEncoder(nn.Module): n_layers, gin_channels=0, ): - super().__init__() + super(PosteriorEncoder, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels @@ -187,7 +171,9 @@ class PosteriorEncoder(nn.Module): ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, g=None): + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ): x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x.dtype ) @@ -201,6 +187,15 @@ class PosteriorEncoder(nn.Module): def remove_weight_norm(self): self.enc.remove_weight_norm() + def __prepare_scriptable__(self): + for hook in self.enc._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.weight_norm" + and hook.__class__.__name__ == "WeightNorm" + ): + torch.nn.utils.remove_weight_norm(self.enc) + return self + class Generator(torch.nn.Module): def __init__( @@ -250,7 +245,7 @@ class Generator(torch.nn.Module): if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - def forward(self, x, g=None): + def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -271,6 +266,28 @@ class Generator(torch.nn.Module): return x + def __prepare_scriptable__(self): + for l in self.ups: + for hook in l._forward_pre_hooks.values(): + # The hook we want to remove is an instance of WeightNorm class, so + # normally we would do `if isinstance(...)` but this class is not accessible + # because of shadowing, so we check the module name directly. + # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 + if ( + hook.__module__ == "torch.nn.utils.weight_norm" + and hook.__class__.__name__ == "WeightNorm" + ): + torch.nn.utils.remove_weight_norm(l) + + for l in self.resblocks: + for hook in l._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.weight_norm" + and hook.__class__.__name__ == "WeightNorm" + ): + torch.nn.utils.remove_weight_norm(l) + return self + def remove_weight_norm(self): for l in self.ups: remove_weight_norm(l) @@ -291,7 +308,7 @@ class SineGen(torch.nn.Module): voiced_thoreshold: F0 threshold for U/V classification (default 0) flag_for_pulse: this SinGen is used inside PulseGen (default False) Note: when flag_for_pulse is True, the first time step of a voiced - segment is always sin(np.pi) or cos(0) + segment is always sin(torch.pi) or cos(0) """ def __init__( @@ -315,9 +332,11 @@ class SineGen(torch.nn.Module): # generate uv signal uv = torch.ones_like(f0) uv = uv * (f0 > self.voiced_threshold) + if uv.device.type == "privateuseone": # for DirectML + uv = uv.float() return uv - def forward(self, f0, upp): + def forward(self, f0: torch.Tensor, upp: int): """sine_tensor, uv = forward(f0) input F0: tensor(batchsize=1, length, dim=1) f0 for unvoiced steps should be 0 @@ -329,7 +348,7 @@ class SineGen(torch.nn.Module): f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) # fundamental component f0_buf[:, :, 0] = f0[:, :, 0] - for idx in np.arange(self.harmonic_num): + for idx in range(self.harmonic_num): f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( idx + 2 ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic @@ -347,12 +366,12 @@ class SineGen(torch.nn.Module): tmp_over_one *= upp tmp_over_one = F.interpolate( tmp_over_one.transpose(2, 1), - scale_factor=upp, + scale_factor=float(upp), mode="linear", align_corners=True, ).transpose(2, 1) rad_values = F.interpolate( - rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest" ).transpose( 2, 1 ) ####### @@ -361,12 +380,12 @@ class SineGen(torch.nn.Module): cumsum_shift = torch.zeros_like(rad_values) cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 sine_waves = torch.sin( - torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi ) sine_waves = sine_waves * self.sine_amp uv = self._f02uv(f0) uv = F.interpolate( - uv.transpose(2, 1), scale_factor=upp, mode="nearest" + uv.transpose(2, 1), scale_factor=float(upp), mode="nearest" ).transpose(2, 1) noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise = noise_amp * torch.randn_like(sine_waves) @@ -414,11 +433,19 @@ class SourceModuleHnNSF(torch.nn.Module): # to merge source harmonics into a single excitation self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_tanh = torch.nn.Tanh() + # self.ddtype:int = -1 - def forward(self, x, upp=None): + def forward(self, x: torch.Tensor, upp: int = 1): + # if self.ddtype ==-1: + # self.ddtype = self.l_linear.weight.dtype sine_wavs, uv, _ = self.l_sin_gen(x, upp) - if self.is_half: - sine_wavs = sine_wavs.half() + # print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype) + # if self.is_half: + # sine_wavs = sine_wavs.half() + # sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x))) + # print(sine_wavs.dtype,self.ddtype) + # if sine_wavs.dtype != self.l_linear.weight.dtype: + sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) sine_merge = self.l_tanh(self.l_linear(sine_wavs)) return sine_merge, None, None # noise, uv @@ -441,7 +468,7 @@ class GeneratorNSF(torch.nn.Module): self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) - self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates)) + self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates)) self.m_source = SourceModuleHnNSF( sampling_rate=sr, harmonic_num=0, is_half=is_half ) @@ -466,7 +493,7 @@ class GeneratorNSF(torch.nn.Module): ) ) if i + 1 < len(upsample_rates): - stride_f0 = np.prod(upsample_rates[i + 1 :]) + stride_f0 = math.prod(upsample_rates[i + 1 :]) self.noise_convs.append( Conv1d( 1, @@ -493,27 +520,36 @@ class GeneratorNSF(torch.nn.Module): if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - self.upp = np.prod(upsample_rates) + self.upp = math.prod(upsample_rates) - def forward(self, x, f0, g=None): + self.lrelu_slope = modules.LRELU_SLOPE + + def forward(self, x, f0, g: Optional[torch.Tensor] = None): har_source, noi_source, uv = self.m_source(f0, self.upp) har_source = har_source.transpose(1, 2) x = self.conv_pre(x) if g is not None: x = x + self.cond(g) - - for i in range(self.num_upsamples): - x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) - x_source = self.noise_convs[i](har_source) - x = x + x_source - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels + # torch.jit.script() does not support direct indexing of torch modules + # That's why I wrote this + for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)): + if i < self.num_upsamples: + x = F.leaky_relu(x, self.lrelu_slope) + x = ups(x) + x_source = noise_convs(har_source) + x = x + x_source + xs: Optional[torch.Tensor] = None + l = [i * self.num_kernels + j for j in range(self.num_kernels)] + for j, resblock in enumerate(self.resblocks): + if j in l: + if xs is None: + xs = resblock(x) + else: + xs += resblock(x) + # This assertion cannot be ignored! \ + # If ignored, it will cause torch.jit.script() compilation errors + assert isinstance(xs, torch.Tensor) + x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) x = torch.tanh(x) @@ -525,6 +561,27 @@ class GeneratorNSF(torch.nn.Module): for l in self.resblocks: l.remove_weight_norm() + def __prepare_scriptable__(self): + for l in self.ups: + for hook in l._forward_pre_hooks.values(): + # The hook we want to remove is an instance of WeightNorm class, so + # normally we would do `if isinstance(...)` but this class is not accessible + # because of shadowing, so we check the module name directly. + # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 + if ( + hook.__module__ == "torch.nn.utils.weight_norm" + and hook.__class__.__name__ == "WeightNorm" + ): + torch.nn.utils.remove_weight_norm(l) + for l in self.resblocks: + for hook in self.resblocks._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.weight_norm" + and hook.__class__.__name__ == "WeightNorm" + ): + torch.nn.utils.remove_weight_norm(l) + return self + sr2sr = { "32k": 32000, @@ -554,11 +611,11 @@ class SynthesizerTrnMsNSFsidM(nn.Module): spk_embed_dim, gin_channels, sr, - version, - **kwargs, + encoder_dim, + **kwargs ): - super().__init__() - if type(sr) == type("strr"): + super(SynthesizerTrnMsNSFsidM, self).__init__() + if isinstance(sr, str): sr = sr2sr[sr] self.spec_channels = spec_channels self.inter_channels = inter_channels @@ -567,7 +624,7 @@ class SynthesizerTrnMsNSFsidM(nn.Module): self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size - self.p_dropout = p_dropout + self.p_dropout = float(p_dropout) self.resblock = resblock self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes @@ -578,25 +635,15 @@ class SynthesizerTrnMsNSFsidM(nn.Module): self.gin_channels = gin_channels # self.hop_length = hop_length# self.spk_embed_dim = spk_embed_dim - if version == "v1": - self.enc_p = TextEncoder256( + self.enc_p = TextEncoder( + encoder_dim, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, - p_dropout, - ) - else: - self.enc_p = TextEncoder768( - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, + float(p_dropout), ) self.dec = GeneratorNSF( inter_channels, @@ -623,9 +670,11 @@ class SynthesizerTrnMsNSFsidM(nn.Module): inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) - self.speaker_map = None logger.debug( - f"gin_channels: {gin_channels}, self.spk_embed_dim: {self.spk_embed_dim}" + "gin_channels: " + + str(gin_channels) + + ", self.spk_embed_dim: " + + str(self.spk_embed_dim) ) def remove_weight_norm(self): @@ -633,9 +682,9 @@ class SynthesizerTrnMsNSFsidM(nn.Module): self.flow.remove_weight_norm() self.enc_q.remove_weight_norm() - def construct_spkmixmap(self, n_speaker): - self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels)) - for i in range(n_speaker): + def construct_spkmixmap(self): + self.speaker_map = torch.zeros((self.n_speaker, 1, 1, self.gin_channels)) + for i in range(self.n_speaker): self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]])) self.speaker_map = self.speaker_map.unsqueeze(0) @@ -810,7 +859,12 @@ class DiscriminatorP(torch.nn.Module): b, c, t = x.shape if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") + if has_xpu and x.dtype == torch.bfloat16: + x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to( + dtype=torch.bfloat16 + ) + else: + x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) diff --git a/infer/modules/onnx/export.py b/infer/modules/onnx/export.py index ed4a416..7c95ae7 100644 --- a/infer/modules/onnx/export.py +++ b/infer/modules/onnx/export.py @@ -18,7 +18,7 @@ def export_onnx(ModelPath, ExportedPath): device = "cpu" # 导出时设备(不影响使用模型) net_g = SynthesizerTrnMsNSFsidM( - *cpt["config"], is_half=False, version=cpt.get("version", "v1") + *cpt["config"], is_half=False, encoder_dim=vec_channels ) # fp32导出(C++要支持fp16必须手动将内存重新排列所以暂时不用fp16) net_g.load_state_dict(cpt["weight"], strict=False) input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"] @@ -44,8 +44,8 @@ def export_onnx(ModelPath, ExportedPath): "rnd": [2], }, do_constant_folding=False, - opset_version=13, - verbose=False, + opset_version=18, + verbose=True, input_names=input_names, output_names=output_names, ) diff --git a/tools/export_onnx.py b/tools/export_onnx.py index 822e09e..9380aa7 100644 --- a/tools/export_onnx.py +++ b/tools/export_onnx.py @@ -6,12 +6,12 @@ if __name__ == "__main__": ModelPath = "Shiroha/shiroha.pth" # 模型路径 ExportedPath = "model.onnx" # 输出路径 - hidden_channels = 256 # hidden_channels,为768Vec做准备 + encoder_dim = 256 # encoder_dim cpt = torch.load(ModelPath, map_location="cpu") cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk print(*cpt["config"]) - test_phone = torch.rand(1, 200, hidden_channels) # hidden unit + test_phone = torch.rand(1, 200, encoder_dim) # hidden unit test_phone_lengths = torch.tensor([200]).long() # hidden unit 长度(貌似没啥用) test_pitch = torch.randint(size=(1, 200), low=5, high=255) # 基频(单位赫兹) test_pitchf = torch.rand(1, 200) # nsf基频 @@ -21,7 +21,7 @@ if __name__ == "__main__": device = "cpu" # 导出时设备(不影响使用模型) net_g = SynthesizerTrnMsNSFsidM( - *cpt["config"], is_half=False + *cpt["config"], is_half=False, encoder_dim = encoder_dim ) # fp32导出(C++要支持fp16必须手动将内存重新排列所以暂时不用fp16) net_g.load_state_dict(cpt["weight"], strict=False) input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"] @@ -47,7 +47,7 @@ if __name__ == "__main__": "rnd": [2], }, do_constant_folding=False, - opset_version=16, + opset_version=18, verbose=False, input_names=input_names, output_names=output_names,