detect diag_oft type

This commit is contained in:
v0xie 2023-11-02 00:11:32 -07:00
parent a2fad6ee05
commit 65ccd6305f

View File

@ -191,10 +191,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# kohya_ss OFT module
elif sd_module is None and "oft_unet" in key_network_without_network_parts: elif sd_module is None and "oft_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# KohakuBlueLeaf OFT module
if sd_module is None and "oft_diag" in key:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match[key_network] = key keys_failed_to_match[key_network] = key
continue continue