mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-20 05:10:15 +08:00
add fix_unet_prefix() to support unet only checkpoints
This commit is contained in:
parent
1318f6118e
commit
eee7294200
@ -282,6 +282,30 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||
return pl_sd
|
||||
|
||||
|
||||
def fix_unet_prefix(state_dict):
|
||||
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
|
||||
|
||||
for k in state_dict.keys():
|
||||
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
|
||||
if len(found) > 0:
|
||||
return state_dict
|
||||
|
||||
# no known prefix found.
|
||||
# in this case, this is a unet only state_dict
|
||||
known_keys = (
|
||||
"input_blocks.0.0.weight", # SD1.5, SD2, SDXL
|
||||
"joint_blocks.0.context_block.adaLN_modulation.1.weight", # SD3
|
||||
"double_blocks.0.img_attn.proj.weight", # FLUX
|
||||
)
|
||||
|
||||
if any(key in state_dict for key in known_keys):
|
||||
state_dict = {f"model.diffusion_model.{k}": v for k, v in state_dict.items()}
|
||||
print("Fixed state_dict keys...")
|
||||
return state_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def read_metadata_from_safetensors(filename):
|
||||
import json
|
||||
|
||||
@ -343,6 +367,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
res = read_state_dict(checkpoint_info.filename)
|
||||
res = fix_unet_prefix(res)
|
||||
timer.record("load weights from disk")
|
||||
|
||||
return res
|
||||
|
Loading…
Reference in New Issue
Block a user