mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-01 19:22:59 +08:00
fix some nn.Embedding to set dtype=float32 for some float8 freeze model
This commit is contained in:
parent
8c9c139c65
commit
2e6533519b
@ -377,7 +377,7 @@ class T5Attention(torch.nn.Module):
|
|||||||
if relative_attention_bias:
|
if relative_attention_bias:
|
||||||
self.relative_attention_num_buckets = 32
|
self.relative_attention_num_buckets = 32
|
||||||
self.relative_attention_max_distance = 128
|
self.relative_attention_max_distance = 128
|
||||||
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
|
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||||
@ -482,7 +482,7 @@ class T5Block(torch.nn.Module):
|
|||||||
class T5Stack(torch.nn.Module):
|
class T5Stack(torch.nn.Module):
|
||||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
|
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
|
#self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device, dtype=torch.float32)
|
||||||
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
|
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
|
||||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -507,7 +507,7 @@ class T5(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_layers = config_dict["num_layers"]
|
self.num_layers = config_dict["num_layers"]
|
||||||
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
|
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
|
||||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"])
|
self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"], device=device, dtype=torch.float32)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user