diff --git a/modules/sd_models.py b/modules/sd_models.py index 7e48c2328..3e0b577bb 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -469,6 +469,25 @@ def get_vae_dtype(state_dict=None, state_dict_dtype=None): return None +def fix_position_ids(state_dict, force=False): + # for SD1.5 or some SDXL with position_ids + for prefix in ("cond_stage_models.", "conditioner.embedders.0."): + position_id_key = f"{prefix}transformer.text_model.embeddings.position_ids" + if position_id_key in state_dict: + original = state_dict[position_id_key] + if original.dtype == torch.int64: + return + + if force: + # regenerate + fixed = torch.tensor([list(range(77))], dtype=torch.int64, device=original.device) + else: + fixed = state_dict[position_id_key].to(torch.int64) + print(f"Warning: Fixed position_ids dtype from {original.dtype} to {fixed.dtype}") + + state_dict[position_id_key] = fixed + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") @@ -490,6 +509,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer else: model.ztsnr = False + fix_position_ids(state_dict) + + if model.is_sdxl: sd_models_xl.extend_sdxl(model)