mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-01 11:13:00 +08:00
fix to support float8_*
This commit is contained in:
parent
219a0e2429
commit
9e57c722b2
@ -160,7 +160,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||||||
self.state_dict = state_dict
|
self.state_dict = state_dict
|
||||||
self.device = device
|
self.device = device
|
||||||
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||||
self.default_dtype = self.weight_dtype_conversion.get('')
|
self.default_dtype = self.weight_dtype_conversion.get('', None)
|
||||||
|
|
||||||
def get_weight_dtype(self, key):
|
def get_weight_dtype(self, key):
|
||||||
key_first_term, _ = key.split('.', 1)
|
key_first_term, _ = key.split('.', 1)
|
||||||
@ -183,7 +183,11 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||||||
key = prefix + name
|
key = prefix + name
|
||||||
sd_param = sd.pop(key, None)
|
sd_param = sd.pop(key, None)
|
||||||
if sd_param is not None:
|
if sd_param is not None:
|
||||||
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
dtype = self.get_weight_dtype(key)
|
||||||
|
if dtype is None:
|
||||||
|
state_dict[key] = sd_param
|
||||||
|
else:
|
||||||
|
state_dict[key] = sd_param.to(dtype=dtype)
|
||||||
used_param_keys.append(key)
|
used_param_keys.append(key)
|
||||||
|
|
||||||
if param.is_meta:
|
if param.is_meta:
|
||||||
|
Loading…
Reference in New Issue
Block a user