mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-01 11:13:00 +08:00
fix load_vae() to check size mismatch
This commit is contained in:
parent
1e73a28707
commit
1318f6118e
@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
|||||||
|
|
||||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||||
|
|
||||||
|
loaded = False
|
||||||
if vae_file:
|
if vae_file:
|
||||||
if cache_enabled and vae_file in checkpoints_loaded:
|
if cache_enabled and vae_file in checkpoints_loaded:
|
||||||
# use vae checkpoint cache
|
# use vae checkpoint cache
|
||||||
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
||||||
store_base_vae(model)
|
store_base_vae(model)
|
||||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
loaded = _load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||||
else:
|
else:
|
||||||
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
||||||
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||||
store_base_vae(model)
|
store_base_vae(model)
|
||||||
|
|
||||||
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
||||||
_load_vae_dict(model, vae_dict_1)
|
loaded = _load_vae_dict(model, vae_dict_1)
|
||||||
|
|
||||||
if cache_enabled:
|
if loaded and cache_enabled:
|
||||||
# cache newly loaded vae
|
# cache newly loaded vae
|
||||||
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||||
|
|
||||||
# clean up cache if limit is reached
|
# clean up cache if limit is reached
|
||||||
if cache_enabled:
|
if loaded and cache_enabled:
|
||||||
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||||
checkpoints_loaded.popitem(last=False) # LRU
|
checkpoints_loaded.popitem(last=False) # LRU
|
||||||
|
|
||||||
# If vae used is not in dict, update it
|
# If vae used is not in dict, update it
|
||||||
# It will be removed on refresh though
|
# It will be removed on refresh though
|
||||||
vae_opt = get_filename(vae_file)
|
vae_opt = get_filename(vae_file)
|
||||||
if vae_opt not in vae_dict:
|
if loaded and vae_opt not in vae_dict:
|
||||||
vae_dict[vae_opt] = vae_file
|
vae_dict[vae_opt] = vae_file
|
||||||
|
|
||||||
elif loaded_vae_file:
|
elif loaded_vae_file:
|
||||||
restore_base_vae(model)
|
restore_base_vae(model)
|
||||||
|
loaded = True
|
||||||
|
|
||||||
|
if loaded:
|
||||||
loaded_vae_file = vae_file
|
loaded_vae_file = vae_file
|
||||||
model.base_vae = base_vae
|
model.base_vae = base_vae
|
||||||
model.loaded_vae_file = loaded_vae_file
|
model.loaded_vae_file = loaded_vae_file
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
# don't call this from outside
|
# don't call this from outside
|
||||||
def _load_vae_dict(model, vae_dict_1):
|
def _load_vae_dict(model, vae_dict_1):
|
||||||
|
conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight")
|
||||||
|
# check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
|
||||||
|
if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape:
|
||||||
|
print("Failed to load VAE. Size mismatched!")
|
||||||
|
return False
|
||||||
|
|
||||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def clear_loaded_vae():
|
def clear_loaded_vae():
|
||||||
@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
load_vae(sd_model, vae_file, vae_source)
|
loaded = load_vae(sd_model, vae_file, vae_source)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
|
if loaded:
|
||||||
print("VAE weights loaded.")
|
print("VAE weights loaded.")
|
||||||
return sd_model
|
return sd_model
|
||||||
|
Loading…
Reference in New Issue
Block a user