fix for float8_e5m2 freeze model

This commit is contained in:
Won-Kyu Park 2024-09-17 10:07:58 +09:00
parent 2ffdf01e05
commit 1e73a28707
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -585,7 +585,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16): if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16):
model.half() model.half()
elif found_unet_dtype in (torch.float8_e4m3fn,): elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
pass pass
else: else:
print("Fail to get a vaild UNet dtype. ignore...") print("Fail to get a vaild UNet dtype. ignore...")
@ -612,7 +612,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if hasattr(module, 'fp16_bias'): if hasattr(module, 'fp16_bias'):
del module.fp16_bias del module.fp16_bias
if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model): if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model):
devices.fp8 = True devices.fp8 = True
# do not convert vae, text_encoders.clip_l, clip_g, t5xxl # do not convert vae, text_encoders.clip_l, clip_g, t5xxl