diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 695d356cb..5492d5e11 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -449,7 +449,7 @@ class T5Attention(torch.nn.Module): else: mask = None - out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None) + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(q.dtype) if mask is not None else None) return self.o(out), past_bias @@ -486,15 +486,17 @@ class T5Stack(torch.nn.Module): self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + def forward(self, x, intermediate_output=None, final_layer_norm_intermediate=True): intermediate = None - x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes + #x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes + # some T5XXL do not embed_token. use shared token instead like comfy past_bias = None for i, layer in enumerate(self.block): x, past_bias = layer(x, past_bias) if i == intermediate_output: intermediate = x.clone() x = self.final_layer_norm(x) + x = torch.nan_to_num(x) if intermediate is not None and final_layer_norm_intermediate: intermediate = self.final_layer_norm(intermediate) return x, intermediate @@ -505,13 +507,18 @@ class T5(torch.nn.Module): super().__init__() self.num_layers = config_dict["num_layers"] self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) + self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"]) self.dtype = dtype def get_input_embeddings(self): - return self.encoder.embed_tokens + #return self.encoder.embed_tokens + return self.shared def set_input_embeddings(self, embeddings): - self.encoder.embed_tokens = embeddings + #self.encoder.embed_tokens = embeddings + self.shared = embeddings - def forward(self, *args, **kwargs): - return self.encoder(*args, **kwargs) + def forward(self, input_ids, *args, **kwargs): + x = self.shared(input_ids).float() + x = torch.nan_to_num(x) + return self.encoder(x, *args, **kwargs) diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py index 6a43f569b..7058950a5 100644 --- a/modules/models/sd3/sd3_cond.py +++ b/modules/models/sd3/sd3_cond.py @@ -43,7 +43,7 @@ CLIPG_CONFIG = { "textual_inversion_key": "clip_g", } -T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" +T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16_e4m3fn.safetensors" T5_CONFIG = { "d_ff": 10240, "d_model": 4096, @@ -164,11 +164,11 @@ class SD3Cond(torch.nn.Module): self.tokenizer = SD3Tokenizer() with torch.no_grad(): - self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) - self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype_inference) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) if shared.opts.sd3_enable_t5: - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference) else: self.t5xxl = None @@ -199,8 +199,8 @@ class SD3Cond(torch.nn.Module): with safetensors.safe_open(clip_l_file, framework="pt") as file: self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict: - t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors") with safetensors.safe_open(t5_file, framework="pt") as file: self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)