From 5f3314ec43b099a06637a207a1d5e4252d59ad58 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 25 Sep 2024 21:58:28 +0900 Subject: [PATCH] do not use copy option for nn.Embedding --- modules/devices.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ec6ec5634..5b763ec85 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -201,7 +201,7 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False): org_dtype = param.dtype break - if copy: + if copy and not isinstance(self, torch.nn.Embedding): copied = deepcopy(self) if org_dtype != target_dtype: copied.to(**params) @@ -266,8 +266,6 @@ def autocast(disable=False, current_dtype=None, target_dtype=None, target_device if target_dtype is None: target_dtype = dtype - if target_device is None: - target_device = device if force_fp16: # No casting during inference if force_fp16 is enabled.