diff --git a/modules/sd_models.py b/modules/sd_models.py index 6dffdc036..da9b7c6f4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -282,6 +282,30 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def fix_unet_prefix(state_dict): + known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.") + + for k in state_dict.keys(): + found = [prefix for prefix in known_prefixes if k.startswith(prefix)] + if len(found) > 0: + return state_dict + + # no known prefix found. + # in this case, this is a unet only state_dict + known_keys = ( + "input_blocks.0.0.weight", # SD1.5, SD2, SDXL + "joint_blocks.0.context_block.adaLN_modulation.1.weight", # SD3 + "double_blocks.0.img_attn.proj.weight", # FLUX + ) + + if any(key in state_dict for key in known_keys): + state_dict = {f"model.diffusion_model.{k}": v for k, v in state_dict.items()} + print("Fixed state_dict keys...") + return state_dict + + return state_dict + + def read_metadata_from_safetensors(filename): import json @@ -343,6 +367,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") res = read_state_dict(checkpoint_info.filename) + res = fix_unet_prefix(res) timer.record("load weights from disk") return res