diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 5492d5e11..4524fa019 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -377,7 +377,7 @@ class T5Attention(torch.nn.Module): if relative_attention_bias: self.relative_attention_num_buckets = 32 self.relative_attention_max_distance = 128 - self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=torch.float32) @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): @@ -482,7 +482,7 @@ class T5Block(torch.nn.Module): class T5Stack(torch.nn.Module): def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): super().__init__() - self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + #self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device, dtype=torch.float32) self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) @@ -507,7 +507,7 @@ class T5(torch.nn.Module): super().__init__() self.num_layers = config_dict["num_layers"] self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) - self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"]) + self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"], device=device, dtype=torch.float32) self.dtype = dtype def get_input_embeddings(self):