mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-13 00:52:56 +08:00
check Unet/VAE and load as is
- check float8 unet dtype to save memory - check vae/ text_encoders dtype and use as intended
This commit is contained in:
parent
39328bd7db
commit
c972951cf6
@ -407,6 +407,90 @@ def set_model_fields(model):
|
||||
model.latent_channels = 4
|
||||
|
||||
|
||||
def get_state_dict_dtype(state_dict):
|
||||
# detect dtypes of state_dict
|
||||
state_dict_dtype = {}
|
||||
|
||||
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
|
||||
|
||||
for k in state_dict.keys():
|
||||
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
|
||||
if len(found) > 0:
|
||||
prefix = found[0]
|
||||
dtype = state_dict[k].dtype
|
||||
dtypes = state_dict_dtype.get(prefix, {})
|
||||
if dtype in dtypes:
|
||||
dtypes[dtype] += 1
|
||||
else:
|
||||
dtypes[dtype] = 1
|
||||
state_dict_dtype[prefix] = dtypes
|
||||
|
||||
for prefix in state_dict_dtype:
|
||||
dtypes = state_dict_dtype[prefix]
|
||||
# sort by count
|
||||
state_dict_dtype[prefix] = dict(sorted(dtypes.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
print("Detected dtypes:", state_dict_dtype)
|
||||
return state_dict_dtype
|
||||
|
||||
|
||||
def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490):
|
||||
if state_dict is not None:
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
aliases = {
|
||||
"FP8": "F8",
|
||||
"FP16": "F16",
|
||||
"FP32": "F32",
|
||||
}
|
||||
|
||||
loadables = {
|
||||
"F8": (torch.float8_e4m3fn,),
|
||||
"F16": (torch.float16,),
|
||||
"F32": (torch.float32,),
|
||||
"BF16": (torch.bfloat16,),
|
||||
}
|
||||
|
||||
if dtype is None:
|
||||
# get the first dtype
|
||||
if prefix in state_dict_dtype:
|
||||
return list(state_dict_dtype[prefix])[0]
|
||||
return None
|
||||
|
||||
|
||||
if dtype in aliases:
|
||||
dtype = aliases[dtype]
|
||||
loadable = loadables[dtype]
|
||||
|
||||
if prefix in state_dict_dtype:
|
||||
dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable]
|
||||
if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count:
|
||||
# mostly dtype weights.
|
||||
return dtypes[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_vae_dtype(state_dict=None, state_dict_dtype=None):
|
||||
if state_dict is not None:
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
if state_dict_dtype is None:
|
||||
raise ValueError("fail to get vae dtype")
|
||||
|
||||
|
||||
vae_prefixes = [prefix for prefix in ("vae.", "first_stage_model.") if prefix in state_dict_dtype]
|
||||
|
||||
if len(vae_prefixes) > 0:
|
||||
vae_prefix = vae_prefixes[0]
|
||||
for dtype in state_dict_dtype[vae_prefix]:
|
||||
if state_dict_dtype[vae_prefix][dtype] > 240 and dtype in (torch.float16, torch.float32, torch.bfloat16):
|
||||
# vae items: 248 for SD1, SDXL 245 for flux
|
||||
return dtype
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
@ -441,6 +525,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if hasattr(model, "before_load_weights"):
|
||||
model.before_load_weights(state_dict)
|
||||
|
||||
# get all dtypes of state_dict
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
|
||||
@ -466,7 +553,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
# check dtype of vae
|
||||
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
|
||||
found_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype)
|
||||
unet_has_float = found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16)
|
||||
|
||||
if (found_unet_dtype is None or unet_has_float) and shared.cmd_opts.no_half:
|
||||
# unet type is not detected or unet has float dtypes
|
||||
model.float()
|
||||
model.alphas_cumprod_original = model.alphas_cumprod
|
||||
devices.dtype_unet = torch.float32
|
||||
@ -476,8 +569,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
|
||||
# preserve bfloat16 if it supported
|
||||
model.first_stage_model = None
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
elif shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
@ -485,15 +581,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
alphas_cumprod = model.alphas_cumprod
|
||||
model.alphas_cumprod = None
|
||||
model.half()
|
||||
|
||||
|
||||
if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16):
|
||||
model.half()
|
||||
elif found_unet_dtype in (torch.float8_e4m3fn,):
|
||||
pass
|
||||
else:
|
||||
print("Fail to get a vaild UNet dtype. ignore...")
|
||||
|
||||
model.alphas_cumprod = alphas_cumprod
|
||||
model.alphas_cumprod_original = alphas_cumprod
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
if found_unet_dtype in (torch.float16, torch.float32):
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
else:
|
||||
print(f"load Unet {found_unet_dtype} as is ...")
|
||||
devices.dtype_unet = found_unet_dtype if found_unet_dtype else torch.float16
|
||||
timer.record("load UNet")
|
||||
|
||||
apply_alpha_schedule_override(model)
|
||||
|
||||
@ -503,10 +612,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if check_fp8(model):
|
||||
if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model):
|
||||
devices.fp8 = True
|
||||
|
||||
# do not convert vae, text_encoders.clip_l, clip_g, t5xxl
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
vae = getattr(model, 'vae', None)
|
||||
if vae is not None:
|
||||
model.vae = None
|
||||
text_encoders = getattr(model, 'text_encoders', None)
|
||||
if text_encoders is not None:
|
||||
model.text_encoders = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
if shared.opts.cache_fp16_weight:
|
||||
@ -514,6 +631,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
if text_encoders is not None:
|
||||
model.text_encoders = text_encoders
|
||||
if vae is not None:
|
||||
model.vae = vae
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
@ -521,8 +642,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
timer.record("apply dtype to VAE")
|
||||
# check supported vae dtype
|
||||
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
|
||||
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
|
||||
devices.dtype_vae = torch.bfloat16
|
||||
print(f"VAE dtype {dtype_vae} detected. load as is.")
|
||||
else:
|
||||
# use default devices.dtype_vae
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
print(f"Use VAE dtype {devices.dtype_vae}")
|
||||
timer.record("apply dtype to VAE")
|
||||
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
@ -818,6 +947,18 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
|
||||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
|
||||
# get all dtypes of state_dict
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
# check loadable unet dtype before loading
|
||||
loadable_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype)
|
||||
|
||||
# check dtype of vae
|
||||
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
|
||||
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
|
||||
devices.dtype_vae = torch.bfloat16
|
||||
print(f"VAE dtype {dtype_vae} detected.")
|
||||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
@ -843,7 +984,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
|
||||
weight_dtype_conversion = {
|
||||
'first_stage_model': None,
|
||||
'alphas_cumprod': None,
|
||||
'': torch.float16,
|
||||
'': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None,
|
||||
}
|
||||
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
|
Loading…
Reference in New Issue
Block a user