check Unet/VAE and load as is

- check float8 unet dtype to save memory
 - check vae/ text_encoders dtype and use as intended
This commit is contained in:
Won-Kyu Park 2024-09-05 09:34:08 +09:00
parent 39328bd7db
commit c972951cf6
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -407,6 +407,90 @@ def set_model_fields(model):
model.latent_channels = 4
def get_state_dict_dtype(state_dict):
# detect dtypes of state_dict
state_dict_dtype = {}
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
for k in state_dict.keys():
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
if len(found) > 0:
prefix = found[0]
dtype = state_dict[k].dtype
dtypes = state_dict_dtype.get(prefix, {})
if dtype in dtypes:
dtypes[dtype] += 1
else:
dtypes[dtype] = 1
state_dict_dtype[prefix] = dtypes
for prefix in state_dict_dtype:
dtypes = state_dict_dtype[prefix]
# sort by count
state_dict_dtype[prefix] = dict(sorted(dtypes.items(), key=lambda item: item[1], reverse=True))
print("Detected dtypes:", state_dict_dtype)
return state_dict_dtype
def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490):
if state_dict is not None:
state_dict_dtype = get_state_dict_dtype(state_dict)
aliases = {
"FP8": "F8",
"FP16": "F16",
"FP32": "F32",
}
loadables = {
"F8": (torch.float8_e4m3fn,),
"F16": (torch.float16,),
"F32": (torch.float32,),
"BF16": (torch.bfloat16,),
}
if dtype is None:
# get the first dtype
if prefix in state_dict_dtype:
return list(state_dict_dtype[prefix])[0]
return None
if dtype in aliases:
dtype = aliases[dtype]
loadable = loadables[dtype]
if prefix in state_dict_dtype:
dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable]
if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count:
# mostly dtype weights.
return dtypes[0]
return None
def get_vae_dtype(state_dict=None, state_dict_dtype=None):
if state_dict is not None:
state_dict_dtype = get_state_dict_dtype(state_dict)
if state_dict_dtype is None:
raise ValueError("fail to get vae dtype")
vae_prefixes = [prefix for prefix in ("vae.", "first_stage_model.") if prefix in state_dict_dtype]
if len(vae_prefixes) > 0:
vae_prefix = vae_prefixes[0]
for dtype in state_dict_dtype[vae_prefix]:
if state_dict_dtype[vae_prefix][dtype] > 240 and dtype in (torch.float16, torch.float32, torch.bfloat16):
# vae items: 248 for SD1, SDXL 245 for flux
return dtype
return None
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
@ -441,6 +525,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if hasattr(model, "before_load_weights"):
model.before_load_weights(state_dict)
# get all dtypes of state_dict
state_dict_dtype = get_state_dict_dtype(state_dict)
model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model")
@ -466,7 +553,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
if shared.cmd_opts.no_half:
# check dtype of vae
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
found_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype)
unet_has_float = found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16)
if (found_unet_dtype is None or unet_has_float) and shared.cmd_opts.no_half:
# unet type is not detected or unet has float dtypes
model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
@ -476,8 +569,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
# preserve bfloat16 if it supported
model.first_stage_model = None
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
elif shared.cmd_opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.cmd_opts.upcast_sampling and depth_model:
@ -485,15 +581,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
alphas_cumprod = model.alphas_cumprod
model.alphas_cumprod = None
model.half()
if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16):
model.half()
elif found_unet_dtype in (torch.float8_e4m3fn,):
pass
else:
print("Fail to get a vaild UNet dtype. ignore...")
model.alphas_cumprod = alphas_cumprod
model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
devices.dtype_unet = torch.float16
timer.record("apply half()")
if found_unet_dtype in (torch.float16, torch.float32):
devices.dtype_unet = torch.float16
timer.record("apply half()")
else:
print(f"load Unet {found_unet_dtype} as is ...")
devices.dtype_unet = found_unet_dtype if found_unet_dtype else torch.float16
timer.record("load UNet")
apply_alpha_schedule_override(model)
@ -503,10 +612,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if hasattr(module, 'fp16_bias'):
del module.fp16_bias
if check_fp8(model):
if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model):
devices.fp8 = True
# do not convert vae, text_encoders.clip_l, clip_g, t5xxl
first_stage = model.first_stage_model
model.first_stage_model = None
vae = getattr(model, 'vae', None)
if vae is not None:
model.vae = None
text_encoders = getattr(model, 'text_encoders', None)
if text_encoders is not None:
model.text_encoders = None
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
@ -514,6 +631,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
if text_encoders is not None:
model.text_encoders = text_encoders
if vae is not None:
model.vae = vae
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
@ -521,8 +642,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
timer.record("apply dtype to VAE")
# check supported vae dtype
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
devices.dtype_vae = torch.bfloat16
print(f"VAE dtype {dtype_vae} detected. load as is.")
else:
# use default devices.dtype_vae
model.first_stage_model.to(devices.dtype_vae)
print(f"Use VAE dtype {devices.dtype_vae}")
timer.record("apply dtype to VAE")
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
@ -818,6 +947,18 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
print(f"Creating model from config: {checkpoint_config}")
# get all dtypes of state_dict
state_dict_dtype = get_state_dict_dtype(state_dict)
# check loadable unet dtype before loading
loadable_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype)
# check dtype of vae
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
devices.dtype_vae = torch.bfloat16
print(f"VAE dtype {dtype_vae} detected.")
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
@ -843,7 +984,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
weight_dtype_conversion = {
'first_stage_model': None,
'alphas_cumprod': None,
'': torch.float16,
'': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None,
}
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):