mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-08 06:32:54 +08:00
simplified get_loadable_dtype
This commit is contained in:
parent
c972951cf6
commit
fcd609f4b4
@ -434,43 +434,16 @@ def get_state_dict_dtype(state_dict):
|
||||
return state_dict_dtype
|
||||
|
||||
|
||||
def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490):
|
||||
def get_loadable_dtype(prefix="model.diffusion_model.", state_dict=None, state_dict_dtype=None):
|
||||
if state_dict is not None:
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
aliases = {
|
||||
"FP8": "F8",
|
||||
"FP16": "F16",
|
||||
"FP32": "F32",
|
||||
}
|
||||
|
||||
loadables = {
|
||||
"F8": (torch.float8_e4m3fn,),
|
||||
"F16": (torch.float16,),
|
||||
"F32": (torch.float32,),
|
||||
"BF16": (torch.bfloat16,),
|
||||
}
|
||||
|
||||
if dtype is None:
|
||||
# get the first dtype
|
||||
if prefix in state_dict_dtype:
|
||||
return list(state_dict_dtype[prefix])[0]
|
||||
return None
|
||||
|
||||
|
||||
if dtype in aliases:
|
||||
dtype = aliases[dtype]
|
||||
loadable = loadables[dtype]
|
||||
|
||||
if prefix in state_dict_dtype:
|
||||
dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable]
|
||||
if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count:
|
||||
# mostly dtype weights.
|
||||
return dtypes[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_vae_dtype(state_dict=None, state_dict_dtype=None):
|
||||
if state_dict is not None:
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
Loading…
Reference in New Issue
Block a user