From fcd609f4b4e04263361420a6ecc1d5252ea6fb28 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 14 Sep 2024 06:57:41 +0900 Subject: [PATCH] simplified get_loadable_dtype --- modules/sd_models.py | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 43e0b9208..75ce55e2a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -434,40 +434,13 @@ def get_state_dict_dtype(state_dict): return state_dict_dtype -def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490): +def get_loadable_dtype(prefix="model.diffusion_model.", state_dict=None, state_dict_dtype=None): if state_dict is not None: state_dict_dtype = get_state_dict_dtype(state_dict) - aliases = { - "FP8": "F8", - "FP16": "F16", - "FP32": "F32", - } - - loadables = { - "F8": (torch.float8_e4m3fn,), - "F16": (torch.float16,), - "F32": (torch.float32,), - "BF16": (torch.bfloat16,), - } - - if dtype is None: - # get the first dtype - if prefix in state_dict_dtype: - return list(state_dict_dtype[prefix])[0] - return None - - - if dtype in aliases: - dtype = aliases[dtype] - loadable = loadables[dtype] - + # get the first dtype if prefix in state_dict_dtype: - dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable] - if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count: - # mostly dtype weights. - return dtypes[0] - + return list(state_dict_dtype[prefix])[0] return None