mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-08 06:32:54 +08:00
use text_encoders.t5xxl.transformer.shared.weight tokens weights
* some T5XXL do not have encoder.embed_tokens.weight. use shared.weight embed_tokens instead. * use float8 text encoder t5xxl_fp8_e4m3fn.safetensors
This commit is contained in:
parent
71b430f703
commit
f569f6eb1e
@ -449,7 +449,7 @@ class T5Attention(torch.nn.Module):
|
||||
else:
|
||||
mask = None
|
||||
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(q.dtype) if mask is not None else None)
|
||||
|
||||
return self.o(out), past_bias
|
||||
|
||||
@ -486,15 +486,17 @@ class T5Stack(torch.nn.Module):
|
||||
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
def forward(self, x, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
intermediate = None
|
||||
x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
|
||||
#x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
|
||||
# some T5XXL do not embed_token. use shared token instead like comfy
|
||||
past_bias = None
|
||||
for i, layer in enumerate(self.block):
|
||||
x, past_bias = layer(x, past_bias)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
x = torch.nan_to_num(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.final_layer_norm(intermediate)
|
||||
return x, intermediate
|
||||
@ -505,13 +507,18 @@ class T5(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
|
||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"])
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.embed_tokens
|
||||
#return self.encoder.embed_tokens
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.encoder.embed_tokens = embeddings
|
||||
#self.encoder.embed_tokens = embeddings
|
||||
self.shared = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.encoder(*args, **kwargs)
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
x = self.shared(input_ids).float()
|
||||
x = torch.nan_to_num(x)
|
||||
return self.encoder(x, *args, **kwargs)
|
||||
|
@ -43,7 +43,7 @@ CLIPG_CONFIG = {
|
||||
"textual_inversion_key": "clip_g",
|
||||
}
|
||||
|
||||
T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
||||
T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16_e4m3fn.safetensors"
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
@ -164,11 +164,11 @@ class SD3Cond(torch.nn.Module):
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
|
||||
with torch.no_grad():
|
||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype_inference)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
|
||||
if shared.opts.sd3_enable_t5:
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
@ -199,8 +199,8 @@ class SD3Cond(torch.nn.Module):
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user