2022-10-22 06:11:07 +08:00
import sys , os , shlex
2022-10-04 17:32:22 +08:00
import contextlib
2022-09-11 13:11:27 +08:00
import torch
2022-09-12 21:34:13 +08:00
from modules import errors
2022-11-17 16:52:17 +08:00
from packaging import version
2022-09-12 21:34:13 +08:00
2022-11-12 15:00:49 +08:00
2022-11-17 16:52:17 +08:00
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
2022-11-12 11:02:40 +08:00
# check `getattr` and try it for compatibility
def has_mps ( ) - > bool :
2022-11-12 15:00:49 +08:00
if not getattr ( torch , ' has_mps ' , False ) :
return False
2022-11-12 11:02:40 +08:00
try :
torch . zeros ( 1 ) . to ( torch . device ( " mps " ) )
return True
except Exception :
return False
2022-09-11 13:11:27 +08:00
2022-09-11 23:48:36 +08:00
2022-10-22 06:11:07 +08:00
def extract_device_id ( args , name ) :
for x in range ( len ( args ) ) :
2022-11-12 15:00:49 +08:00
if name in args [ x ] :
return args [ x + 1 ]
2022-10-22 06:11:07 +08:00
return None
2022-09-11 23:48:36 +08:00
2022-11-12 15:00:49 +08:00
2022-11-27 18:08:54 +08:00
def get_cuda_device_string ( ) :
from modules import shared
if shared . cmd_opts . device_id is not None :
return f " cuda: { shared . cmd_opts . device_id } "
2022-10-22 19:04:14 +08:00
2022-11-27 18:08:54 +08:00
return " cuda "
2022-10-22 19:04:14 +08:00
2022-11-27 18:08:54 +08:00
def get_optimal_device ( ) :
if torch . cuda . is_available ( ) :
return torch . device ( get_cuda_device_string ( ) )
2022-09-11 23:48:36 +08:00
2022-11-12 11:02:40 +08:00
if has_mps ( ) :
2022-09-11 23:48:36 +08:00
return torch . device ( " mps " )
return cpu
2022-09-12 04:24:24 +08:00
2022-12-03 23:06:33 +08:00
def get_device_for ( task ) :
from modules import shared
if task in shared . cmd_opts . use_cpu :
return cpu
return get_optimal_device ( )
2022-09-12 04:24:24 +08:00
def torch_gc ( ) :
if torch . cuda . is_available ( ) :
2022-11-27 18:08:54 +08:00
with torch . cuda . device ( get_cuda_device_string ( ) ) :
2022-11-27 07:25:16 +08:00
torch . cuda . empty_cache ( )
torch . cuda . ipc_collect ( )
2022-09-12 21:34:13 +08:00
def enable_tf32 ( ) :
if torch . cuda . is_available ( ) :
2022-12-03 21:01:23 +08:00
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
2022-12-03 20:57:52 +08:00
if any ( [ torch . cuda . get_device_capability ( devid ) == ( 7 , 5 ) for devid in range ( 0 , torch . cuda . device_count ( ) ) ] ) :
2022-11-08 10:06:48 +08:00
torch . backends . cudnn . benchmark = True
2022-12-03 20:57:52 +08:00
2022-09-12 21:34:13 +08:00
torch . backends . cuda . matmul . allow_tf32 = True
torch . backends . cudnn . allow_tf32 = True
2022-11-07 09:05:51 +08:00
2022-09-12 21:34:13 +08:00
errors . run ( enable_tf32 , " Enabling TF32 " )
2022-09-13 01:09:32 +08:00
2022-11-12 15:00:49 +08:00
cpu = torch . device ( " cpu " )
2022-12-03 23:06:33 +08:00
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
2022-10-02 20:03:39 +08:00
dtype = torch . float16
2022-10-10 21:11:14 +08:00
dtype_vae = torch . float16
2023-01-25 12:51:45 +08:00
dtype_unet = torch . float16
unet_needs_upcast = False
2022-09-13 01:09:32 +08:00
2022-11-12 15:00:49 +08:00
2022-09-13 01:09:32 +08:00
def randn ( seed , shape ) :
torch . manual_seed ( seed )
2022-11-30 21:02:39 +08:00
if device . type == ' mps ' :
return torch . randn ( shape , device = cpu ) . to ( device )
2022-09-13 01:09:32 +08:00
return torch . randn ( shape , device = device )
2022-09-14 02:49:58 +08:00
def randn_without_seed ( shape ) :
if device . type == ' mps ' :
2022-11-30 21:02:39 +08:00
return torch . randn ( shape , device = cpu ) . to ( device )
2022-09-14 02:49:58 +08:00
return torch . randn ( shape , device = device )
2022-10-04 17:32:22 +08:00
2022-10-10 21:11:14 +08:00
def autocast ( disable = False ) :
2022-10-04 17:32:22 +08:00
from modules import shared
2022-10-10 21:11:14 +08:00
if disable :
return contextlib . nullcontext ( )
2022-10-04 17:32:22 +08:00
if dtype == torch . float32 or shared . cmd_opts . precision == " full " :
return contextlib . nullcontext ( )
return torch . autocast ( " cuda " )
2022-10-25 14:01:57 +08:00
2022-11-12 15:00:49 +08:00
2023-01-17 03:59:46 +08:00
class NansException ( Exception ) :
pass
def test_for_nans ( x , where ) :
from modules import shared
2023-01-17 16:04:56 +08:00
if shared . cmd_opts . disable_nan_check :
return
2023-01-17 03:59:46 +08:00
if not torch . all ( torch . isnan ( x ) ) . item ( ) :
return
if where == " unet " :
message = " A tensor with all NaNs was produced in Unet. "
if not shared . cmd_opts . no_half :
message + = " This could be either because there ' s not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this. "
elif where == " vae " :
message = " A tensor with all NaNs was produced in VAE. "
if not shared . cmd_opts . no_half and not shared . cmd_opts . no_half_vae :
message + = " This could be because there ' s not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this. "
else :
message = " A tensor with all NaNs was produced. "
raise NansException ( message )
2022-10-25 14:01:57 +08:00
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
2022-11-17 16:52:17 +08:00
orig_tensor_to = torch . Tensor . to
def tensor_to_fix ( self , * args , * * kwargs ) :
if self . device . type != ' mps ' and \
( ( len ( args ) > 0 and isinstance ( args [ 0 ] , torch . device ) and args [ 0 ] . type == ' mps ' ) or \
( isinstance ( kwargs . get ( ' device ' ) , torch . device ) and kwargs [ ' device ' ] . type == ' mps ' ) ) :
self = self . contiguous ( )
return orig_tensor_to ( self , * args , * * kwargs )
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
orig_layer_norm = torch . nn . functional . layer_norm
def layer_norm_fix ( * args , * * kwargs ) :
if len ( args ) > 0 and isinstance ( args [ 0 ] , torch . Tensor ) and args [ 0 ] . device . type == ' mps ' :
args = list ( args )
args [ 0 ] = args [ 0 ] . contiguous ( )
return orig_layer_norm ( * args , * * kwargs )
2022-12-17 16:21:19 +08:00
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
orig_tensor_numpy = torch . Tensor . numpy
def numpy_fix ( self , * args , * * kwargs ) :
if self . requires_grad :
self = self . detach ( )
return orig_tensor_numpy ( self , * args , * * kwargs )
2023-01-04 09:43:05 +08:00
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
orig_cumsum = torch . cumsum
orig_Tensor_cumsum = torch . Tensor . cumsum
def cumsum_fix ( input , cumsum_func , * args , * * kwargs ) :
if input . device . type == ' mps ' :
output_dtype = kwargs . get ( ' dtype ' , input . dtype )
2023-01-12 21:00:38 +08:00
if output_dtype == torch . int64 :
2023-01-04 09:43:05 +08:00
return cumsum_func ( input . cpu ( ) , * args , * * kwargs ) . to ( input . device )
2023-01-12 21:00:38 +08:00
elif cumsum_needs_bool_fix and output_dtype == torch . bool or cumsum_needs_int_fix and ( output_dtype == torch . int8 or output_dtype == torch . int16 ) :
return cumsum_func ( input . to ( torch . int32 ) , * args , * * kwargs ) . to ( torch . int64 )
2023-01-04 09:43:05 +08:00
return cumsum_func ( input , * args , * * kwargs )
if has_mps ( ) :
if version . parse ( torch . __version__ ) < version . parse ( " 1.13 " ) :
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
torch . Tensor . to = tensor_to_fix
torch . nn . functional . layer_norm = layer_norm_fix
torch . Tensor . numpy = numpy_fix
elif version . parse ( torch . __version__ ) > version . parse ( " 1.13.1 " ) :
2023-01-12 21:00:38 +08:00
cumsum_needs_int_fix = not torch . Tensor ( [ 1 , 2 ] ) . to ( torch . device ( " mps " ) ) . equal ( torch . ShortTensor ( [ 1 , 1 ] ) . to ( torch . device ( " mps " ) ) . cumsum ( 0 ) )
cumsum_needs_bool_fix = not torch . BoolTensor ( [ True , True ] ) . to ( device = torch . device ( " mps " ) , dtype = torch . int64 ) . equal ( torch . BoolTensor ( [ True , False ] ) . to ( torch . device ( " mps " ) ) . cumsum ( 0 ) )
torch . cumsum = lambda input , * args , * * kwargs : ( cumsum_fix ( input , orig_cumsum , * args , * * kwargs ) )
torch . Tensor . cumsum = lambda self , * args , * * kwargs : ( cumsum_fix ( self , orig_Tensor_cumsum , * args , * * kwargs ) )
2023-01-04 09:43:05 +08:00
orig_narrow = torch . narrow
torch . narrow = lambda * args , * * kwargs : ( orig_narrow ( * args , * * kwargs ) . clone ( ) )
2023-01-17 03:59:46 +08:00