fix flux to use float8 t5xxl

This commit is contained in:
Won-Kyu Park 2024-09-20 00:00:34 +09:00
parent f569f6eb1e
commit 28eca46959
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -108,7 +108,7 @@ class FluxCond(torch.nn.Module):
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file: with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)