mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
use shared.opts.lora_without_backup_weight option in the devices.autocast()
* add nn.Embedding in the devices.autocast() * do not cast forward args for some cases * add copy option in the devices.autocast()
This commit is contained in:
parent
03516f48f0
commit
ba499f92ac
@ -167,6 +167,7 @@ patch_module_list = [
|
||||
torch.nn.MultiheadAttention,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
torch.nn.Embedding,
|
||||
]
|
||||
|
||||
|
||||
@ -175,6 +176,10 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
|
||||
if supports_non_blocking():
|
||||
params['non_blocking'] = True
|
||||
|
||||
supported_cast_dtypes = [torch.float16, torch.float32]
|
||||
if torch.cuda.is_bf16_supported():
|
||||
supported_cast_dtypes += [torch.bfloat16]
|
||||
|
||||
def forward_wrapper(self, *args, **kwargs):
|
||||
if target_device is not None:
|
||||
params['device'] = target_device
|
||||
@ -182,11 +187,13 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
|
||||
|
||||
args = list(args)
|
||||
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
|
||||
args[j] = args[j].to(**params)
|
||||
if args[j].dtype in supported_cast_dtypes:
|
||||
args[j] = args[j].to(**params)
|
||||
args = tuple(args)
|
||||
|
||||
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
|
||||
kwargs[key] = kwargs[key].to(**params)
|
||||
if kwargs[key].dtype in supported_cast_dtypes:
|
||||
kwargs[key] = kwargs[key].to(**params)
|
||||
|
||||
org_dtype = target_dtype
|
||||
for param in self.parameters():
|
||||
@ -227,10 +234,9 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast(target_dtype, target_device=None):
|
||||
def manual_cast(target_dtype, target_device=None, copy=None):
|
||||
applied = False
|
||||
|
||||
copy = shared.opts.lora_without_backup_weight
|
||||
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
@ -252,10 +258,12 @@ def manual_cast(target_dtype, target_device=None):
|
||||
delattr(module_type, "org_forward")
|
||||
|
||||
|
||||
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None):
|
||||
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None, copy=None):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
copy = copy if copy is not None else shared.opts.lora_without_backup_weight
|
||||
|
||||
if target_dtype is None:
|
||||
target_dtype = dtype
|
||||
if target_device is None:
|
||||
@ -270,13 +278,13 @@ def autocast(disable=False, current_dtype=None, target_dtype=None, target_device
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and dtype_inference == torch.float32:
|
||||
return manual_cast(target_dtype, target_device)
|
||||
return manual_cast(target_dtype, target_device, copy=copy)
|
||||
|
||||
if target_dtype != dtype_inference:
|
||||
return manual_cast(target_dtype, target_device)
|
||||
if target_dtype != dtype_inference or copy:
|
||||
return manual_cast(target_dtype, target_device, copy=copy)
|
||||
|
||||
if current_dtype is not None and current_dtype != target_dtype:
|
||||
return manual_cast(target_dtype, target_device)
|
||||
return manual_cast(target_dtype, target_device, copy=copy)
|
||||
|
||||
if target_dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
return contextlib.nullcontext()
|
||||
|
Loading…
Reference in New Issue
Block a user