mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-08 06:32:54 +08:00
simplified get_loadable_dtype
This commit is contained in:
parent
c972951cf6
commit
fcd609f4b4
@ -434,40 +434,13 @@ def get_state_dict_dtype(state_dict):
|
|||||||
return state_dict_dtype
|
return state_dict_dtype
|
||||||
|
|
||||||
|
|
||||||
def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490):
|
def get_loadable_dtype(prefix="model.diffusion_model.", state_dict=None, state_dict_dtype=None):
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||||
|
|
||||||
aliases = {
|
# get the first dtype
|
||||||
"FP8": "F8",
|
|
||||||
"FP16": "F16",
|
|
||||||
"FP32": "F32",
|
|
||||||
}
|
|
||||||
|
|
||||||
loadables = {
|
|
||||||
"F8": (torch.float8_e4m3fn,),
|
|
||||||
"F16": (torch.float16,),
|
|
||||||
"F32": (torch.float32,),
|
|
||||||
"BF16": (torch.bfloat16,),
|
|
||||||
}
|
|
||||||
|
|
||||||
if dtype is None:
|
|
||||||
# get the first dtype
|
|
||||||
if prefix in state_dict_dtype:
|
|
||||||
return list(state_dict_dtype[prefix])[0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
if dtype in aliases:
|
|
||||||
dtype = aliases[dtype]
|
|
||||||
loadable = loadables[dtype]
|
|
||||||
|
|
||||||
if prefix in state_dict_dtype:
|
if prefix in state_dict_dtype:
|
||||||
dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable]
|
return list(state_dict_dtype[prefix])[0]
|
||||||
if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count:
|
|
||||||
# mostly dtype weights.
|
|
||||||
return dtypes[0]
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user