do not use copy option for nn.Embedding

This commit is contained in:
Won-Kyu Park 2024-09-25 21:58:28 +09:00
parent ba499f92ac
commit 5f3314ec43
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -201,7 +201,7 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
org_dtype = param.dtype
break
if copy:
if copy and not isinstance(self, torch.nn.Embedding):
copied = deepcopy(self)
if org_dtype != target_dtype:
copied.to(**params)
@ -266,8 +266,6 @@ def autocast(disable=False, current_dtype=None, target_dtype=None, target_device
if target_dtype is None:
target_dtype = dtype
if target_device is None:
target_device = device
if force_fp16:
# No casting during inference if force_fp16 is enabled.