diff --git a/infer_pack/models.py b/infer_pack/models.py index a7e688e..085d26c 100644 --- a/infer_pack/models.py +++ b/infer_pack/models.py @@ -59,9 +59,7 @@ class TextEncoder256(nn.Module): m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs, x_mask - - -class TextEncoder256Sim(nn.Module): +class TextEncoder768(nn.Module): def __init__( self, out_channels, @@ -81,14 +79,14 @@ class TextEncoder256Sim(nn.Module): self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = p_dropout - self.emb_phone = nn.Linear(256, hidden_channels) + self.emb_phone = nn.Linear(768, hidden_channels) self.lrelu = nn.LeakyReLU(0.1, inplace=True) if f0 == True: self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 self.encoder = attentions.Encoder( hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout ) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, phone, pitch, lengths): if pitch == None: @@ -102,9 +100,10 @@ class TextEncoder256Sim(nn.Module): x.dtype ) x = self.encoder(x * x_mask, x_mask) - x = self.proj(x) * x_mask - return x, x_mask + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return m, logs, x_mask class ResidualCouplingBlock(nn.Module): def __init__( @@ -636,6 +635,115 @@ class SynthesizerTrnMs256NSFsid(nn.Module): z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g) return o, x_mask, (z, z_p, m_p, logs_p) +class SynthesizerTrnMs768NSFsid(nn.Module): + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr, + **kwargs + ): + super().__init__() + if type(sr) == type("strr"): + sr = sr2sr[sr] + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + # self.hop_length = hop_length# + self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder768( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + self.dec = GeneratorNSF( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + sr=sr, + is_half=kwargs["is_half"], + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels + ) + self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) + print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim) + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + self.flow.remove_weight_norm() + self.enc_q.remove_weight_norm() + + def forward( + self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds + ): # 这里ds是id,[bs,1] + # print(1,pitch.shape)#[bs,t] + g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size + ) + # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) + pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size) + # print(-2,pitchf.shape,z_slice.shape) + o = self.dec(z_slice, pitchf, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g) + return o, x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs256NSFsid_nono(nn.Module): @@ -738,13 +846,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module): z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec((z * x_mask)[:, :, :max_len], g=g) return o, x_mask, (z, z_p, m_p, logs_p) - - -class SynthesizerTrnMs256NSFsid_sim(nn.Module): - """ - Synthesizer for Training - """ - +class SynthesizerTrnMs768NSFsid_nono(nn.Module): def __init__( self, spec_channels, @@ -763,9 +865,8 @@ class SynthesizerTrnMs256NSFsid_sim(nn.Module): upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, - # hop_length, - gin_channels=0, - use_sdp=True, + gin_channels, + sr=None, **kwargs ): super().__init__() @@ -787,7 +888,7 @@ class SynthesizerTrnMs256NSFsid_sim(nn.Module): self.gin_channels = gin_channels # self.hop_length = hop_length# self.spk_embed_dim = spk_embed_dim - self.enc_p = TextEncoder256Sim( + self.enc_p = TextEncoder768( inter_channels, hidden_channels, filter_channels, @@ -795,8 +896,9 @@ class SynthesizerTrnMs256NSFsid_sim(nn.Module): n_layers, kernel_size, p_dropout, + f0=False, ) - self.dec = GeneratorNSF( + self.dec = Generator( inter_channels, resblock, resblock_kernel_sizes, @@ -805,9 +907,16 @@ class SynthesizerTrnMs256NSFsid_sim(nn.Module): upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, - is_half=kwargs["is_half"], ) - + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) @@ -819,28 +928,24 @@ class SynthesizerTrnMs256NSFsid_sim(nn.Module): self.flow.remove_weight_norm() self.enc_q.remove_weight_norm() - def forward( - self, phone, phone_lengths, pitch, pitchf, y_lengths, ds - ): # y是spec不需要了现在 + def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1] g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 - x, x_mask = self.enc_p(phone, pitch, phone_lengths) - x = self.flow(x, x_mask, g=g, reverse=True) + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) z_slice, ids_slice = commons.rand_slice_segments( - x, y_lengths, self.segment_size + z, y_lengths, self.segment_size ) + o = self.dec(z_slice, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size) - o = self.dec(z_slice, pitchf, g=g) - return o, ids_slice - - def infer( - self, phone, phone_lengths, pitch, pitchf, ds, max_len=None - ): # y是spec不需要了现在 - g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 - x, x_mask = self.enc_p(phone, pitch, phone_lengths) - x = self.flow(x, x_mask, g=g, reverse=True) - o = self.dec((x * x_mask)[:, :, :max_len], pitchf, g=g) - return o, o + def infer(self, phone, phone_lengths, sid, max_len=None): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec((z * x_mask)[:, :, :max_len], g=g) + return o, x_mask, (z, z_p, m_p, logs_p) class MultiPeriodDiscriminator(torch.nn.Module): @@ -872,6 +977,35 @@ class MultiPeriodDiscriminator(torch.nn.Module): return y_d_rs, y_d_gs, fmap_rs, fmap_gs +class MultiPeriodDiscriminatorV2(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminatorV2, self).__init__() + # periods = [2, 3, 5, 7, 11, 17] + periods = [2,3, 5, 7, 11, 17, 23, 37] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] # + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + # for j in range(len(fmap_r)): + # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False):