diff --git a/modules/devices.py b/modules/devices.py index ee679141a..556e72d2e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -128,6 +128,26 @@ dtype_unet: torch.dtype = torch.float16 dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False +supported_vae_dtypes = [torch.float16, torch.float32] + + +# prepare available dtypes +if torch.version.cuda: + if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: + supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes + if has_xpu(): + supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes + + +def supports_non_blocking(): + if has_mps() or has_xpu(): + return False + + if npu_specific.has_npu: + return False + + return True + def cond_cast_unet(input): if force_fp16: @@ -149,14 +169,23 @@ patch_module_list = [ ] -def manual_cast_forward(target_dtype): +def manual_cast_forward(target_dtype, target_device=None): + params = dict() + if supports_non_blocking(): + params['non_blocking'] = True + def forward_wrapper(self, *args, **kwargs): - if any( - isinstance(arg, torch.Tensor) and arg.dtype != target_dtype - for arg in args - ): - args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + if target_device is not None: + params['device'] = target_device + params['dtype'] = target_dtype + + args = list(args) + for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype): + args[j] = args[j].to(**params) + args = tuple(args) + + for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype): + kwargs[key] = kwargs[key].to(**params) org_dtype = target_dtype for param in self.parameters(): @@ -165,37 +194,41 @@ def manual_cast_forward(target_dtype): break if org_dtype != target_dtype: - self.to(target_dtype) + self.to(**params) result = self.org_forward(*args, **kwargs) + if org_dtype != target_dtype: - self.to(org_dtype) + params['dtype'] = org_dtype + self.to(**params) if target_dtype != dtype_inference: + params['dtype'] = dtype_inference if isinstance(result, tuple): result = tuple( - i.to(dtype_inference) + i.to(**params) if isinstance(i, torch.Tensor) else i for i in result ) elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) + result = result.to(**params) return result return forward_wrapper @contextlib.contextmanager -def manual_cast(target_dtype): +def manual_cast(target_dtype, target_device=None): applied = False + for module_type in patch_module_list: if hasattr(module_type, "org_forward"): continue applied = True org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention: - module_type.forward = manual_cast_forward(torch.float32) + module_type.forward = manual_cast_forward(torch.float32, target_device) else: - module_type.forward = manual_cast_forward(target_dtype) + module_type.forward = manual_cast_forward(target_dtype, target_device) module_type.org_forward = org_forward try: yield None @@ -207,26 +240,37 @@ def manual_cast(target_dtype): delattr(module_type, "org_forward") -def autocast(disable=False): +def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None): if disable: return contextlib.nullcontext() + if target_dtype is None: + target_dtype = dtype + if target_device is None: + target_device = device + if force_fp16: # No casting during inference if force_fp16 is enabled. # All tensor dtype conversion happens before inference. return contextlib.nullcontext() - if fp8 and device==cpu: + if fp8 and target_device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and dtype_inference == torch.float32: - return manual_cast(dtype) + return manual_cast(target_dtype, target_device) - if dtype == torch.float32 or dtype_inference == torch.float32: + if target_dtype != dtype_inference: + return manual_cast(target_dtype, target_device) + + if current_dtype is not None and current_dtype != target_dtype: + return manual_cast(target_dtype, target_device) + + if target_dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype) + return manual_cast(target_dtype, target_device) return torch.autocast("cuda")