From a2ff0b17dc7d7ca6c7e2dd608dfa222cb5032287 Mon Sep 17 00:00:00 2001 From: Blaise Date: Wed, 11 Dec 2024 23:22:05 +0100 Subject: [PATCH] Remove unused ResBlock check as we only use ResBlock 1 --- infer/lib/infer_pack/models.py | 4 +- infer/lib/infer_pack/models_onnx.py | 4 +- infer/lib/infer_pack/modules.py | 60 +---------------------------- 3 files changed, 6 insertions(+), 62 deletions(-) diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index a900048..afc246a 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -219,7 +219,7 @@ class Generator(torch.nn.Module): self.conv_pre = Conv1d( initial_channel, upsample_initial_channel, 7, 1, padding=3 ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + resblock = modules.ResBlock self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -471,7 +471,7 @@ class GeneratorNSF(torch.nn.Module): self.conv_pre = Conv1d( initial_channel, upsample_initial_channel, 7, 1, padding=3 ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + resblock = modules.ResBlock self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): diff --git a/infer/lib/infer_pack/models_onnx.py b/infer/lib/infer_pack/models_onnx.py index e327019..e2b5987 100644 --- a/infer/lib/infer_pack/models_onnx.py +++ b/infer/lib/infer_pack/models_onnx.py @@ -230,7 +230,7 @@ class Generator(torch.nn.Module): self.conv_pre = Conv1d( initial_channel, upsample_initial_channel, 7, 1, padding=3 ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + resblock = modules.ResBlock self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -442,7 +442,7 @@ class GeneratorNSF(torch.nn.Module): self.conv_pre = Conv1d( initial_channel, upsample_initial_channel, 7, 1, padding=3 ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + resblock = modules.ResBlock self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): diff --git a/infer/lib/infer_pack/modules.py b/infer/lib/infer_pack/modules.py index 51aeaf0..1189c03 100644 --- a/infer/lib/infer_pack/modules.py +++ b/infer/lib/infer_pack/modules.py @@ -249,9 +249,9 @@ class WN(torch.nn.Module): return self -class ResBlock1(torch.nn.Module): +class ResBlock(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock1, self).__init__() + super(ResBlock, self).__init__() self.convs1 = nn.ModuleList( [ weight_norm( @@ -364,62 +364,6 @@ class ResBlock1(torch.nn.Module): return self -class ResBlock2(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): - super(ResBlock2, self).__init__() - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - self.convs.apply(init_weights) - self.lrelu_slope = LRELU_SLOPE - - def forward(self, x, x_mask: Optional[torch.Tensor] = None): - for c in self.convs: - xt = F.leaky_relu(x, self.lrelu_slope) - if x_mask is not None: - xt = xt * x_mask - xt = c(xt) - x = xt + x - if x_mask is not None: - x = x * x_mask - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - - def __prepare_scriptable__(self): - for l in self.convs: - for hook in l._forward_pre_hooks.values(): - if ( - hook.__module__ == "torch.nn.utils.weight_norm" - and hook.__class__.__name__ == "WeightNorm" - ): - torch.nn.utils.remove_weight_norm(l) - return self - - class Log(nn.Module): def forward( self,