do not use copy option for nn.Embedding

This commit is contained in:
Won-Kyu Park 2024-09-25 21:58:28 +09:00
parent ba499f92ac
commit 5f3314ec43
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -201,7 +201,7 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
org_dtype = param.dtype org_dtype = param.dtype
break break
if copy: if copy and not isinstance(self, torch.nn.Embedding):
copied = deepcopy(self) copied = deepcopy(self)
if org_dtype != target_dtype: if org_dtype != target_dtype:
copied.to(**params) copied.to(**params)
@ -266,8 +266,6 @@ def autocast(disable=False, current_dtype=None, target_dtype=None, target_device
if target_dtype is None: if target_dtype is None:
target_dtype = dtype target_dtype = dtype
if target_device is None:
target_device = device
if force_fp16: if force_fp16:
# No casting during inference if force_fp16 is enabled. # No casting during inference if force_fp16 is enabled.