mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-08 14:42:54 +08:00
fix for float8_e5m2 freeze model
This commit is contained in:
parent
2ffdf01e05
commit
1e73a28707
@ -585,7 +585,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16):
|
if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16):
|
||||||
model.half()
|
model.half()
|
||||||
elif found_unet_dtype in (torch.float8_e4m3fn,):
|
elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print("Fail to get a vaild UNet dtype. ignore...")
|
print("Fail to get a vaild UNet dtype. ignore...")
|
||||||
@ -612,7 +612,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
if hasattr(module, 'fp16_bias'):
|
if hasattr(module, 'fp16_bias'):
|
||||||
del module.fp16_bias
|
del module.fp16_bias
|
||||||
|
|
||||||
if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model):
|
if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model):
|
||||||
devices.fp8 = True
|
devices.fp8 = True
|
||||||
|
|
||||||
# do not convert vae, text_encoders.clip_l, clip_g, t5xxl
|
# do not convert vae, text_encoders.clip_l, clip_g, t5xxl
|
||||||
|
Loading…
Reference in New Issue
Block a user