mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-05-06 20:01:37 +08:00
Merge pull request #1869 from yxlllc/dev
Correct the receptive field of the flow model and optimize the real-time inference code
This commit is contained in:
commit
886b6324bc
@ -17,9 +17,10 @@ from infer.lib.infer_pack.commons import get_padding, init_weights
|
|||||||
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
|
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||||
|
|
||||||
|
|
||||||
class TextEncoder256(nn.Module):
|
class TextEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
filter_channels,
|
filter_channels,
|
||||||
@ -29,7 +30,7 @@ class TextEncoder256(nn.Module):
|
|||||||
p_dropout,
|
p_dropout,
|
||||||
f0=True,
|
f0=True,
|
||||||
):
|
):
|
||||||
super(TextEncoder256, self).__init__()
|
super(TextEncoder, self).__init__()
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
self.filter_channels = filter_channels
|
self.filter_channels = filter_channels
|
||||||
@ -37,7 +38,7 @@ class TextEncoder256(nn.Module):
|
|||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.p_dropout = float(p_dropout)
|
self.p_dropout = float(p_dropout)
|
||||||
self.emb_phone = nn.Linear(256, hidden_channels)
|
self.emb_phone = nn.Linear(in_channels, hidden_channels)
|
||||||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
||||||
if f0 == True:
|
if f0 == True:
|
||||||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
||||||
@ -51,9 +52,7 @@ class TextEncoder256(nn.Module):
|
|||||||
)
|
)
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(
|
def forward(self, phone: torch.Tensor, pitch: torch.Tensor, lengths: torch.Tensor, skip_head: Optional[torch.Tensor] = None):
|
||||||
self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor
|
|
||||||
):
|
|
||||||
if pitch is None:
|
if pitch is None:
|
||||||
x = self.emb_phone(phone)
|
x = self.emb_phone(phone)
|
||||||
else:
|
else:
|
||||||
@ -65,60 +64,12 @@ class TextEncoder256(nn.Module):
|
|||||||
x.dtype
|
x.dtype
|
||||||
)
|
)
|
||||||
x = self.encoder(x * x_mask, x_mask)
|
x = self.encoder(x * x_mask, x_mask)
|
||||||
|
if skip_head is not None:
|
||||||
|
assert isinstance(skip_head, torch.Tensor)
|
||||||
|
head = int(skip_head.item())
|
||||||
|
x = x[:, :, head : ]
|
||||||
|
x_mask = x_mask[:, :, head : ]
|
||||||
stats = self.proj(x) * x_mask
|
stats = self.proj(x) * x_mask
|
||||||
|
|
||||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
||||||
return m, logs, x_mask
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoder768(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
out_channels,
|
|
||||||
hidden_channels,
|
|
||||||
filter_channels,
|
|
||||||
n_heads,
|
|
||||||
n_layers,
|
|
||||||
kernel_size,
|
|
||||||
p_dropout,
|
|
||||||
f0=True,
|
|
||||||
):
|
|
||||||
super(TextEncoder768, self).__init__()
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.hidden_channels = hidden_channels
|
|
||||||
self.filter_channels = filter_channels
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.p_dropout = float(p_dropout)
|
|
||||||
self.emb_phone = nn.Linear(768, hidden_channels)
|
|
||||||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
if f0 == True:
|
|
||||||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
|
||||||
self.encoder = attentions.Encoder(
|
|
||||||
hidden_channels,
|
|
||||||
filter_channels,
|
|
||||||
n_heads,
|
|
||||||
n_layers,
|
|
||||||
kernel_size,
|
|
||||||
float(p_dropout),
|
|
||||||
)
|
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
|
||||||
|
|
||||||
def forward(self, phone: torch.Tensor, pitch: torch.Tensor, lengths: torch.Tensor):
|
|
||||||
if pitch is None:
|
|
||||||
x = self.emb_phone(phone)
|
|
||||||
else:
|
|
||||||
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
|
||||||
x = x * math.sqrt(self.hidden_channels) # [b, t, h]
|
|
||||||
x = self.lrelu(x)
|
|
||||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
|
|
||||||
x.dtype
|
|
||||||
)
|
|
||||||
x = self.encoder(x * x_mask, x_mask)
|
|
||||||
stats = self.proj(x) * x_mask
|
|
||||||
|
|
||||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
return m, logs, x_mask
|
return m, logs, x_mask
|
||||||
|
|
||||||
@ -682,7 +633,8 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
|||||||
self.gin_channels = gin_channels
|
self.gin_channels = gin_channels
|
||||||
# self.hop_length = hop_length#
|
# self.hop_length = hop_length#
|
||||||
self.spk_embed_dim = spk_embed_dim
|
self.spk_embed_dim = spk_embed_dim
|
||||||
self.enc_p = TextEncoder256(
|
self.enc_p = TextEncoder(
|
||||||
|
256,
|
||||||
inter_channels,
|
inter_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
filter_channels,
|
filter_channels,
|
||||||
@ -792,22 +744,28 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
|||||||
return_length: Optional[torch.Tensor] = None,
|
return_length: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
|
||||||
if skip_head is not None and return_length is not None:
|
if skip_head is not None and return_length is not None:
|
||||||
assert isinstance(skip_head, torch.Tensor)
|
assert isinstance(skip_head, torch.Tensor)
|
||||||
assert isinstance(return_length, torch.Tensor)
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
length = int(return_length.item())
|
length = int(return_length.item())
|
||||||
z_p = z_p[:, :, head : head + length]
|
flow_head = torch.clamp(skip_head - 24, min=0)
|
||||||
x_mask = x_mask[:, :, head : head + length]
|
dec_head = head - int(flow_head.item())
|
||||||
|
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head)
|
||||||
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
|
z = z[:, :, dec_head : dec_head + length]
|
||||||
|
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
||||||
nsff0 = nsff0[:, head : head + length]
|
nsff0 = nsff0[:, head : head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
else:
|
||||||
|
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||||
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, nsff0, g=g)
|
o = self.dec(z * x_mask, nsff0, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrnMs768NSFsid(nn.Module):
|
class SynthesizerTrnMs768NSFsid(SynthesizerTrnMs256NSFsid):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spec_channels,
|
spec_channels,
|
||||||
@ -830,28 +788,30 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
|||||||
sr,
|
sr,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super(SynthesizerTrnMs768NSFsid, self).__init__()
|
super(SynthesizerTrnMs768NSFsid, self).__init__(
|
||||||
if isinstance(sr, str):
|
spec_channels,
|
||||||
sr = sr2sr[sr]
|
segment_size,
|
||||||
self.spec_channels = spec_channels
|
inter_channels,
|
||||||
self.inter_channels = inter_channels
|
hidden_channels,
|
||||||
self.hidden_channels = hidden_channels
|
filter_channels,
|
||||||
self.filter_channels = filter_channels
|
n_heads,
|
||||||
self.n_heads = n_heads
|
n_layers,
|
||||||
self.n_layers = n_layers
|
kernel_size,
|
||||||
self.kernel_size = kernel_size
|
p_dropout,
|
||||||
self.p_dropout = float(p_dropout)
|
resblock,
|
||||||
self.resblock = resblock
|
resblock_kernel_sizes,
|
||||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
resblock_dilation_sizes,
|
||||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
upsample_rates,
|
||||||
self.upsample_rates = upsample_rates
|
upsample_initial_channel,
|
||||||
self.upsample_initial_channel = upsample_initial_channel
|
upsample_kernel_sizes,
|
||||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
spk_embed_dim,
|
||||||
self.segment_size = segment_size
|
gin_channels,
|
||||||
self.gin_channels = gin_channels
|
sr,
|
||||||
# self.hop_length = hop_length#
|
**kwargs
|
||||||
self.spk_embed_dim = spk_embed_dim
|
)
|
||||||
self.enc_p = TextEncoder768(
|
del self.enc_p
|
||||||
|
self.enc_p = TextEncoder(
|
||||||
|
768,
|
||||||
inter_channels,
|
inter_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
filter_channels,
|
filter_channels,
|
||||||
@ -860,113 +820,6 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
|||||||
kernel_size,
|
kernel_size,
|
||||||
float(p_dropout),
|
float(p_dropout),
|
||||||
)
|
)
|
||||||
self.dec = GeneratorNSF(
|
|
||||||
inter_channels,
|
|
||||||
resblock,
|
|
||||||
resblock_kernel_sizes,
|
|
||||||
resblock_dilation_sizes,
|
|
||||||
upsample_rates,
|
|
||||||
upsample_initial_channel,
|
|
||||||
upsample_kernel_sizes,
|
|
||||||
gin_channels=gin_channels,
|
|
||||||
sr=sr,
|
|
||||||
is_half=kwargs["is_half"],
|
|
||||||
)
|
|
||||||
self.enc_q = PosteriorEncoder(
|
|
||||||
spec_channels,
|
|
||||||
inter_channels,
|
|
||||||
hidden_channels,
|
|
||||||
5,
|
|
||||||
1,
|
|
||||||
16,
|
|
||||||
gin_channels=gin_channels,
|
|
||||||
)
|
|
||||||
self.flow = ResidualCouplingBlock(
|
|
||||||
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
|
||||||
)
|
|
||||||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
|
||||||
logger.debug(
|
|
||||||
"gin_channels: "
|
|
||||||
+ str(gin_channels)
|
|
||||||
+ ", self.spk_embed_dim: "
|
|
||||||
+ str(self.spk_embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
self.dec.remove_weight_norm()
|
|
||||||
self.flow.remove_weight_norm()
|
|
||||||
if hasattr(self, "enc_q"):
|
|
||||||
self.enc_q.remove_weight_norm()
|
|
||||||
|
|
||||||
def __prepare_scriptable__(self):
|
|
||||||
for hook in self.dec._forward_pre_hooks.values():
|
|
||||||
# The hook we want to remove is an instance of WeightNorm class, so
|
|
||||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
|
||||||
# because of shadowing, so we check the module name directly.
|
|
||||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
|
||||||
if (
|
|
||||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
||||||
and hook.__class__.__name__ == "WeightNorm"
|
|
||||||
):
|
|
||||||
torch.nn.utils.remove_weight_norm(self.dec)
|
|
||||||
for hook in self.flow._forward_pre_hooks.values():
|
|
||||||
if (
|
|
||||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
||||||
and hook.__class__.__name__ == "WeightNorm"
|
|
||||||
):
|
|
||||||
torch.nn.utils.remove_weight_norm(self.flow)
|
|
||||||
if hasattr(self, "enc_q"):
|
|
||||||
for hook in self.enc_q._forward_pre_hooks.values():
|
|
||||||
if (
|
|
||||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
||||||
and hook.__class__.__name__ == "WeightNorm"
|
|
||||||
):
|
|
||||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def forward(
|
|
||||||
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
|
|
||||||
): # 这里ds是id,[bs,1]
|
|
||||||
# print(1,pitch.shape)#[bs,t]
|
|
||||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
|
||||||
z_slice, ids_slice = commons.rand_slice_segments(
|
|
||||||
z, y_lengths, self.segment_size
|
|
||||||
)
|
|
||||||
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
|
|
||||||
pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
|
|
||||||
# print(-2,pitchf.shape,z_slice.shape)
|
|
||||||
o = self.dec(z_slice, pitchf, g=g)
|
|
||||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def infer(
|
|
||||||
self,
|
|
||||||
phone: torch.Tensor,
|
|
||||||
phone_lengths: torch.Tensor,
|
|
||||||
pitch: torch.Tensor,
|
|
||||||
nsff0: torch.Tensor,
|
|
||||||
sid: torch.Tensor,
|
|
||||||
skip_head: Optional[torch.Tensor] = None,
|
|
||||||
return_length: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
|
||||||
if skip_head is not None and return_length is not None:
|
|
||||||
assert isinstance(skip_head, torch.Tensor)
|
|
||||||
assert isinstance(return_length, torch.Tensor)
|
|
||||||
head = int(skip_head.item())
|
|
||||||
length = int(return_length.item())
|
|
||||||
z_p = z_p[:, :, head : head + length]
|
|
||||||
x_mask = x_mask[:, :, head : head + length]
|
|
||||||
nsff0 = nsff0[:, head : head + length]
|
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
|
||||||
o = self.dec(z * x_mask, nsff0, g=g)
|
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||||
@ -1011,7 +864,8 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
|||||||
self.gin_channels = gin_channels
|
self.gin_channels = gin_channels
|
||||||
# self.hop_length = hop_length#
|
# self.hop_length = hop_length#
|
||||||
self.spk_embed_dim = spk_embed_dim
|
self.spk_embed_dim = spk_embed_dim
|
||||||
self.enc_p = TextEncoder256(
|
self.enc_p = TextEncoder(
|
||||||
|
256,
|
||||||
inter_channels,
|
inter_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
filter_channels,
|
filter_channels,
|
||||||
@ -1105,21 +959,27 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
|||||||
return_length: Optional[torch.Tensor] = None,
|
return_length: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
|
||||||
if skip_head is not None and return_length is not None:
|
if skip_head is not None and return_length is not None:
|
||||||
assert isinstance(skip_head, torch.Tensor)
|
assert isinstance(skip_head, torch.Tensor)
|
||||||
assert isinstance(return_length, torch.Tensor)
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
length = int(return_length.item())
|
length = int(return_length.item())
|
||||||
z_p = z_p[:, :, head : head + length]
|
flow_head = torch.clamp(skip_head - 24, min=0)
|
||||||
x_mask = x_mask[:, :, head : head + length]
|
dec_head = head - int(flow_head.item())
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths, flow_head)
|
||||||
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
|
z = z[:, :, dec_head : dec_head + length]
|
||||||
|
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
||||||
|
else:
|
||||||
|
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||||
|
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||||
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, g=g)
|
o = self.dec(z * x_mask, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spec_channels,
|
spec_channels,
|
||||||
@ -1142,26 +1002,30 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
|||||||
sr=None,
|
sr=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super(SynthesizerTrnMs768NSFsid_nono, self).__init__()
|
super(SynthesizerTrnMs768NSFsid_nono, self).__init__(
|
||||||
self.spec_channels = spec_channels
|
spec_channels,
|
||||||
self.inter_channels = inter_channels
|
segment_size,
|
||||||
self.hidden_channels = hidden_channels
|
inter_channels,
|
||||||
self.filter_channels = filter_channels
|
hidden_channels,
|
||||||
self.n_heads = n_heads
|
filter_channels,
|
||||||
self.n_layers = n_layers
|
n_heads,
|
||||||
self.kernel_size = kernel_size
|
n_layers,
|
||||||
self.p_dropout = float(p_dropout)
|
kernel_size,
|
||||||
self.resblock = resblock
|
p_dropout,
|
||||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
resblock,
|
||||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
resblock_kernel_sizes,
|
||||||
self.upsample_rates = upsample_rates
|
resblock_dilation_sizes,
|
||||||
self.upsample_initial_channel = upsample_initial_channel
|
upsample_rates,
|
||||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
upsample_initial_channel,
|
||||||
self.segment_size = segment_size
|
upsample_kernel_sizes,
|
||||||
self.gin_channels = gin_channels
|
spk_embed_dim,
|
||||||
# self.hop_length = hop_length#
|
gin_channels,
|
||||||
self.spk_embed_dim = spk_embed_dim
|
sr,
|
||||||
self.enc_p = TextEncoder768(
|
**kwargs
|
||||||
|
)
|
||||||
|
del self.enc_p
|
||||||
|
self.enc_p = TextEncoder(
|
||||||
|
768,
|
||||||
inter_channels,
|
inter_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
filter_channels,
|
filter_channels,
|
||||||
@ -1171,102 +1035,6 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
|||||||
float(p_dropout),
|
float(p_dropout),
|
||||||
f0=False,
|
f0=False,
|
||||||
)
|
)
|
||||||
self.dec = Generator(
|
|
||||||
inter_channels,
|
|
||||||
resblock,
|
|
||||||
resblock_kernel_sizes,
|
|
||||||
resblock_dilation_sizes,
|
|
||||||
upsample_rates,
|
|
||||||
upsample_initial_channel,
|
|
||||||
upsample_kernel_sizes,
|
|
||||||
gin_channels=gin_channels,
|
|
||||||
)
|
|
||||||
self.enc_q = PosteriorEncoder(
|
|
||||||
spec_channels,
|
|
||||||
inter_channels,
|
|
||||||
hidden_channels,
|
|
||||||
5,
|
|
||||||
1,
|
|
||||||
16,
|
|
||||||
gin_channels=gin_channels,
|
|
||||||
)
|
|
||||||
self.flow = ResidualCouplingBlock(
|
|
||||||
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
|
||||||
)
|
|
||||||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
|
||||||
logger.debug(
|
|
||||||
"gin_channels: "
|
|
||||||
+ str(gin_channels)
|
|
||||||
+ ", self.spk_embed_dim: "
|
|
||||||
+ str(self.spk_embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
self.dec.remove_weight_norm()
|
|
||||||
self.flow.remove_weight_norm()
|
|
||||||
if hasattr(self, "enc_q"):
|
|
||||||
self.enc_q.remove_weight_norm()
|
|
||||||
|
|
||||||
def __prepare_scriptable__(self):
|
|
||||||
for hook in self.dec._forward_pre_hooks.values():
|
|
||||||
# The hook we want to remove is an instance of WeightNorm class, so
|
|
||||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
|
||||||
# because of shadowing, so we check the module name directly.
|
|
||||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
|
||||||
if (
|
|
||||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
||||||
and hook.__class__.__name__ == "WeightNorm"
|
|
||||||
):
|
|
||||||
torch.nn.utils.remove_weight_norm(self.dec)
|
|
||||||
for hook in self.flow._forward_pre_hooks.values():
|
|
||||||
if (
|
|
||||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
||||||
and hook.__class__.__name__ == "WeightNorm"
|
|
||||||
):
|
|
||||||
torch.nn.utils.remove_weight_norm(self.flow)
|
|
||||||
if hasattr(self, "enc_q"):
|
|
||||||
for hook in self.enc_q._forward_pre_hooks.values():
|
|
||||||
if (
|
|
||||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
||||||
and hook.__class__.__name__ == "WeightNorm"
|
|
||||||
):
|
|
||||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
|
|
||||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
|
||||||
z_slice, ids_slice = commons.rand_slice_segments(
|
|
||||||
z, y_lengths, self.segment_size
|
|
||||||
)
|
|
||||||
o = self.dec(z_slice, g=g)
|
|
||||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def infer(
|
|
||||||
self,
|
|
||||||
phone: torch.Tensor,
|
|
||||||
phone_lengths: torch.Tensor,
|
|
||||||
sid: torch.Tensor,
|
|
||||||
skip_head: Optional[torch.Tensor] = None,
|
|
||||||
return_length: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
|
||||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
|
||||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
|
||||||
if skip_head is not None and return_length is not None:
|
|
||||||
assert isinstance(skip_head, torch.Tensor)
|
|
||||||
assert isinstance(return_length, torch.Tensor)
|
|
||||||
head = int(skip_head.item())
|
|
||||||
length = int(return_length.item())
|
|
||||||
z_p = z_p[:, :, head : head + length]
|
|
||||||
x_mask = x_mask[:, :, head : head + length]
|
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
|
||||||
o = self.dec(z * x_mask, g=g)
|
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user