From c93940a25d1d6b377c8d9c307ad58f3eb7c7ef3a Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Sun, 28 May 2023 23:00:29 +0800 Subject: [PATCH] Add files via upload --- uvr5_pack/lib_v5/layers_new.py | 126 +++++++++++++++++++++++++++++++++ uvr5_pack/lib_v5/nets_new.py | 124 ++++++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+) create mode 100644 uvr5_pack/lib_v5/layers_new.py create mode 100644 uvr5_pack/lib_v5/nets_new.py diff --git a/uvr5_pack/lib_v5/layers_new.py b/uvr5_pack/lib_v5/layers_new.py new file mode 100644 index 0000000..9c5ded3 --- /dev/null +++ b/uvr5_pack/lib_v5/layers_new.py @@ -0,0 +1,126 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from uvr5_pack.lib_v5 import spec_utils + +class Conv2DBNActiv(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False), + nn.BatchNorm2d(nout), + activ() + ) + + def __call__(self, x): + return self.conv(x) + +class Encoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + + def __call__(self, x): + h = self.conv1(x) + h = self.conv2(h) + + return h + + +class Decoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False): + super(Decoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + + h = self.conv1(x) + # h = self.conv2(h) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + + def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) + ) + self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) + self.conv3 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + out = self.bottleneck(out) + + if self.dropout is not None: + out = self.dropout(out) + + return out + + +class LSTMModule(nn.Module): + + def __init__(self, nin_conv, nin_lstm, nout_lstm): + super(LSTMModule, self).__init__() + self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) + self.lstm = nn.LSTM( + input_size=nin_lstm, + hidden_size=nout_lstm // 2, + bidirectional=True + ) + self.dense = nn.Sequential( + nn.Linear(nout_lstm, nin_lstm), + nn.BatchNorm1d(nin_lstm), + nn.ReLU() + ) + + def forward(self, x): + N, _, nbins, nframes = x.size() + h = self.conv(x)[:, 0] # N, nbins, nframes + h = h.permute(2, 0, 1) # nframes, N, nbins + h, _ = self.lstm(h) + h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins + h = h.reshape(nframes, N, 1, nbins) + h = h.permute(1, 2, 3, 0) + + return h diff --git a/uvr5_pack/lib_v5/nets_new.py b/uvr5_pack/lib_v5/nets_new.py new file mode 100644 index 0000000..1c50bf8 --- /dev/null +++ b/uvr5_pack/lib_v5/nets_new.py @@ -0,0 +1,124 @@ +import torch +from torch import nn +import torch.nn.functional as F +from uvr5_pack.lib_v5 import layers_new as layers + +class BaseNet(nn.Module): + + def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))): + super(BaseNet, self).__init__() + self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1) + self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1) + self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1) + self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1) + self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True) + + self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) + self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) + self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) + self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm) + self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) + + def __call__(self, x): + e1 = self.enc1(x) + e2 = self.enc2(e1) + e3 = self.enc3(e2) + e4 = self.enc4(e3) + e5 = self.enc5(e4) + + h = self.aspp(e5) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = torch.cat([h, self.lstm_dec2(h)], dim=1) + h = self.dec1(h, e1) + + return h + +class CascadedNet(nn.Module): + + def __init__(self, n_fft, nout=32, nout_lstm=128): + super(CascadedNet, self).__init__() + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + self.nin_lstm = self.max_bin // 2 + self.offset = 64 + + self.stg1_low_band_net = nn.Sequential( + BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm), + layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0) + ) + + self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2) + + self.stg2_low_band_net = nn.Sequential( + BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), + layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0) + ) + self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2) + + self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm) + + self.out = nn.Conv2d(nout, 2, 1, bias=False) + self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) + + def forward(self, x): + x = x[:, :, :self.max_bin] + + bandw = x.size()[2] // 2 + l1_in = x[:, :, :bandw] + h1_in = x[:, :, bandw:] + l1 = self.stg1_low_band_net(l1_in) + h1 = self.stg1_high_band_net(h1_in) + aux1 = torch.cat([l1, h1], dim=2) + + l2_in = torch.cat([l1_in, l1], dim=1) + h2_in = torch.cat([h1_in, h1], dim=1) + l2 = self.stg2_low_band_net(l2_in) + h2 = self.stg2_high_band_net(h2_in) + aux2 = torch.cat([l2, h2], dim=2) + + f3_in = torch.cat([x, aux1, aux2], dim=1) + f3 = self.stg3_full_band_net(f3_in) + + mask = torch.sigmoid(self.out(f3)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode='replicate' + ) + + if self.training: + aux = torch.cat([aux1, aux2], dim=1) + aux = torch.sigmoid(self.aux_out(aux)) + aux = F.pad( + input=aux, + pad=(0, 0, 0, self.output_bin - aux.size()[2]), + mode='replicate' + ) + return mask, aux + else: + return mask + + def predict_mask(self, x): + mask = self.forward(x) + + if self.offset > 0: + mask = mask[:, :, :, self.offset:-self.offset] + assert mask.size()[3] > 0 + + return mask + + def predict(self, x,aggressiveness=None): + mask = self.forward(x) + pred_mag = x * mask + + if self.offset > 0: + pred_mag = pred_mag[:, :, :, self.offset:-self.offset] + assert pred_mag.size()[3] > 0 + + return pred_mag