mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 05:45:05 +08:00
Merge pull request #5796 from brkirch/invoke-fix
Improve InvokeAI cross attention reliability and speed when using MPS for large images
This commit is contained in:
commit
f0dfed2a17
@ -127,7 +127,7 @@ def check_for_psutil():
|
|||||||
|
|
||||||
invokeAI_mps_available = check_for_psutil()
|
invokeAI_mps_available = check_for_psutil()
|
||||||
|
|
||||||
# -- Taken from https://github.com/invoke-ai/InvokeAI --
|
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||||
if invokeAI_mps_available:
|
if invokeAI_mps_available:
|
||||||
import psutil
|
import psutil
|
||||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
def einsum_op_mps_v1(q, k, v):
|
def einsum_op_mps_v1(q, k, v):
|
||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||||
return einsum_op_compvis(q, k, v)
|
return einsum_op_compvis(q, k, v)
|
||||||
else:
|
else:
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||||
|
if slice_size % 4096 == 0:
|
||||||
|
slice_size -= 1
|
||||||
return einsum_op_slice_1(q, k, v, slice_size)
|
return einsum_op_slice_1(q, k, v, slice_size)
|
||||||
|
|
||||||
def einsum_op_mps_v2(q, k, v):
|
def einsum_op_mps_v2(q, k, v):
|
||||||
if mem_total_gb > 8 and q.shape[1] <= 4096:
|
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||||
return einsum_op_compvis(q, k, v)
|
return einsum_op_compvis(q, k, v)
|
||||||
else:
|
else:
|
||||||
return einsum_op_slice_0(q, k, v, 1)
|
return einsum_op_slice_0(q, k, v, 1)
|
||||||
@ -188,7 +190,7 @@ def einsum_op(q, k, v):
|
|||||||
return einsum_op_cuda(q, k, v)
|
return einsum_op_cuda(q, k, v)
|
||||||
|
|
||||||
if q.device.type == 'mps':
|
if q.device.type == 'mps':
|
||||||
if mem_total_gb >= 32:
|
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
||||||
return einsum_op_mps_v1(q, k, v)
|
return einsum_op_mps_v1(q, k, v)
|
||||||
return einsum_op_mps_v2(q, k, v)
|
return einsum_op_mps_v2(q, k, v)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user