fix some nn.Embedding to set dtype=float32 for some float8 freeze model

This commit is contained in:
Won-Kyu Park 2024-09-22 00:42:45 +09:00
parent 8c9c139c65
commit 2e6533519b
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -377,7 +377,7 @@ class T5Attention(torch.nn.Module):
if relative_attention_bias: if relative_attention_bias:
self.relative_attention_num_buckets = 32 self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128 self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=torch.float32)
@staticmethod @staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
@ -482,7 +482,7 @@ class T5Block(torch.nn.Module):
class T5Stack(torch.nn.Module): class T5Stack(torch.nn.Module):
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
super().__init__() super().__init__()
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) #self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device, dtype=torch.float32)
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
@ -507,7 +507,7 @@ class T5(torch.nn.Module):
super().__init__() super().__init__()
self.num_layers = config_dict["num_layers"] self.num_layers = config_dict["num_layers"]
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"]) self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"], device=device, dtype=torch.float32)
self.dtype = dtype self.dtype = dtype
def get_input_embeddings(self): def get_input_embeddings(self):