mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-01 11:13:00 +08:00
fix to support float8_*
This commit is contained in:
parent
219a0e2429
commit
9e57c722b2
@ -160,7 +160,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
||||
self.state_dict = state_dict
|
||||
self.device = device
|
||||
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||
self.default_dtype = self.weight_dtype_conversion.get('')
|
||||
self.default_dtype = self.weight_dtype_conversion.get('', None)
|
||||
|
||||
def get_weight_dtype(self, key):
|
||||
key_first_term, _ = key.split('.', 1)
|
||||
@ -183,7 +183,11 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
||||
key = prefix + name
|
||||
sd_param = sd.pop(key, None)
|
||||
if sd_param is not None:
|
||||
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||
dtype = self.get_weight_dtype(key)
|
||||
if dtype is None:
|
||||
state_dict[key] = sd_param
|
||||
else:
|
||||
state_dict[key] = sd_param.to(dtype=dtype)
|
||||
used_param_keys.append(key)
|
||||
|
||||
if param.is_meta:
|
||||
|
Loading…
Reference in New Issue
Block a user