restore org_dtype != compute dtype case

This commit is contained in:
Won-Kyu Park 2024-10-04 00:08:44 +09:00
parent b783a967c0
commit 310d0e6938
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -391,7 +391,8 @@ def restore_weights_backup(obj, field, weight):
setattr(obj, field, None)
return
getattr(obj, field).copy_(weight)
old_weight = getattr(obj, field)
old_weight.copy_(weight.to(dtype=old_weight.dtype))
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):