mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
fix position_ids
This commit is contained in:
parent
3b18b6f482
commit
2ffdf01e05
@ -469,6 +469,25 @@ def get_vae_dtype(state_dict=None, state_dict_dtype=None):
|
||||
return None
|
||||
|
||||
|
||||
def fix_position_ids(state_dict, force=False):
|
||||
# for SD1.5 or some SDXL with position_ids
|
||||
for prefix in ("cond_stage_models.", "conditioner.embedders.0."):
|
||||
position_id_key = f"{prefix}transformer.text_model.embeddings.position_ids"
|
||||
if position_id_key in state_dict:
|
||||
original = state_dict[position_id_key]
|
||||
if original.dtype == torch.int64:
|
||||
return
|
||||
|
||||
if force:
|
||||
# regenerate
|
||||
fixed = torch.tensor([list(range(77))], dtype=torch.int64, device=original.device)
|
||||
else:
|
||||
fixed = state_dict[position_id_key].to(torch.int64)
|
||||
print(f"Warning: Fixed position_ids dtype from {original.dtype} to {fixed.dtype}")
|
||||
|
||||
state_dict[position_id_key] = fixed
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
@ -490,6 +509,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
else:
|
||||
model.ztsnr = False
|
||||
|
||||
fix_position_ids(state_dict)
|
||||
|
||||
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user