diff --git a/modules/devices.py b/modules/devices.py index ff279ac50..6edfb1278 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -110,6 +110,7 @@ device_codeformer: torch.device = None dtype: torch.dtype = torch.float16 dtype_vae: torch.dtype = torch.float16 dtype_unet: torch.dtype = torch.float16 +dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False @@ -131,21 +132,49 @@ patch_module_list = [ ] -def manual_cast_forward(self, *args, **kwargs): - org_dtype = torch_utils.get_param(self).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result +def manual_cast_forward(target_dtype): + def forward_wrapper(self, *args, **kwargs): + org_dtype = torch_utils.get_param(self).dtype + if not target_dtype == org_dtype == dtype_inference: + self.to(target_dtype) + args = [ + arg.to(target_dtype) + if isinstance(arg, torch.Tensor) + else arg + for arg in args + ] + kwargs = { + k: v.to(target_dtype) + if isinstance(v, torch.Tensor) + else v + for k, v in kwargs.items() + } + + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + + if target_dtype != dtype_inference: + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + return forward_wrapper @contextlib.contextmanager -def manual_cast(): +def manual_cast(target_dtype): for module_type in patch_module_list: org_forward = module_type.forward - module_type.forward = manual_cast_forward + if module_type == torch.nn.MultiheadAttention and has_xpu(): + module_type.forward = manual_cast_forward(torch.float32) + else: + module_type.forward = manual_cast_forward(target_dtype) module_type.org_forward = org_forward try: yield None @@ -161,15 +190,12 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_cast() - - if has_mps() and shared.cmd_opts.precision != "full": - return manual_cast() - - if dtype == torch.float32 or shared.cmd_opts.precision == "full": + if dtype == torch.float32 and shared.cmd_opts.precision == "full": return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype_inference) + return torch.autocast("cuda") diff --git a/modules/shared_init.py b/modules/shared_init.py index 586be3423..935e3a21c 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -29,6 +29,7 @@ def initialize(): devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 + devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype shared.device = devices.device shared.weight_load_location = None if cmd_opts.lowram else "cpu"