diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index fc1e91e9d..eb17a31c7 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -80,7 +80,7 @@ class FluxCond(torch.nn.Module): with torch.no_grad(): self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) - if shared.opts.sd3_enable_t5: + if shared.opts.flux_enable_t5: self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) else: self.t5xxl = None @@ -107,7 +107,7 @@ class FluxCond(torch.nn.Module): with safetensors.safe_open(clip_l_file, framework="pt") as file: self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict: + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") with safetensors.safe_open(t5_file, framework="pt") as file: self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) @@ -194,7 +194,7 @@ class BaseModel(torch.nn.Module): guidance_embed=True, ) - self.diffusion_model = Flux(device=device, dtype=devices.dtype, **params) + self.diffusion_model = Flux(device=device, dtype=dtype, **params) self.model_sampling = ModelSamplingFlux() self.depth = params['depth'] self.depth_single_block = params['depth_single_blocks'] @@ -301,7 +301,10 @@ class FLUX1Inferencer(torch.nn.Module): def decode_first_stage(self, latent): latent = self.latent_format.process_out(latent) - return self.first_stage_model.decode(latent) + x = self.first_stage_model.decode(latent) + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x def encode_first_stage(self, image): latent = self.first_stage_model.encode(image) diff --git a/modules/shared_options.py b/modules/shared_options.py index 6b6faf332..0cb10acfd 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -195,6 +195,9 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), { "sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"), })) +options_templates.update(options_section(('flux', "Stable Diffusion FLUX", "sd"), { + "flux_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"), +})) options_templates.update(options_section(('vae', "VAE", "sd"), { "sd_vae_explanation": OptionHTML("""