From c972951cf69d46bee9c09bd64e0215378f4fa852 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 5 Sep 2024 09:34:08 +0900 Subject: [PATCH] check Unet/VAE and load as is - check float8 unet dtype to save memory - check vae/ text_encoders dtype and use as intended --- modules/sd_models.py | 159 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 150 insertions(+), 9 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 1c7d370e9..43e0b9208 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,90 @@ def set_model_fields(model): model.latent_channels = 4 +def get_state_dict_dtype(state_dict): + # detect dtypes of state_dict + state_dict_dtype = {} + + known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.") + + for k in state_dict.keys(): + found = [prefix for prefix in known_prefixes if k.startswith(prefix)] + if len(found) > 0: + prefix = found[0] + dtype = state_dict[k].dtype + dtypes = state_dict_dtype.get(prefix, {}) + if dtype in dtypes: + dtypes[dtype] += 1 + else: + dtypes[dtype] = 1 + state_dict_dtype[prefix] = dtypes + + for prefix in state_dict_dtype: + dtypes = state_dict_dtype[prefix] + # sort by count + state_dict_dtype[prefix] = dict(sorted(dtypes.items(), key=lambda item: item[1], reverse=True)) + + print("Detected dtypes:", state_dict_dtype) + return state_dict_dtype + + +def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490): + if state_dict is not None: + state_dict_dtype = get_state_dict_dtype(state_dict) + + aliases = { + "FP8": "F8", + "FP16": "F16", + "FP32": "F32", + } + + loadables = { + "F8": (torch.float8_e4m3fn,), + "F16": (torch.float16,), + "F32": (torch.float32,), + "BF16": (torch.bfloat16,), + } + + if dtype is None: + # get the first dtype + if prefix in state_dict_dtype: + return list(state_dict_dtype[prefix])[0] + return None + + + if dtype in aliases: + dtype = aliases[dtype] + loadable = loadables[dtype] + + if prefix in state_dict_dtype: + dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable] + if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count: + # mostly dtype weights. + return dtypes[0] + + return None + + +def get_vae_dtype(state_dict=None, state_dict_dtype=None): + if state_dict is not None: + state_dict_dtype = get_state_dict_dtype(state_dict) + + if state_dict_dtype is None: + raise ValueError("fail to get vae dtype") + + + vae_prefixes = [prefix for prefix in ("vae.", "first_stage_model.") if prefix in state_dict_dtype] + + if len(vae_prefixes) > 0: + vae_prefix = vae_prefixes[0] + for dtype in state_dict_dtype[vae_prefix]: + if state_dict_dtype[vae_prefix][dtype] > 240 and dtype in (torch.float16, torch.float32, torch.bfloat16): + # vae items: 248 for SD1, SDXL 245 for flux + return dtype + + return None + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") @@ -441,6 +525,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(model, "before_load_weights"): model.before_load_weights(state_dict) + # get all dtypes of state_dict + state_dict_dtype = get_state_dict_dtype(state_dict) + model.load_state_dict(state_dict, strict=False) timer.record("apply weights to model") @@ -466,7 +553,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.to(memory_format=torch.channels_last) timer.record("apply channels_last") - if shared.cmd_opts.no_half: + # check dtype of vae + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + found_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype) + unet_has_float = found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) + + if (found_unet_dtype is None or unet_has_float) and shared.cmd_opts.no_half: + # unet type is not detected or unet has float dtypes model.float() model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 @@ -476,8 +569,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer vae = model.first_stage_model depth_model = getattr(model, 'depth_model', None) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + # preserve bfloat16 if it supported + model.first_stage_model = None # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: + elif shared.cmd_opts.no_half_vae: model.first_stage_model = None # with --upcast-sampling, don't convert the depth model weights to float16 if shared.cmd_opts.upcast_sampling and depth_model: @@ -485,15 +581,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer alphas_cumprod = model.alphas_cumprod model.alphas_cumprod = None - model.half() + + + if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16): + model.half() + elif found_unet_dtype in (torch.float8_e4m3fn,): + pass + else: + print("Fail to get a vaild UNet dtype. ignore...") + model.alphas_cumprod = alphas_cumprod model.alphas_cumprod_original = alphas_cumprod model.first_stage_model = vae if depth_model: model.depth_model = depth_model - devices.dtype_unet = torch.float16 - timer.record("apply half()") + if found_unet_dtype in (torch.float16, torch.float32): + devices.dtype_unet = torch.float16 + timer.record("apply half()") + else: + print(f"load Unet {found_unet_dtype} as is ...") + devices.dtype_unet = found_unet_dtype if found_unet_dtype else torch.float16 + timer.record("load UNet") apply_alpha_schedule_override(model) @@ -503,10 +612,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(module, 'fp16_bias'): del module.fp16_bias - if check_fp8(model): + if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model): devices.fp8 = True + + # do not convert vae, text_encoders.clip_l, clip_g, t5xxl first_stage = model.first_stage_model model.first_stage_model = None + vae = getattr(model, 'vae', None) + if vae is not None: + model.vae = None + text_encoders = getattr(model, 'text_encoders', None) + if text_encoders is not None: + model.text_encoders = None for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: @@ -514,6 +631,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if module.bias is not None: module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) + if text_encoders is not None: + model.text_encoders = text_encoders + if vae is not None: + model.vae = vae model.first_stage_model = first_stage timer.record("apply fp8") else: @@ -521,8 +642,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) - timer.record("apply dtype to VAE") + # check supported vae dtype + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + devices.dtype_vae = torch.bfloat16 + print(f"VAE dtype {dtype_vae} detected. load as is.") + else: + # use default devices.dtype_vae + model.first_stage_model.to(devices.dtype_vae) + print(f"Use VAE dtype {devices.dtype_vae}") + timer.record("apply dtype to VAE") # clean up cache if limit is reached while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: @@ -818,6 +947,18 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ print(f"Creating model from config: {checkpoint_config}") + # get all dtypes of state_dict + state_dict_dtype = get_state_dict_dtype(state_dict) + + # check loadable unet dtype before loading + loadable_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype) + + # check dtype of vae + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + devices.dtype_vae = torch.bfloat16 + print(f"VAE dtype {dtype_vae} detected.") + sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): @@ -843,7 +984,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ weight_dtype_conversion = { 'first_stage_model': None, 'alphas_cumprod': None, - '': torch.float16, + '': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None, } with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):