* check supported dtypes
 * detect non_blocking
 * update autocast() to use non_blocking, target_device and current_dtype
This commit is contained in:
Won-Kyu Park 2024-09-05 09:17:26 +09:00
parent 821e76a415
commit 39328bd7db
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -128,6 +128,26 @@ dtype_unet: torch.dtype = torch.float16
dtype_inference: torch.dtype = torch.float16
unet_needs_upcast = False
supported_vae_dtypes = [torch.float16, torch.float32]
# prepare available dtypes
if torch.version.cuda:
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
if has_xpu():
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
def supports_non_blocking():
if has_mps() or has_xpu():
return False
if npu_specific.has_npu:
return False
return True
def cond_cast_unet(input):
if force_fp16:
@ -149,14 +169,23 @@ patch_module_list = [
]
def manual_cast_forward(target_dtype):
def manual_cast_forward(target_dtype, target_device=None):
params = dict()
if supports_non_blocking():
params['non_blocking'] = True
def forward_wrapper(self, *args, **kwargs):
if any(
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
for arg in args
):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
if target_device is not None:
params['device'] = target_device
params['dtype'] = target_dtype
args = list(args)
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
args[j] = args[j].to(**params)
args = tuple(args)
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
kwargs[key] = kwargs[key].to(**params)
org_dtype = target_dtype
for param in self.parameters():
@ -165,37 +194,41 @@ def manual_cast_forward(target_dtype):
break
if org_dtype != target_dtype:
self.to(target_dtype)
self.to(**params)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
self.to(org_dtype)
params['dtype'] = org_dtype
self.to(**params)
if target_dtype != dtype_inference:
params['dtype'] = dtype_inference
if isinstance(result, tuple):
result = tuple(
i.to(dtype_inference)
i.to(**params)
if isinstance(i, torch.Tensor)
else i
for i in result
)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
result = result.to(**params)
return result
return forward_wrapper
@contextlib.contextmanager
def manual_cast(target_dtype):
def manual_cast(target_dtype, target_device=None):
applied = False
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
continue
applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention:
module_type.forward = manual_cast_forward(torch.float32)
module_type.forward = manual_cast_forward(torch.float32, target_device)
else:
module_type.forward = manual_cast_forward(target_dtype)
module_type.forward = manual_cast_forward(target_dtype, target_device)
module_type.org_forward = org_forward
try:
yield None
@ -207,26 +240,37 @@ def manual_cast(target_dtype):
delattr(module_type, "org_forward")
def autocast(disable=False):
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None):
if disable:
return contextlib.nullcontext()
if target_dtype is None:
target_dtype = dtype
if target_device is None:
target_device = device
if force_fp16:
# No casting during inference if force_fp16 is enabled.
# All tensor dtype conversion happens before inference.
return contextlib.nullcontext()
if fp8 and device==cpu:
if fp8 and target_device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and dtype_inference == torch.float32:
return manual_cast(dtype)
return manual_cast(target_dtype, target_device)
if dtype == torch.float32 or dtype_inference == torch.float32:
if target_dtype != dtype_inference:
return manual_cast(target_dtype, target_device)
if current_dtype is not None and current_dtype != target_dtype:
return manual_cast(target_dtype, target_device)
if target_dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()
if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype)
return manual_cast(target_dtype, target_device)
return torch.autocast("cuda")