mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-05-06 20:01:37 +08:00
FIx Onnx export (#1963)
* Add files via upload * Add files via upload * Add files via upload
This commit is contained in:
parent
9cae044bf2
commit
597e0a97d0
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -13,10 +14,13 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
from infer.lib.infer_pack import attentions, commons, modules
|
||||
from infer.lib.infer_pack.commons import get_padding, init_weights
|
||||
|
||||
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
|
||||
class TextEncoder256(nn.Module):
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
@ -26,25 +30,36 @@ class TextEncoder256(nn.Module):
|
||||
p_dropout,
|
||||
f0=True,
|
||||
):
|
||||
super().__init__()
|
||||
super(TextEncoder, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.emb_phone = nn.Linear(256, hidden_channels)
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.emb_phone = nn.Linear(in_channels, hidden_channels)
|
||||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
||||
if f0 == True:
|
||||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, phone, pitch, lengths):
|
||||
if pitch == None:
|
||||
def forward(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
skip_head: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if pitch is None:
|
||||
x = self.emb_phone(phone)
|
||||
else:
|
||||
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
||||
@ -56,54 +71,6 @@ class TextEncoder256(nn.Module):
|
||||
)
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return m, logs, x_mask
|
||||
|
||||
|
||||
class TextEncoder768(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
f0=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.emb_phone = nn.Linear(768, hidden_channels)
|
||||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
||||
if f0 == True:
|
||||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, phone, pitch, lengths):
|
||||
if pitch == None:
|
||||
x = self.emb_phone(phone)
|
||||
else:
|
||||
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
||||
x = x * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
x = self.lrelu(x)
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return m, logs, x_mask
|
||||
|
||||
@ -119,7 +86,7 @@ class ResidualCouplingBlock(nn.Module):
|
||||
n_flows=4,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
super(ResidualCouplingBlock, self).__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
@ -143,19 +110,36 @@ class ResidualCouplingBlock(nn.Module):
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse: bool = False,
|
||||
):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
x, _ = flow.forward(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for i in range(self.n_flows):
|
||||
self.flows[i * 2].remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for i in range(self.n_flows):
|
||||
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(
|
||||
@ -168,7 +152,7 @@ class PosteriorEncoder(nn.Module):
|
||||
n_layers,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
super(PosteriorEncoder, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -187,7 +171,9 @@ class PosteriorEncoder(nn.Module):
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
@ -201,6 +187,15 @@ class PosteriorEncoder(nn.Module):
|
||||
def remove_weight_norm(self):
|
||||
self.enc.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.enc._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc)
|
||||
return self
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(
|
||||
@ -250,7 +245,7 @@ class Generator(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
@ -271,6 +266,28 @@ class Generator(torch.nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.ups:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
for l in self.resblocks:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
@ -291,7 +308,7 @@ class SineGen(torch.nn.Module):
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
segment is always sin(torch.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -315,9 +332,11 @@ class SineGen(torch.nn.Module):
|
||||
# generate uv signal
|
||||
uv = torch.ones_like(f0)
|
||||
uv = uv * (f0 > self.voiced_threshold)
|
||||
if uv.device.type == "privateuseone": # for DirectML
|
||||
uv = uv.float()
|
||||
return uv
|
||||
|
||||
def forward(self, f0, upp):
|
||||
def forward(self, f0: torch.Tensor, upp: int):
|
||||
"""sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
@ -329,7 +348,7 @@ class SineGen(torch.nn.Module):
|
||||
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
||||
# fundamental component
|
||||
f0_buf[:, :, 0] = f0[:, :, 0]
|
||||
for idx in np.arange(self.harmonic_num):
|
||||
for idx in range(self.harmonic_num):
|
||||
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
|
||||
idx + 2
|
||||
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
||||
@ -347,12 +366,12 @@ class SineGen(torch.nn.Module):
|
||||
tmp_over_one *= upp
|
||||
tmp_over_one = F.interpolate(
|
||||
tmp_over_one.transpose(2, 1),
|
||||
scale_factor=upp,
|
||||
scale_factor=float(upp),
|
||||
mode="linear",
|
||||
align_corners=True,
|
||||
).transpose(2, 1)
|
||||
rad_values = F.interpolate(
|
||||
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
|
||||
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
|
||||
).transpose(
|
||||
2, 1
|
||||
) #######
|
||||
@ -361,12 +380,12 @@ class SineGen(torch.nn.Module):
|
||||
cumsum_shift = torch.zeros_like(rad_values)
|
||||
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
sine_waves = torch.sin(
|
||||
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
|
||||
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
|
||||
)
|
||||
sine_waves = sine_waves * self.sine_amp
|
||||
uv = self._f02uv(f0)
|
||||
uv = F.interpolate(
|
||||
uv.transpose(2, 1), scale_factor=upp, mode="nearest"
|
||||
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
|
||||
).transpose(2, 1)
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
@ -414,11 +433,19 @@ class SourceModuleHnNSF(torch.nn.Module):
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
# self.ddtype:int = -1
|
||||
|
||||
def forward(self, x, upp=None):
|
||||
def forward(self, x: torch.Tensor, upp: int = 1):
|
||||
# if self.ddtype ==-1:
|
||||
# self.ddtype = self.l_linear.weight.dtype
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
|
||||
if self.is_half:
|
||||
sine_wavs = sine_wavs.half()
|
||||
# print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype)
|
||||
# if self.is_half:
|
||||
# sine_wavs = sine_wavs.half()
|
||||
# sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x)))
|
||||
# print(sine_wavs.dtype,self.ddtype)
|
||||
# if sine_wavs.dtype != self.l_linear.weight.dtype:
|
||||
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
return sine_merge, None, None # noise, uv
|
||||
|
||||
@ -441,7 +468,7 @@ class GeneratorNSF(torch.nn.Module):
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sr, harmonic_num=0, is_half=is_half
|
||||
)
|
||||
@ -466,7 +493,7 @@ class GeneratorNSF(torch.nn.Module):
|
||||
)
|
||||
)
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
||||
stride_f0 = math.prod(upsample_rates[i + 1 :])
|
||||
self.noise_convs.append(
|
||||
Conv1d(
|
||||
1,
|
||||
@ -493,27 +520,36 @@ class GeneratorNSF(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
self.upp = np.prod(upsample_rates)
|
||||
self.upp = math.prod(upsample_rates)
|
||||
|
||||
def forward(self, x, f0, g=None):
|
||||
self.lrelu_slope = modules.LRELU_SLOPE
|
||||
|
||||
def forward(self, x, f0, g: Optional[torch.Tensor] = None):
|
||||
har_source, noi_source, uv = self.m_source(f0, self.upp)
|
||||
har_source = har_source.transpose(1, 2)
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
x_source = self.noise_convs[i](har_source)
|
||||
x = x + x_source
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
# torch.jit.script() does not support direct indexing of torch modules
|
||||
# That's why I wrote this
|
||||
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
||||
if i < self.num_upsamples:
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = ups(x)
|
||||
x_source = noise_convs(har_source)
|
||||
x = x + x_source
|
||||
xs: Optional[torch.Tensor] = None
|
||||
l = [i * self.num_kernels + j for j in range(self.num_kernels)]
|
||||
for j, resblock in enumerate(self.resblocks):
|
||||
if j in l:
|
||||
if xs is None:
|
||||
xs = resblock(x)
|
||||
else:
|
||||
xs += resblock(x)
|
||||
# This assertion cannot be ignored! \
|
||||
# If ignored, it will cause torch.jit.script() compilation errors
|
||||
assert isinstance(xs, torch.Tensor)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
@ -525,6 +561,27 @@ class GeneratorNSF(torch.nn.Module):
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.ups:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
for hook in self.resblocks._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
|
||||
sr2sr = {
|
||||
"32k": 32000,
|
||||
@ -554,11 +611,11 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
|
||||
spk_embed_dim,
|
||||
gin_channels,
|
||||
sr,
|
||||
version,
|
||||
**kwargs,
|
||||
encoder_dim,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if type(sr) == type("strr"):
|
||||
super(SynthesizerTrnMsNSFsidM, self).__init__()
|
||||
if isinstance(sr, str):
|
||||
sr = sr2sr[sr]
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@ -567,7 +624,7 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@ -578,25 +635,15 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
# self.hop_length = hop_length#
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
if version == "v1":
|
||||
self.enc_p = TextEncoder256(
|
||||
self.enc_p = TextEncoder(
|
||||
encoder_dim,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
)
|
||||
else:
|
||||
self.enc_p = TextEncoder768(
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.dec = GeneratorNSF(
|
||||
inter_channels,
|
||||
@ -623,9 +670,11 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
|
||||
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
||||
)
|
||||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
||||
self.speaker_map = None
|
||||
logger.debug(
|
||||
f"gin_channels: {gin_channels}, self.spk_embed_dim: {self.spk_embed_dim}"
|
||||
"gin_channels: "
|
||||
+ str(gin_channels)
|
||||
+ ", self.spk_embed_dim: "
|
||||
+ str(self.spk_embed_dim)
|
||||
)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
@ -633,9 +682,9 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def construct_spkmixmap(self, n_speaker):
|
||||
self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels))
|
||||
for i in range(n_speaker):
|
||||
def construct_spkmixmap(self):
|
||||
self.speaker_map = torch.zeros((self.n_speaker, 1, 1, self.gin_channels))
|
||||
for i in range(self.n_speaker):
|
||||
self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
|
||||
self.speaker_map = self.speaker_map.unsqueeze(0)
|
||||
|
||||
@ -810,7 +859,12 @@ class DiscriminatorP(torch.nn.Module):
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
if has_xpu and x.dtype == torch.bfloat16:
|
||||
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
|
||||
dtype=torch.bfloat16
|
||||
)
|
||||
else:
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
|
@ -18,7 +18,7 @@ def export_onnx(ModelPath, ExportedPath):
|
||||
device = "cpu" # 导出时设备(不影响使用模型)
|
||||
|
||||
net_g = SynthesizerTrnMsNSFsidM(
|
||||
*cpt["config"], is_half=False, version=cpt.get("version", "v1")
|
||||
*cpt["config"], is_half=False, encoder_dim=vec_channels
|
||||
) # fp32导出(C++要支持fp16必须手动将内存重新排列所以暂时不用fp16)
|
||||
net_g.load_state_dict(cpt["weight"], strict=False)
|
||||
input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"]
|
||||
@ -44,8 +44,8 @@ def export_onnx(ModelPath, ExportedPath):
|
||||
"rnd": [2],
|
||||
},
|
||||
do_constant_folding=False,
|
||||
opset_version=13,
|
||||
verbose=False,
|
||||
opset_version=18,
|
||||
verbose=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
)
|
||||
|
@ -6,12 +6,12 @@ if __name__ == "__main__":
|
||||
|
||||
ModelPath = "Shiroha/shiroha.pth" # 模型路径
|
||||
ExportedPath = "model.onnx" # 输出路径
|
||||
hidden_channels = 256 # hidden_channels,为768Vec做准备
|
||||
encoder_dim = 256 # encoder_dim
|
||||
cpt = torch.load(ModelPath, map_location="cpu")
|
||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
||||
print(*cpt["config"])
|
||||
|
||||
test_phone = torch.rand(1, 200, hidden_channels) # hidden unit
|
||||
test_phone = torch.rand(1, 200, encoder_dim) # hidden unit
|
||||
test_phone_lengths = torch.tensor([200]).long() # hidden unit 长度(貌似没啥用)
|
||||
test_pitch = torch.randint(size=(1, 200), low=5, high=255) # 基频(单位赫兹)
|
||||
test_pitchf = torch.rand(1, 200) # nsf基频
|
||||
@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||
device = "cpu" # 导出时设备(不影响使用模型)
|
||||
|
||||
net_g = SynthesizerTrnMsNSFsidM(
|
||||
*cpt["config"], is_half=False
|
||||
*cpt["config"], is_half=False, encoder_dim = encoder_dim
|
||||
) # fp32导出(C++要支持fp16必须手动将内存重新排列所以暂时不用fp16)
|
||||
net_g.load_state_dict(cpt["weight"], strict=False)
|
||||
input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"]
|
||||
@ -47,7 +47,7 @@ if __name__ == "__main__":
|
||||
"rnd": [2],
|
||||
},
|
||||
do_constant_folding=False,
|
||||
opset_version=16,
|
||||
opset_version=18,
|
||||
verbose=False,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
|
Loading…
x
Reference in New Issue
Block a user