Add support for SD 2.1 Turbo, by converting the state dict from SGM to LDM on load

This commit is contained in:
MrCheeze 2023-12-01 22:58:05 -05:00
parent 293f44e6c1
commit 6080045b2a

View File

@ -230,15 +230,19 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
checkpoint_dict_replacements = { checkpoint_dict_replacements_sd1 = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
} }
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
'conditioner.embedders.0.': 'cond_stage_model.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in checkpoint_dict_replacements.items(): def transform_checkpoint_dict_key(k, replacements):
for text, replacement in replacements.items():
if k.startswith(text): if k.startswith(text):
k = replacement + k[len(text):] k = replacement + k[len(text):]
@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd):
pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None) pl_sd.pop("state_dict", None)
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
sd = {} sd = {}
for k, v in pl_sd.items(): for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k) if is_sd2_turbo:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
else:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
if new_key is not None: if new_key is not None:
sd[new_key] = v sd[new_key] = v