diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index aa57c3e06..d60fb5591 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -176,6 +176,11 @@ class LoadStateDictOnMeta(ReplaceHelper): def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] + if type(module) in (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.LayerNorm,): + # HACK add assign=True to local_metadata for some cases + args[0]['assign_to_params_buffers'] = True + + for name, param in module._parameters.items(): if param is None: continue