mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-01 11:13:00 +08:00
fix load_vae() to check size mismatch
This commit is contained in:
parent
1e73a28707
commit
1318f6118e
@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
||||
loaded = False
|
||||
if vae_file:
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
loaded = _load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||
store_base_vae(model)
|
||||
|
||||
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
loaded = _load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
if loaded and cache_enabled:
|
||||
# cache newly loaded vae
|
||||
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
if loaded and cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
if loaded and vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
loaded = True
|
||||
|
||||
if loaded:
|
||||
loaded_vae_file = vae_file
|
||||
model.base_vae = base_vae
|
||||
model.loaded_vae_file = loaded_vae_file
|
||||
return loaded
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight")
|
||||
# check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
|
||||
if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape:
|
||||
print("Failed to load VAE. Size mismatched!")
|
||||
return False
|
||||
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
return True
|
||||
|
||||
|
||||
def clear_loaded_vae():
|
||||
@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file, vae_source)
|
||||
loaded = load_vae(sd_model, vae_file, vae_source)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if loaded:
|
||||
print("VAE weights loaded.")
|
||||
return sd_model
|
||||
|
Loading…
Reference in New Issue
Block a user