fix for t5xxl

This commit is contained in:
Won-Kyu Park 2024-09-05 08:57:42 +09:00
parent 51c285265f
commit 7e2d51965f
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 10 additions and 4 deletions

View File

@ -80,7 +80,7 @@ class FluxCond(torch.nn.Module):
with torch.no_grad():
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.sd3_enable_t5:
if shared.opts.flux_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
else:
self.t5xxl = None
@ -107,7 +107,7 @@ class FluxCond(torch.nn.Module):
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
@ -194,7 +194,7 @@ class BaseModel(torch.nn.Module):
guidance_embed=True,
)
self.diffusion_model = Flux(device=device, dtype=devices.dtype, **params)
self.diffusion_model = Flux(device=device, dtype=dtype, **params)
self.model_sampling = ModelSamplingFlux()
self.depth = params['depth']
self.depth_single_block = params['depth_single_blocks']
@ -301,7 +301,10 @@ class FLUX1Inferencer(torch.nn.Module):
def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
return self.first_stage_model.decode(latent)
x = self.first_stage_model.decode(latent)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
def encode_first_stage(self, image):
latent = self.first_stage_model.encode(image)

View File

@ -195,6 +195,9 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"),
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
"sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
}))
options_templates.update(options_section(('flux', "Stable Diffusion FLUX", "sd"), {
"flux_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
}))
options_templates.update(options_section(('vae', "VAE", "sd"), {
"sd_vae_explanation": OptionHTML("""