Merge pull request #5796 from brkirch/invoke-fix

Improve InvokeAI cross attention reliability and speed when using MPS for large images
This commit is contained in:
AUTOMATIC1111 2022-12-24 08:21:19 +03:00 committed by GitHub
commit f0dfed2a17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -127,7 +127,7 @@ def check_for_psutil():
invokeAI_mps_available = check_for_psutil() invokeAI_mps_available = check_for_psutil()
# -- Taken from https://github.com/invoke-ai/InvokeAI -- # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available: if invokeAI_mps_available:
import psutil import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30) mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
return r return r
def einsum_op_mps_v1(q, k, v): def einsum_op_mps_v1(q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v) return einsum_op_compvis(q, k, v)
else: else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
if slice_size % 4096 == 0:
slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size) return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v): def einsum_op_mps_v2(q, k, v):
if mem_total_gb > 8 and q.shape[1] <= 4096: if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v) return einsum_op_compvis(q, k, v)
else: else:
return einsum_op_slice_0(q, k, v, 1) return einsum_op_slice_0(q, k, v, 1)
@ -188,7 +190,7 @@ def einsum_op(q, k, v):
return einsum_op_cuda(q, k, v) return einsum_op_cuda(q, k, v)
if q.device.type == 'mps': if q.device.type == 'mps':
if mem_total_gb >= 32: if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v) return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v) return einsum_op_mps_v2(q, k, v)