Remove unused ResBlock check as we only use ResBlock 1

This commit is contained in:
Blaise 2024-12-11 23:22:05 +01:00
parent 7ef1986778
commit a2ff0b17dc
3 changed files with 6 additions and 62 deletions

View File

@ -219,7 +219,7 @@ class Generator(torch.nn.Module):
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
resblock = modules.ResBlock
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
@ -471,7 +471,7 @@ class GeneratorNSF(torch.nn.Module):
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
resblock = modules.ResBlock
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):

View File

@ -230,7 +230,7 @@ class Generator(torch.nn.Module):
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
resblock = modules.ResBlock
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
@ -442,7 +442,7 @@ class GeneratorNSF(torch.nn.Module):
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
resblock = modules.ResBlock
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):

View File

@ -249,9 +249,9 @@ class WN(torch.nn.Module):
return self
class ResBlock1(torch.nn.Module):
class ResBlock(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
super(ResBlock, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
@ -364,62 +364,6 @@ class ResBlock1(torch.nn.Module):
return self
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
self.lrelu_slope = LRELU_SLOPE
def forward(self, x, x_mask: Optional[torch.Tensor] = None):
for c in self.convs:
xt = F.leaky_relu(x, self.lrelu_slope)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
def __prepare_scriptable__(self):
for l in self.convs:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(l)
return self
class Log(nn.Module):
def forward(
self,