diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 0fc1596b7..aa57c3e06 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -160,7 +160,7 @@ class LoadStateDictOnMeta(ReplaceHelper): self.state_dict = state_dict self.device = device self.weight_dtype_conversion = weight_dtype_conversion or {} - self.default_dtype = self.weight_dtype_conversion.get('') + self.default_dtype = self.weight_dtype_conversion.get('', None) def get_weight_dtype(self, key): key_first_term, _ = key.split('.', 1) @@ -183,7 +183,11 @@ class LoadStateDictOnMeta(ReplaceHelper): key = prefix + name sd_param = sd.pop(key, None) if sd_param is not None: - state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) + dtype = self.get_weight_dtype(key) + if dtype is None: + state_dict[key] = sd_param + else: + state_dict[key] = sd_param.to(dtype=dtype) used_param_keys.append(key) if param.is_meta: