mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-25 06:19:00 +08:00
Merge pull request #14046 from hidenorly/AddFP32FallbackSupportOnSdVaeApprox
Add FP32 fallback support on sd_vae_approx
This commit is contained in:
commit
e12a26c253
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
import platform
|
import platform
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
return cumsum_func(input, *args, **kwargs)
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||||
|
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
||||||
|
try:
|
||||||
|
return orig_func(*args, **kwargs)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "not implemented for" in str(e) and "Half" in str(e):
|
||||||
|
input_tensor = args[0]
|
||||||
|
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
||||||
|
else:
|
||||||
|
print(f"An unexpected RuntimeError occurred: {str(e)}")
|
||||||
|
|
||||||
if has_mps:
|
if has_mps:
|
||||||
if platform.mac_ver()[0].startswith("13.2."):
|
if platform.mac_ver()[0].startswith("13.2."):
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||||
@ -77,6 +89,9 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||||
|
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||||
if platform.processor() == 'i386':
|
if platform.processor() == 'i386':
|
||||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user