diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index d60fb5591..47f98416e 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -176,7 +176,7 @@ class LoadStateDictOnMeta(ReplaceHelper): def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] - if type(module) in (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.LayerNorm,): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.LayerNorm)): # HACK add assign=True to local_metadata for some cases args[0]['assign_to_params_buffers'] = True