mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-01 03:03:00 +08:00
add fix_unet_prefix() to support unet only checkpoints
This commit is contained in:
parent
1318f6118e
commit
eee7294200
@ -282,6 +282,30 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||||||
return pl_sd
|
return pl_sd
|
||||||
|
|
||||||
|
|
||||||
|
def fix_unet_prefix(state_dict):
|
||||||
|
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
|
||||||
|
|
||||||
|
for k in state_dict.keys():
|
||||||
|
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
|
||||||
|
if len(found) > 0:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
# no known prefix found.
|
||||||
|
# in this case, this is a unet only state_dict
|
||||||
|
known_keys = (
|
||||||
|
"input_blocks.0.0.weight", # SD1.5, SD2, SDXL
|
||||||
|
"joint_blocks.0.context_block.adaLN_modulation.1.weight", # SD3
|
||||||
|
"double_blocks.0.img_attn.proj.weight", # FLUX
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(key in state_dict for key in known_keys):
|
||||||
|
state_dict = {f"model.diffusion_model.{k}": v for k, v in state_dict.items()}
|
||||||
|
print("Fixed state_dict keys...")
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def read_metadata_from_safetensors(filename):
|
def read_metadata_from_safetensors(filename):
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -343,6 +367,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
|||||||
|
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||||
res = read_state_dict(checkpoint_info.filename)
|
res = read_state_dict(checkpoint_info.filename)
|
||||||
|
res = fix_unet_prefix(res)
|
||||||
timer.record("load weights from disk")
|
timer.record("load weights from disk")
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
Loading…
Reference in New Issue
Block a user