mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
fix misc
* check supported dtypes * detect non_blocking * update autocast() to use non_blocking, target_device and current_dtype
This commit is contained in:
parent
821e76a415
commit
39328bd7db
@ -128,6 +128,26 @@ dtype_unet: torch.dtype = torch.float16
|
|||||||
dtype_inference: torch.dtype = torch.float16
|
dtype_inference: torch.dtype = torch.float16
|
||||||
unet_needs_upcast = False
|
unet_needs_upcast = False
|
||||||
|
|
||||||
|
supported_vae_dtypes = [torch.float16, torch.float32]
|
||||||
|
|
||||||
|
|
||||||
|
# prepare available dtypes
|
||||||
|
if torch.version.cuda:
|
||||||
|
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||||
|
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
|
||||||
|
if has_xpu():
|
||||||
|
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
|
||||||
|
|
||||||
|
|
||||||
|
def supports_non_blocking():
|
||||||
|
if has_mps() or has_xpu():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if npu_specific.has_npu:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def cond_cast_unet(input):
|
def cond_cast_unet(input):
|
||||||
if force_fp16:
|
if force_fp16:
|
||||||
@ -149,14 +169,23 @@ patch_module_list = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def manual_cast_forward(target_dtype):
|
def manual_cast_forward(target_dtype, target_device=None):
|
||||||
|
params = dict()
|
||||||
|
if supports_non_blocking():
|
||||||
|
params['non_blocking'] = True
|
||||||
|
|
||||||
def forward_wrapper(self, *args, **kwargs):
|
def forward_wrapper(self, *args, **kwargs):
|
||||||
if any(
|
if target_device is not None:
|
||||||
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
params['device'] = target_device
|
||||||
for arg in args
|
params['dtype'] = target_dtype
|
||||||
):
|
|
||||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
args = list(args)
|
||||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
|
||||||
|
args[j] = args[j].to(**params)
|
||||||
|
args = tuple(args)
|
||||||
|
|
||||||
|
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
|
||||||
|
kwargs[key] = kwargs[key].to(**params)
|
||||||
|
|
||||||
org_dtype = target_dtype
|
org_dtype = target_dtype
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
@ -165,37 +194,41 @@ def manual_cast_forward(target_dtype):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if org_dtype != target_dtype:
|
if org_dtype != target_dtype:
|
||||||
self.to(target_dtype)
|
self.to(**params)
|
||||||
result = self.org_forward(*args, **kwargs)
|
result = self.org_forward(*args, **kwargs)
|
||||||
|
|
||||||
if org_dtype != target_dtype:
|
if org_dtype != target_dtype:
|
||||||
self.to(org_dtype)
|
params['dtype'] = org_dtype
|
||||||
|
self.to(**params)
|
||||||
|
|
||||||
if target_dtype != dtype_inference:
|
if target_dtype != dtype_inference:
|
||||||
|
params['dtype'] = dtype_inference
|
||||||
if isinstance(result, tuple):
|
if isinstance(result, tuple):
|
||||||
result = tuple(
|
result = tuple(
|
||||||
i.to(dtype_inference)
|
i.to(**params)
|
||||||
if isinstance(i, torch.Tensor)
|
if isinstance(i, torch.Tensor)
|
||||||
else i
|
else i
|
||||||
for i in result
|
for i in result
|
||||||
)
|
)
|
||||||
elif isinstance(result, torch.Tensor):
|
elif isinstance(result, torch.Tensor):
|
||||||
result = result.to(dtype_inference)
|
result = result.to(**params)
|
||||||
return result
|
return result
|
||||||
return forward_wrapper
|
return forward_wrapper
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def manual_cast(target_dtype):
|
def manual_cast(target_dtype, target_device=None):
|
||||||
applied = False
|
applied = False
|
||||||
|
|
||||||
for module_type in patch_module_list:
|
for module_type in patch_module_list:
|
||||||
if hasattr(module_type, "org_forward"):
|
if hasattr(module_type, "org_forward"):
|
||||||
continue
|
continue
|
||||||
applied = True
|
applied = True
|
||||||
org_forward = module_type.forward
|
org_forward = module_type.forward
|
||||||
if module_type == torch.nn.MultiheadAttention:
|
if module_type == torch.nn.MultiheadAttention:
|
||||||
module_type.forward = manual_cast_forward(torch.float32)
|
module_type.forward = manual_cast_forward(torch.float32, target_device)
|
||||||
else:
|
else:
|
||||||
module_type.forward = manual_cast_forward(target_dtype)
|
module_type.forward = manual_cast_forward(target_dtype, target_device)
|
||||||
module_type.org_forward = org_forward
|
module_type.org_forward = org_forward
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
@ -207,26 +240,37 @@ def manual_cast(target_dtype):
|
|||||||
delattr(module_type, "org_forward")
|
delattr(module_type, "org_forward")
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False):
|
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None):
|
||||||
if disable:
|
if disable:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
if target_dtype is None:
|
||||||
|
target_dtype = dtype
|
||||||
|
if target_device is None:
|
||||||
|
target_device = device
|
||||||
|
|
||||||
if force_fp16:
|
if force_fp16:
|
||||||
# No casting during inference if force_fp16 is enabled.
|
# No casting during inference if force_fp16 is enabled.
|
||||||
# All tensor dtype conversion happens before inference.
|
# All tensor dtype conversion happens before inference.
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
if fp8 and device==cpu:
|
if fp8 and target_device==cpu:
|
||||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||||
|
|
||||||
if fp8 and dtype_inference == torch.float32:
|
if fp8 and dtype_inference == torch.float32:
|
||||||
return manual_cast(dtype)
|
return manual_cast(target_dtype, target_device)
|
||||||
|
|
||||||
if dtype == torch.float32 or dtype_inference == torch.float32:
|
if target_dtype != dtype_inference:
|
||||||
|
return manual_cast(target_dtype, target_device)
|
||||||
|
|
||||||
|
if current_dtype is not None and current_dtype != target_dtype:
|
||||||
|
return manual_cast(target_dtype, target_device)
|
||||||
|
|
||||||
|
if target_dtype == torch.float32 or dtype_inference == torch.float32:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||||
return manual_cast(dtype)
|
return manual_cast(target_dtype, target_device)
|
||||||
|
|
||||||
return torch.autocast("cuda")
|
return torch.autocast("cuda")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user