get T5 to work both with and without --precision half

This commit is contained in:
AUTOMATIC1111 2024-06-28 08:10:19 +03:00
parent 06fe174c74
commit fc8b126673

View File

@ -479,7 +479,7 @@ class T5Stack(torch.nn.Module):
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
intermediate = None intermediate = None
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
past_bias = None past_bias = None
for i, layer in enumerate(self.block): for i, layer in enumerate(self.block):
x, past_bias = layer(x, past_bias) x, past_bias = layer(x, past_bias)