From ac4ccfa1369e74492b467294eab96c3f558b297b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 09:30:33 +0300 Subject: [PATCH] get attention optimizations to work --- modules/hypernetworks/hypernetwork.py | 2 +- modules/launch_utils.py | 1 + modules/sd_hijack_optimizations.py | 14 +++++++------- modules/sd_models_xl.py | 3 +++ 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 79670b877..c4821d21a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None): return context_k, context_v -def attention_CrossAttention_forward(self, x, context=None, mask=None): +def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 56b972d51..183730d29 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -239,6 +239,7 @@ def mute_sdxl_imports(): sys.modules['sgm.data'] = module + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index e99c9ba5f..b5f85ba51 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -173,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -214,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None, additiona # taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -355,7 +355,7 @@ def einsum_op(q, k, v): return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) @@ -383,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, add # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -470,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None -def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -496,7 +496,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, additional_toke # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -537,7 +537,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None, addit return hidden_states -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 9224c1a35..4d1aa497b 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -55,3 +55,6 @@ sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None sgm.modules.encoders.modules.print = lambda *args: None +# this gets the code to load the vanilla attention that we override +sgm.modules.attention.SDP_IS_AVAILABLE = True +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False \ No newline at end of file