simplified get_loadable_dtype

This commit is contained in:
Won-Kyu Park 2024-09-14 06:57:41 +09:00
parent c972951cf6
commit fcd609f4b4
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -434,40 +434,13 @@ def get_state_dict_dtype(state_dict):
return state_dict_dtype
def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490):
def get_loadable_dtype(prefix="model.diffusion_model.", state_dict=None, state_dict_dtype=None):
if state_dict is not None:
state_dict_dtype = get_state_dict_dtype(state_dict)
aliases = {
"FP8": "F8",
"FP16": "F16",
"FP32": "F32",
}
loadables = {
"F8": (torch.float8_e4m3fn,),
"F16": (torch.float16,),
"F32": (torch.float32,),
"BF16": (torch.bfloat16,),
}
if dtype is None:
# get the first dtype
if prefix in state_dict_dtype:
return list(state_dict_dtype[prefix])[0]
return None
if dtype in aliases:
dtype = aliases[dtype]
loadable = loadables[dtype]
# get the first dtype
if prefix in state_dict_dtype:
dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable]
if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count:
# mostly dtype weights.
return dtypes[0]
return list(state_dict_dtype[prefix])[0]
return None