use shared.opts.lora_without_backup_weight option in the devices.autocast()

* add nn.Embedding in the devices.autocast()
 * do not cast forward args for some cases
 * add copy option in the devices.autocast()
This commit is contained in:
Won-Kyu Park 2024-09-25 02:20:28 +09:00
parent 03516f48f0
commit ba499f92ac
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -167,6 +167,7 @@ patch_module_list = [
torch.nn.MultiheadAttention,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
torch.nn.Embedding,
]
@ -175,6 +176,10 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
if supports_non_blocking():
params['non_blocking'] = True
supported_cast_dtypes = [torch.float16, torch.float32]
if torch.cuda.is_bf16_supported():
supported_cast_dtypes += [torch.bfloat16]
def forward_wrapper(self, *args, **kwargs):
if target_device is not None:
params['device'] = target_device
@ -182,11 +187,13 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
args = list(args)
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
args[j] = args[j].to(**params)
if args[j].dtype in supported_cast_dtypes:
args[j] = args[j].to(**params)
args = tuple(args)
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
kwargs[key] = kwargs[key].to(**params)
if kwargs[key].dtype in supported_cast_dtypes:
kwargs[key] = kwargs[key].to(**params)
org_dtype = target_dtype
for param in self.parameters():
@ -227,10 +234,9 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
@contextlib.contextmanager
def manual_cast(target_dtype, target_device=None):
def manual_cast(target_dtype, target_device=None, copy=None):
applied = False
copy = shared.opts.lora_without_backup_weight
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
@ -252,10 +258,12 @@ def manual_cast(target_dtype, target_device=None):
delattr(module_type, "org_forward")
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None):
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None, copy=None):
if disable:
return contextlib.nullcontext()
copy = copy if copy is not None else shared.opts.lora_without_backup_weight
if target_dtype is None:
target_dtype = dtype
if target_device is None:
@ -270,13 +278,13 @@ def autocast(disable=False, current_dtype=None, target_dtype=None, target_device
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and dtype_inference == torch.float32:
return manual_cast(target_dtype, target_device)
return manual_cast(target_dtype, target_device, copy=copy)
if target_dtype != dtype_inference:
return manual_cast(target_dtype, target_device)
if target_dtype != dtype_inference or copy:
return manual_cast(target_dtype, target_device, copy=copy)
if current_dtype is not None and current_dtype != target_dtype:
return manual_cast(target_dtype, target_device)
return manual_cast(target_dtype, target_device, copy=copy)
if target_dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()