diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index d7b9b2621..002fe4832 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -479,7 +479,7 @@ class T5Stack(torch.nn.Module): def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): intermediate = None - x = self.embed_tokens(input_ids) + x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes past_bias = None for i, layer in enumerate(self.block): x, past_bias = layer(x, past_bias)