fix to support float8_*

This commit is contained in:
Won-Kyu Park 2024-08-31 23:22:43 +09:00
parent 219a0e2429
commit 9e57c722b2
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -160,7 +160,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
self.state_dict = state_dict self.state_dict = state_dict
self.device = device self.device = device
self.weight_dtype_conversion = weight_dtype_conversion or {} self.weight_dtype_conversion = weight_dtype_conversion or {}
self.default_dtype = self.weight_dtype_conversion.get('') self.default_dtype = self.weight_dtype_conversion.get('', None)
def get_weight_dtype(self, key): def get_weight_dtype(self, key):
key_first_term, _ = key.split('.', 1) key_first_term, _ = key.split('.', 1)
@ -183,7 +183,11 @@ class LoadStateDictOnMeta(ReplaceHelper):
key = prefix + name key = prefix + name
sd_param = sd.pop(key, None) sd_param = sd.pop(key, None)
if sd_param is not None: if sd_param is not None:
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) dtype = self.get_weight_dtype(key)
if dtype is None:
state_dict[key] = sd_param
else:
state_dict[key] = sd_param.to(dtype=dtype)
used_param_keys.append(key) used_param_keys.append(key)
if param.is_meta: if param.is_meta: