mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 05:45:05 +08:00
Change sub-quad chunk threshold to use percentage
This commit is contained in:
parent
b119815333
commit
b95a4c0ce5
@ -233,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||||
|
|
||||||
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||||
|
|
||||||
@ -243,20 +243,20 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True):
|
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||||
bytes_per_token = torch.finfo(q.dtype).bits//8
|
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||||
batch_x_heads, q_tokens, _ = q.shape
|
batch_x_heads, q_tokens, _ = q.shape
|
||||||
_, k_tokens, _ = k.shape
|
_, k_tokens, _ = k.shape
|
||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
if chunk_threshold is None:
|
||||||
|
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
||||||
if chunk_threshold_bytes is None:
|
elif chunk_threshold == 0:
|
||||||
chunk_threshold_bytes = available_vram
|
|
||||||
elif chunk_threshold_bytes == 0:
|
|
||||||
chunk_threshold_bytes = None
|
chunk_threshold_bytes = None
|
||||||
|
else:
|
||||||
|
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
||||||
|
|
||||||
if kv_chunk_size_min is None:
|
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
||||||
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
||||||
elif kv_chunk_size_min == 0:
|
elif kv_chunk_size_min == 0:
|
||||||
kv_chunk_size_min = None
|
kv_chunk_size_min = None
|
||||||
@ -382,7 +382,7 @@ def sub_quad_attnblock_forward(self, x):
|
|||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x + out
|
return x + out
|
||||||
|
@ -59,7 +59,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
|
|||||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||||
|
Loading…
Reference in New Issue
Block a user