add fix_unet_prefix() to support unet only checkpoints

This commit is contained in:
Won-Kyu Park 2024-09-17 17:03:11 +09:00
parent 1318f6118e
commit eee7294200
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -282,6 +282,30 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
def fix_unet_prefix(state_dict):
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
for k in state_dict.keys():
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
if len(found) > 0:
return state_dict
# no known prefix found.
# in this case, this is a unet only state_dict
known_keys = (
"input_blocks.0.0.weight", # SD1.5, SD2, SDXL
"joint_blocks.0.context_block.adaLN_modulation.1.weight", # SD3
"double_blocks.0.img_attn.proj.weight", # FLUX
)
if any(key in state_dict for key in known_keys):
state_dict = {f"model.diffusion_model.{k}": v for k, v in state_dict.items()}
print("Fixed state_dict keys...")
return state_dict
return state_dict
def read_metadata_from_safetensors(filename):
import json
@ -343,6 +367,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
res = read_state_dict(checkpoint_info.filename)
res = fix_unet_prefix(res)
timer.record("load weights from disk")
return res