fix position_ids

This commit is contained in:
Won-Kyu Park 2024-09-17 10:05:13 +09:00
parent 3b18b6f482
commit 2ffdf01e05
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -469,6 +469,25 @@ def get_vae_dtype(state_dict=None, state_dict_dtype=None):
return None
def fix_position_ids(state_dict, force=False):
# for SD1.5 or some SDXL with position_ids
for prefix in ("cond_stage_models.", "conditioner.embedders.0."):
position_id_key = f"{prefix}transformer.text_model.embeddings.position_ids"
if position_id_key in state_dict:
original = state_dict[position_id_key]
if original.dtype == torch.int64:
return
if force:
# regenerate
fixed = torch.tensor([list(range(77))], dtype=torch.int64, device=original.device)
else:
fixed = state_dict[position_id_key].to(torch.int64)
print(f"Warning: Fixed position_ids dtype from {original.dtype} to {fixed.dtype}")
state_dict[position_id_key] = fixed
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
@ -490,6 +509,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
else:
model.ztsnr = False
fix_position_ids(state_dict)
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)