mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
fix misc
* check supported dtypes * detect non_blocking * update autocast() to use non_blocking, target_device and current_dtype
This commit is contained in:
parent
821e76a415
commit
39328bd7db
@ -128,6 +128,26 @@ dtype_unet: torch.dtype = torch.float16
|
||||
dtype_inference: torch.dtype = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
supported_vae_dtypes = [torch.float16, torch.float32]
|
||||
|
||||
|
||||
# prepare available dtypes
|
||||
if torch.version.cuda:
|
||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
|
||||
if has_xpu():
|
||||
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
|
||||
|
||||
|
||||
def supports_non_blocking():
|
||||
if has_mps() or has_xpu():
|
||||
return False
|
||||
|
||||
if npu_specific.has_npu:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def cond_cast_unet(input):
|
||||
if force_fp16:
|
||||
@ -149,14 +169,23 @@ patch_module_list = [
|
||||
]
|
||||
|
||||
|
||||
def manual_cast_forward(target_dtype):
|
||||
def manual_cast_forward(target_dtype, target_device=None):
|
||||
params = dict()
|
||||
if supports_non_blocking():
|
||||
params['non_blocking'] = True
|
||||
|
||||
def forward_wrapper(self, *args, **kwargs):
|
||||
if any(
|
||||
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||
for arg in args
|
||||
):
|
||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
if target_device is not None:
|
||||
params['device'] = target_device
|
||||
params['dtype'] = target_dtype
|
||||
|
||||
args = list(args)
|
||||
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
|
||||
args[j] = args[j].to(**params)
|
||||
args = tuple(args)
|
||||
|
||||
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
|
||||
kwargs[key] = kwargs[key].to(**params)
|
||||
|
||||
org_dtype = target_dtype
|
||||
for param in self.parameters():
|
||||
@ -165,37 +194,41 @@ def manual_cast_forward(target_dtype):
|
||||
break
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
self.to(target_dtype)
|
||||
self.to(**params)
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
self.to(org_dtype)
|
||||
params['dtype'] = org_dtype
|
||||
self.to(**params)
|
||||
|
||||
if target_dtype != dtype_inference:
|
||||
params['dtype'] = dtype_inference
|
||||
if isinstance(result, tuple):
|
||||
result = tuple(
|
||||
i.to(dtype_inference)
|
||||
i.to(**params)
|
||||
if isinstance(i, torch.Tensor)
|
||||
else i
|
||||
for i in result
|
||||
)
|
||||
elif isinstance(result, torch.Tensor):
|
||||
result = result.to(dtype_inference)
|
||||
result = result.to(**params)
|
||||
return result
|
||||
return forward_wrapper
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast(target_dtype):
|
||||
def manual_cast(target_dtype, target_device=None):
|
||||
applied = False
|
||||
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
continue
|
||||
applied = True
|
||||
org_forward = module_type.forward
|
||||
if module_type == torch.nn.MultiheadAttention:
|
||||
module_type.forward = manual_cast_forward(torch.float32)
|
||||
module_type.forward = manual_cast_forward(torch.float32, target_device)
|
||||
else:
|
||||
module_type.forward = manual_cast_forward(target_dtype)
|
||||
module_type.forward = manual_cast_forward(target_dtype, target_device)
|
||||
module_type.org_forward = org_forward
|
||||
try:
|
||||
yield None
|
||||
@ -207,26 +240,37 @@ def manual_cast(target_dtype):
|
||||
delattr(module_type, "org_forward")
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if target_dtype is None:
|
||||
target_dtype = dtype
|
||||
if target_device is None:
|
||||
target_device = device
|
||||
|
||||
if force_fp16:
|
||||
# No casting during inference if force_fp16 is enabled.
|
||||
# All tensor dtype conversion happens before inference.
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if fp8 and device==cpu:
|
||||
if fp8 and target_device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and dtype_inference == torch.float32:
|
||||
return manual_cast(dtype)
|
||||
return manual_cast(target_dtype, target_device)
|
||||
|
||||
if dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
if target_dtype != dtype_inference:
|
||||
return manual_cast(target_dtype, target_device)
|
||||
|
||||
if current_dtype is not None and current_dtype != target_dtype:
|
||||
return manual_cast(target_dtype, target_device)
|
||||
|
||||
if target_dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||
return manual_cast(dtype)
|
||||
return manual_cast(target_dtype, target_device)
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user