fix load_vae() to check size mismatch

This commit is contained in:
Won-Kyu Park 2024-09-17 11:05:48 +09:00
parent 1e73a28707
commit 1318f6118e
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
loaded = False
if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
loaded = _load_vae_dict(model, checkpoints_loaded[vae_file])
else:
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
loaded = _load_vae_dict(model, vae_dict_1)
if cache_enabled:
if loaded and cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
if loaded and cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
if loaded and vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
elif loaded_vae_file:
restore_base_vae(model)
loaded = True
loaded_vae_file = vae_file
if loaded:
loaded_vae_file = vae_file
model.base_vae = base_vae
model.loaded_vae_file = loaded_vae_file
return loaded
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight")
# check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape:
print("Failed to load VAE. Size mismatched!")
return False
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
return True
def clear_loaded_vae():
@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file, vae_source)
loaded = load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
script_callbacks.model_loaded_callback(sd_model)
print("VAE weights loaded.")
if loaded:
print("VAE weights loaded.")
return sd_model