use a local variable instead of dictionary entry for sd_merge_models in merge model metadata code

This commit is contained in:
AUTOMATIC 2023-05-17 17:44:07 +03:00
parent 36c14831b3
commit 76ebf750a4

View File

@ -245,7 +245,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
metadata = None metadata = None
if save_metadata: if save_metadata:
metadata = {"format": "pt", "sd_merge_models": {}} metadata = {"format": "pt"}
merge_recipe = { merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger "type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256, "primary_model_hash": primary_model_info.sha256,
@ -263,15 +264,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
} }
metadata["sd_merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
sd_merge_models = {}
def add_model_metadata(checkpoint_info): def add_model_metadata(checkpoint_info):
checkpoint_info.calculate_shorthash() checkpoint_info.calculate_shorthash()
metadata["sd_merge_models"][checkpoint_info.sha256] = { sd_merge_models[checkpoint_info.sha256] = {
"name": checkpoint_info.name, "name": checkpoint_info.name,
"legacy_hash": checkpoint_info.hash, "legacy_hash": checkpoint_info.hash,
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None) "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
} }
metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {})) sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
add_model_metadata(primary_model_info) add_model_metadata(primary_model_info)
if secondary_model_info: if secondary_model_info:
@ -279,7 +282,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info: if tertiary_model_info:
add_model_metadata(tertiary_model_info) add_model_metadata(tertiary_model_info)
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"]) metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":