mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 15:15:05 +08:00
use shared.opts.lora_without_backup_weight option in the devices.autocast()
* add nn.Embedding in the devices.autocast() * do not cast forward args for some cases * add copy option in the devices.autocast()
This commit is contained in:
parent
03516f48f0
commit
ba499f92ac
@ -167,6 +167,7 @@ patch_module_list = [
|
|||||||
torch.nn.MultiheadAttention,
|
torch.nn.MultiheadAttention,
|
||||||
torch.nn.GroupNorm,
|
torch.nn.GroupNorm,
|
||||||
torch.nn.LayerNorm,
|
torch.nn.LayerNorm,
|
||||||
|
torch.nn.Embedding,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -175,6 +176,10 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
|
|||||||
if supports_non_blocking():
|
if supports_non_blocking():
|
||||||
params['non_blocking'] = True
|
params['non_blocking'] = True
|
||||||
|
|
||||||
|
supported_cast_dtypes = [torch.float16, torch.float32]
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
supported_cast_dtypes += [torch.bfloat16]
|
||||||
|
|
||||||
def forward_wrapper(self, *args, **kwargs):
|
def forward_wrapper(self, *args, **kwargs):
|
||||||
if target_device is not None:
|
if target_device is not None:
|
||||||
params['device'] = target_device
|
params['device'] = target_device
|
||||||
@ -182,11 +187,13 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
|
|||||||
|
|
||||||
args = list(args)
|
args = list(args)
|
||||||
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
|
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
|
||||||
args[j] = args[j].to(**params)
|
if args[j].dtype in supported_cast_dtypes:
|
||||||
|
args[j] = args[j].to(**params)
|
||||||
args = tuple(args)
|
args = tuple(args)
|
||||||
|
|
||||||
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
|
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
|
||||||
kwargs[key] = kwargs[key].to(**params)
|
if kwargs[key].dtype in supported_cast_dtypes:
|
||||||
|
kwargs[key] = kwargs[key].to(**params)
|
||||||
|
|
||||||
org_dtype = target_dtype
|
org_dtype = target_dtype
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
@ -227,10 +234,9 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False):
|
|||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def manual_cast(target_dtype, target_device=None):
|
def manual_cast(target_dtype, target_device=None, copy=None):
|
||||||
applied = False
|
applied = False
|
||||||
|
|
||||||
copy = shared.opts.lora_without_backup_weight
|
|
||||||
|
|
||||||
for module_type in patch_module_list:
|
for module_type in patch_module_list:
|
||||||
if hasattr(module_type, "org_forward"):
|
if hasattr(module_type, "org_forward"):
|
||||||
@ -252,10 +258,12 @@ def manual_cast(target_dtype, target_device=None):
|
|||||||
delattr(module_type, "org_forward")
|
delattr(module_type, "org_forward")
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None):
|
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None, copy=None):
|
||||||
if disable:
|
if disable:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
copy = copy if copy is not None else shared.opts.lora_without_backup_weight
|
||||||
|
|
||||||
if target_dtype is None:
|
if target_dtype is None:
|
||||||
target_dtype = dtype
|
target_dtype = dtype
|
||||||
if target_device is None:
|
if target_device is None:
|
||||||
@ -270,13 +278,13 @@ def autocast(disable=False, current_dtype=None, target_dtype=None, target_device
|
|||||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||||
|
|
||||||
if fp8 and dtype_inference == torch.float32:
|
if fp8 and dtype_inference == torch.float32:
|
||||||
return manual_cast(target_dtype, target_device)
|
return manual_cast(target_dtype, target_device, copy=copy)
|
||||||
|
|
||||||
if target_dtype != dtype_inference:
|
if target_dtype != dtype_inference or copy:
|
||||||
return manual_cast(target_dtype, target_device)
|
return manual_cast(target_dtype, target_device, copy=copy)
|
||||||
|
|
||||||
if current_dtype is not None and current_dtype != target_dtype:
|
if current_dtype is not None and current_dtype != target_dtype:
|
||||||
return manual_cast(target_dtype, target_device)
|
return manual_cast(target_dtype, target_device, copy=copy)
|
||||||
|
|
||||||
if target_dtype == torch.float32 or dtype_inference == torch.float32:
|
if target_dtype == torch.float32 or dtype_inference == torch.float32:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
Loading…
Reference in New Issue
Block a user