use text_encoders.t5xxl.transformer.shared.weight tokens weights

* some T5XXL do not have encoder.embed_tokens.weight. use shared.weight embed_tokens instead.
 * use float8 text encoder t5xxl_fp8_e4m3fn.safetensors
This commit is contained in:
Won-Kyu Park 2024-09-19 23:56:07 +09:00
parent 71b430f703
commit f569f6eb1e
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 20 additions and 13 deletions

View File

@ -449,7 +449,7 @@ class T5Attention(torch.nn.Module):
else:
mask = None
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(q.dtype) if mask is not None else None)
return self.o(out), past_bias
@ -486,15 +486,17 @@ class T5Stack(torch.nn.Module):
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
def forward(self, x, intermediate_output=None, final_layer_norm_intermediate=True):
intermediate = None
x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
#x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
# some T5XXL do not embed_token. use shared token instead like comfy
past_bias = None
for i, layer in enumerate(self.block):
x, past_bias = layer(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
x = torch.nan_to_num(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.final_layer_norm(intermediate)
return x, intermediate
@ -505,13 +507,18 @@ class T5(torch.nn.Module):
super().__init__()
self.num_layers = config_dict["num_layers"]
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"])
self.dtype = dtype
def get_input_embeddings(self):
return self.encoder.embed_tokens
#return self.encoder.embed_tokens
return self.shared
def set_input_embeddings(self, embeddings):
self.encoder.embed_tokens = embeddings
#self.encoder.embed_tokens = embeddings
self.shared = embeddings
def forward(self, *args, **kwargs):
return self.encoder(*args, **kwargs)
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids).float()
x = torch.nan_to_num(x)
return self.encoder(x, *args, **kwargs)

View File

@ -43,7 +43,7 @@ CLIPG_CONFIG = {
"textual_inversion_key": "clip_g",
}
T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16_e4m3fn.safetensors"
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
@ -164,11 +164,11 @@ class SD3Cond(torch.nn.Module):
self.tokenizer = SD3Tokenizer()
with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype_inference)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.sd3_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference)
else:
self.t5xxl = None
@ -199,8 +199,8 @@ class SD3Cond(torch.nn.Module):
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)