diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 002fe4832..f992db9bd 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -39,9 +39,9 @@ class Mlp(nn.Module): out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) self.act = act_layer - self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) def forward(self, x): x = self.fc1(x)