mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-01 20:35:06 +08:00
Merge pull request #6055 from brkirch/sub-quad_attn_opt
Add Birch-san's sub-quadratic attention implementation
This commit is contained in:
commit
c295e4a244
@ -141,6 +141,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||||
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||||
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
||||||
|
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
|
||||||
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
||||||
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
||||||
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
||||||
|
@ -184,7 +184,7 @@ SOFTWARE.
|
|||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
||||||
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
<small>Code added by contributors, most likely copied from this repository.</small>
|
||||||
|
|
||||||
<pre>
|
<pre>
|
||||||
Apache License
|
Apache License
|
||||||
@ -390,3 +390,30 @@ SOFTWARE.
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
||||||
|
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Alex Birch
|
||||||
|
Copyright (c) 2023 Amin Rezaei
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork
|
|||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
|
|
||||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
@ -43,20 +41,19 @@ def apply_optimizations():
|
|||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||||
optimization_method = 'xformers'
|
optimization_method = 'xformers'
|
||||||
|
elif cmd_opts.opt_sub_quad_attention:
|
||||||
|
print("Applying sub-quadratic cross attention optimization.")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
||||||
|
optimization_method = 'sub-quadratic'
|
||||||
elif cmd_opts.opt_split_attention_v1:
|
elif cmd_opts.opt_split_attention_v1:
|
||||||
print("Applying v1 cross attention optimization.")
|
print("Applying v1 cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
optimization_method = 'V1'
|
optimization_method = 'V1'
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
||||||
if not invokeAI_mps_available and shared.device.type == 'mps':
|
print("Applying cross attention optimization (InvokeAI).")
|
||||||
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||||
print("Applying v1 cross attention optimization.")
|
optimization_method = 'InvokeAI'
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
|
||||||
optimization_method = 'V1'
|
|
||||||
else:
|
|
||||||
print("Applying cross attention optimization (InvokeAI).")
|
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
|
||||||
optimization_method = 'InvokeAI'
|
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
print("Applying cross attention optimization (Doggettx).")
|
print("Applying cross attention optimization (Doggettx).")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import importlib
|
import psutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -12,6 +12,8 @@ from einops import rearrange
|
|||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||||
try:
|
try:
|
||||||
@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
|||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_vram():
|
||||||
|
if shared.device.type == 'cuda':
|
||||||
|
stats = torch.cuda.memory_stats(shared.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
return mem_free_total
|
||||||
|
else:
|
||||||
|
return psutil.virtual_memory().available
|
||||||
|
|
||||||
|
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
mem_free_total = get_available_vram()
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
|
||||||
def check_for_psutil():
|
|
||||||
try:
|
|
||||||
spec = importlib.util.find_spec('psutil')
|
|
||||||
return spec is not None
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
invokeAI_mps_available = check_for_psutil()
|
|
||||||
|
|
||||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||||
if invokeAI_mps_available:
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
import psutil
|
|
||||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
|
||||||
|
|
||||||
def einsum_op_compvis(q, k, v):
|
def einsum_op_compvis(q, k, v):
|
||||||
s = einsum('b i d, b j d -> b i j', q, k)
|
s = einsum('b i d, b j d -> b i j', q, k)
|
||||||
@ -215,6 +214,71 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||||
|
|
||||||
|
|
||||||
|
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||||
|
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||||
|
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||||
|
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||||
|
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||||
|
k = self.to_k(context_k)
|
||||||
|
v = self.to_v(context_v)
|
||||||
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
|
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
|
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||||
|
|
||||||
|
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||||
|
|
||||||
|
out_proj, dropout = self.to_out
|
||||||
|
x = out_proj(x)
|
||||||
|
x = dropout(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||||
|
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||||
|
batch_x_heads, q_tokens, _ = q.shape
|
||||||
|
_, k_tokens, _ = k.shape
|
||||||
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
|
if chunk_threshold is None:
|
||||||
|
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
||||||
|
elif chunk_threshold == 0:
|
||||||
|
chunk_threshold_bytes = None
|
||||||
|
else:
|
||||||
|
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
||||||
|
|
||||||
|
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
||||||
|
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
||||||
|
elif kv_chunk_size_min == 0:
|
||||||
|
kv_chunk_size_min = None
|
||||||
|
|
||||||
|
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||||
|
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||||
|
# i.e. send it down the unchunked fast-path
|
||||||
|
query_chunk_size = q_tokens
|
||||||
|
kv_chunk_size = k_tokens
|
||||||
|
|
||||||
|
return efficient_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_chunk_size=q_chunk_size,
|
||||||
|
kv_chunk_size=kv_chunk_size,
|
||||||
|
kv_chunk_size_min = kv_chunk_size_min,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
@ -252,12 +316,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
|
|
||||||
h_ = torch.zeros_like(k, device=q.device)
|
h_ = torch.zeros_like(k, device=q.device)
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
mem_free_total = get_available_vram()
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
mem_required = tensor_size * 2.5
|
mem_required = tensor_size * 2.5
|
||||||
@ -312,3 +371,19 @@ def xformers_attnblock_forward(self, x):
|
|||||||
return x + out
|
return x + out
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return cross_attention_attnblock_forward(self, x)
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
def sub_quad_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||||
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
|
out = self.proj_out(out)
|
||||||
|
return x + out
|
||||||
|
@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
|
|||||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||||
|
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||||
|
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||||
|
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||||
|
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||||
|
205
modules/sub_quadratic_attention.py
Normal file
205
modules/sub_quadratic_attention.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
# original source:
|
||||||
|
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
||||||
|
# license:
|
||||||
|
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
|
||||||
|
# credit:
|
||||||
|
# Amin Rezaei (original author)
|
||||||
|
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
||||||
|
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
|
||||||
|
# implementation of:
|
||||||
|
# Self-attention Does Not Need O(n2) Memory":
|
||||||
|
# https://arxiv.org/abs/2112.05682v2
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
import math
|
||||||
|
from typing import Optional, NamedTuple, Protocol, List
|
||||||
|
|
||||||
|
def narrow_trunc(
|
||||||
|
input: Tensor,
|
||||||
|
dim: int,
|
||||||
|
start: int,
|
||||||
|
length: int
|
||||||
|
) -> Tensor:
|
||||||
|
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
||||||
|
|
||||||
|
class AttnChunk(NamedTuple):
|
||||||
|
exp_values: Tensor
|
||||||
|
exp_weights_sum: Tensor
|
||||||
|
max_score: Tensor
|
||||||
|
|
||||||
|
class SummarizeChunk(Protocol):
|
||||||
|
@staticmethod
|
||||||
|
def __call__(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
) -> AttnChunk: ...
|
||||||
|
|
||||||
|
class ComputeQueryChunkAttn(Protocol):
|
||||||
|
@staticmethod
|
||||||
|
def __call__(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
) -> Tensor: ...
|
||||||
|
|
||||||
|
def _summarize_chunk(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
scale: float,
|
||||||
|
) -> AttnChunk:
|
||||||
|
attn_weights = torch.baddbmm(
|
||||||
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
|
query,
|
||||||
|
key.transpose(1,2),
|
||||||
|
alpha=scale,
|
||||||
|
beta=0,
|
||||||
|
)
|
||||||
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
|
max_score = max_score.detach()
|
||||||
|
exp_weights = torch.exp(attn_weights - max_score)
|
||||||
|
exp_values = torch.bmm(exp_weights, value)
|
||||||
|
max_score = max_score.squeeze(-1)
|
||||||
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||||
|
|
||||||
|
def _query_chunk_attention(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
summarize_chunk: SummarizeChunk,
|
||||||
|
kv_chunk_size: int,
|
||||||
|
) -> Tensor:
|
||||||
|
batch_x_heads, k_tokens, k_channels_per_head = key.shape
|
||||||
|
_, _, v_channels_per_head = value.shape
|
||||||
|
|
||||||
|
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
||||||
|
key_chunk = narrow_trunc(
|
||||||
|
key,
|
||||||
|
1,
|
||||||
|
chunk_idx,
|
||||||
|
kv_chunk_size
|
||||||
|
)
|
||||||
|
value_chunk = narrow_trunc(
|
||||||
|
value,
|
||||||
|
1,
|
||||||
|
chunk_idx,
|
||||||
|
kv_chunk_size
|
||||||
|
)
|
||||||
|
return summarize_chunk(query, key_chunk, value_chunk)
|
||||||
|
|
||||||
|
chunks: List[AttnChunk] = [
|
||||||
|
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||||
|
]
|
||||||
|
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||||
|
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||||
|
|
||||||
|
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
||||||
|
max_diffs = torch.exp(chunk_max - global_max)
|
||||||
|
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
||||||
|
chunk_weights *= max_diffs
|
||||||
|
|
||||||
|
all_values = chunk_values.sum(dim=0)
|
||||||
|
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||||
|
return all_values / all_weights
|
||||||
|
|
||||||
|
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||||
|
def _get_attention_scores_no_kv_chunking(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
scale: float,
|
||||||
|
) -> Tensor:
|
||||||
|
attn_scores = torch.baddbmm(
|
||||||
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
|
query,
|
||||||
|
key.transpose(1,2),
|
||||||
|
alpha=scale,
|
||||||
|
beta=0,
|
||||||
|
)
|
||||||
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
|
del attn_scores
|
||||||
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||||
|
return hidden_states_slice
|
||||||
|
|
||||||
|
class ScannedChunk(NamedTuple):
|
||||||
|
chunk_idx: int
|
||||||
|
attn_chunk: AttnChunk
|
||||||
|
|
||||||
|
def efficient_dot_product_attention(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
query_chunk_size=1024,
|
||||||
|
kv_chunk_size: Optional[int] = None,
|
||||||
|
kv_chunk_size_min: Optional[int] = None,
|
||||||
|
use_checkpoint=True,
|
||||||
|
):
|
||||||
|
"""Computes efficient dot-product attention given query, key, and value.
|
||||||
|
This is efficient version of attention presented in
|
||||||
|
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
||||||
|
Args:
|
||||||
|
query: queries for calculating attention with shape of
|
||||||
|
`[batch * num_heads, tokens, channels_per_head]`.
|
||||||
|
key: keys for calculating attention with shape of
|
||||||
|
`[batch * num_heads, tokens, channels_per_head]`.
|
||||||
|
value: values to be used in attention with shape of
|
||||||
|
`[batch * num_heads, tokens, channels_per_head]`.
|
||||||
|
query_chunk_size: int: query chunks size
|
||||||
|
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
||||||
|
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
||||||
|
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
||||||
|
Returns:
|
||||||
|
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
||||||
|
"""
|
||||||
|
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
||||||
|
_, k_tokens, _ = key.shape
|
||||||
|
scale = q_channels_per_head ** -0.5
|
||||||
|
|
||||||
|
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
||||||
|
if kv_chunk_size_min is not None:
|
||||||
|
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||||
|
|
||||||
|
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||||
|
return narrow_trunc(
|
||||||
|
query,
|
||||||
|
1,
|
||||||
|
chunk_idx,
|
||||||
|
min(query_chunk_size, q_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||||
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||||
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||||
|
_get_attention_scores_no_kv_chunking,
|
||||||
|
scale=scale
|
||||||
|
) if k_tokens <= kv_chunk_size else (
|
||||||
|
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||||
|
partial(
|
||||||
|
_query_chunk_attention,
|
||||||
|
kv_chunk_size=kv_chunk_size,
|
||||||
|
summarize_chunk=summarize_chunk,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if q_tokens <= query_chunk_size:
|
||||||
|
# fast-path for when there's just 1 query chunk
|
||||||
|
return compute_query_chunk_attn(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||||
|
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||||
|
res = torch.cat([
|
||||||
|
compute_query_chunk_attn(
|
||||||
|
query=get_query_chunk(i * query_chunk_size),
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||||
|
], dim=1)
|
||||||
|
return res
|
@ -30,4 +30,4 @@ inflection
|
|||||||
GitPython
|
GitPython
|
||||||
torchsde
|
torchsde
|
||||||
safetensors
|
safetensors
|
||||||
psutil; sys_platform == 'darwin'
|
psutil
|
||||||
|
Loading…
Reference in New Issue
Block a user