use assign=True for some cases

This commit is contained in:
Won-Kyu Park 2024-09-18 02:18:00 +09:00
parent eee7294200
commit 6675d1f090
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -176,6 +176,11 @@ class LoadStateDictOnMeta(ReplaceHelper):
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = []
if type(module) in (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.LayerNorm,):
# HACK add assign=True to local_metadata for some cases
args[0]['assign_to_params_buffers'] = True
for name, param in module._parameters.items():
if param is None:
continue