mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-08 06:32:54 +08:00
fix for float8_e5m2 freeze model
This commit is contained in:
parent
2ffdf01e05
commit
1e73a28707
@ -585,7 +585,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16):
|
||||
model.half()
|
||||
elif found_unet_dtype in (torch.float8_e4m3fn,):
|
||||
elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||
pass
|
||||
else:
|
||||
print("Fail to get a vaild UNet dtype. ignore...")
|
||||
@ -612,7 +612,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model):
|
||||
if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model):
|
||||
devices.fp8 = True
|
||||
|
||||
# do not convert vae, text_encoders.clip_l, clip_g, t5xxl
|
||||
|
Loading…
Reference in New Issue
Block a user