diff --git a/modules/devices.py b/modules/devices.py index dd50fe248..f00079c6b 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -2,9 +2,10 @@ import sys, os, shlex import contextlib import torch from modules import errors +from packaging import version -# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: if not getattr(torch, 'has_mps', False): @@ -99,9 +100,25 @@ def autocast(disable=False): # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): - return input_tensor.contiguous() if device.type == 'mps' else input_tensor +orig_tensor_to = torch.Tensor.to +def tensor_to_fix(self, *args, **kwargs): + if self.device.type != 'mps' and \ + ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ + (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): + self = self.contiguous() + return orig_tensor_to(self, *args, **kwargs) -def mps_contiguous_to(input_tensor, device): - return mps_contiguous(input_tensor, device).to(device) +# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 +orig_layer_norm = torch.nn.functional.layer_norm +def layer_norm_fix(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': + args = list(args) + args[0] = args[0].contiguous() + return orig_layer_norm(*args, **kwargs) + + +# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working +if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c61669b46..9a9c38f1f 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -199,7 +199,7 @@ def upscale_without_tiling(model, img): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan) + img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 59532274f..523602413 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -54,7 +54,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), device) + img = img.unsqueeze(0).to(device) with torch.no_grad(): output = model(img) diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 4253b66d5..facd262db 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) + img = img.unsqueeze(0).to(devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old