mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 04:08:58 +08:00
Remove unused ResBlock check as we only use ResBlock 1
This commit is contained in:
parent
7ef1986778
commit
a2ff0b17dc
@ -219,7 +219,7 @@ class Generator(torch.nn.Module):
|
|||||||
self.conv_pre = Conv1d(
|
self.conv_pre = Conv1d(
|
||||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||||
)
|
)
|
||||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
resblock = modules.ResBlock
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
@ -471,7 +471,7 @@ class GeneratorNSF(torch.nn.Module):
|
|||||||
self.conv_pre = Conv1d(
|
self.conv_pre = Conv1d(
|
||||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||||
)
|
)
|
||||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
resblock = modules.ResBlock
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
@ -230,7 +230,7 @@ class Generator(torch.nn.Module):
|
|||||||
self.conv_pre = Conv1d(
|
self.conv_pre = Conv1d(
|
||||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||||
)
|
)
|
||||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
resblock = modules.ResBlock
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
@ -442,7 +442,7 @@ class GeneratorNSF(torch.nn.Module):
|
|||||||
self.conv_pre = Conv1d(
|
self.conv_pre = Conv1d(
|
||||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||||
)
|
)
|
||||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
resblock = modules.ResBlock
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
@ -249,9 +249,9 @@ class WN(torch.nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ResBlock1(torch.nn.Module):
|
class ResBlock(torch.nn.Module):
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
super(ResBlock1, self).__init__()
|
super(ResBlock, self).__init__()
|
||||||
self.convs1 = nn.ModuleList(
|
self.convs1 = nn.ModuleList(
|
||||||
[
|
[
|
||||||
weight_norm(
|
weight_norm(
|
||||||
@ -364,62 +364,6 @@ class ResBlock1(torch.nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ResBlock2(torch.nn.Module):
|
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
|
||||||
super(ResBlock2, self).__init__()
|
|
||||||
self.convs = nn.ModuleList(
|
|
||||||
[
|
|
||||||
weight_norm(
|
|
||||||
Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
1,
|
|
||||||
dilation=dilation[0],
|
|
||||||
padding=get_padding(kernel_size, dilation[0]),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
weight_norm(
|
|
||||||
Conv1d(
|
|
||||||
channels,
|
|
||||||
channels,
|
|
||||||
kernel_size,
|
|
||||||
1,
|
|
||||||
dilation=dilation[1],
|
|
||||||
padding=get_padding(kernel_size, dilation[1]),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.convs.apply(init_weights)
|
|
||||||
self.lrelu_slope = LRELU_SLOPE
|
|
||||||
|
|
||||||
def forward(self, x, x_mask: Optional[torch.Tensor] = None):
|
|
||||||
for c in self.convs:
|
|
||||||
xt = F.leaky_relu(x, self.lrelu_slope)
|
|
||||||
if x_mask is not None:
|
|
||||||
xt = xt * x_mask
|
|
||||||
xt = c(xt)
|
|
||||||
x = xt + x
|
|
||||||
if x_mask is not None:
|
|
||||||
x = x * x_mask
|
|
||||||
return x
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
for l in self.convs:
|
|
||||||
remove_weight_norm(l)
|
|
||||||
|
|
||||||
def __prepare_scriptable__(self):
|
|
||||||
for l in self.convs:
|
|
||||||
for hook in l._forward_pre_hooks.values():
|
|
||||||
if (
|
|
||||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
|
||||||
and hook.__class__.__name__ == "WeightNorm"
|
|
||||||
):
|
|
||||||
torch.nn.utils.remove_weight_norm(l)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class Log(nn.Module):
|
class Log(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user