From abfa4ad8bc995dcaf832c07a7cf75b6e295a8ca9 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 8 May 2023 18:16:01 -0400 Subject: [PATCH 1/8] Use fixed size for sub-quadratic chunking on MPS Even if this causes chunks to be much smaller, performance isn't significantly impacted. This will usually reduce memory usage but should also help with poor performance when free memory is low. --- modules/sd_hijack_optimizations.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 0e810eec8..b3e712707 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ from __future__ import annotations import math import psutil +import platform import torch from torch import einsum @@ -427,7 +428,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens if chunk_threshold is None: - chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + if q.device.type == 'mps': + chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) + else: + chunk_threshold_bytes = int(get_available_vram() * 0.7) elif chunk_threshold == 0: chunk_threshold_bytes = None else: From 87dd685224b5f7dbbd832fc73cc08e7e470c9f28 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sun, 21 May 2023 05:00:27 -0400 Subject: [PATCH 2/8] Make sub-quadratic the default for MPS --- modules/sd_hijack_optimizations.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index b3e712707..7f9e328d0 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -95,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): class SdOptimizationSubQuad(SdOptimization): name = "sub-quadratic" cmd_opt = "opt_sub_quad_attention" - priority = 10 + + @property + def priority(self): + return 1000 if shared.device.type == 'mps' else 10 def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward @@ -121,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): @property def priority(self): - return 1000 if not torch.cuda.is_available() else 10 + return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI From 2489252099c299bed49a9d4a39a4ead73b6b6f10 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 25 Jul 2023 03:03:06 -0400 Subject: [PATCH 3/8] `torch.empty` can create issues; use `torch.zeros` For MPS, using a tensor created with `torch.empty()` can cause `torch.baddbmm()` to include NaNs in the tensor it returns, even though `beta=0`. However, with a tensor of shape [1,1,1], there should be a negligible performance difference between `torch.empty()` and `torch.zeros()` anyway, so it's better to just use `torch.zeros()` for this and avoid unnecessarily creating issues. --- modules/sub_quadratic_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 497568eb5..ae4ee4bbe 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -58,7 +58,7 @@ def _summarize_chunk( scale: float, ) -> AttnChunk: attn_weights = torch.baddbmm( - torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype), query, key.transpose(1,2), alpha=scale, @@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking( scale: float, ) -> Tensor: attn_scores = torch.baddbmm( - torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype), query, key.transpose(1,2), alpha=scale, From 9058620cec2788495d295f4e68ef2932d6d700e6 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 12 Aug 2023 04:44:16 -0400 Subject: [PATCH 4/8] `git checkout` with commit hash --- modules/launch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 4fc254a25..e77baa521 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -175,7 +175,7 @@ def git_clone(url, dir, name, commithash=None): run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False) - run_git(dir, name, 'checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) + run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) return From f4dbb0c820344798e3481d4104618b95594a3d10 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 20 Jul 2023 01:44:45 -0400 Subject: [PATCH 5/8] Change the repositories origin URLs when necessary --- modules/launch_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index e77baa521..9eda7c9d3 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -173,6 +173,9 @@ def git_clone(url, dir, name, commithash=None): if current_hash == commithash: return + if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url: + run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False) + run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False) run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) From 232c931f4082ea73bbaca8f77469cfea9d5db459 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Aug 2023 10:33:43 -0400 Subject: [PATCH 6/8] Mac k-diffusion workarounds are no longer needed --- webui-macos-env.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 6354e73ba..24bc5c426 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -12,8 +12,6 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" -export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" -export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### From 5df535b7c2374c3324485faaea62fbdbffc71f71 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Aug 2023 10:20:10 -0400 Subject: [PATCH 7/8] Remove duplicate code for torchsde randn --- modules/mac_specific.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index bce527ccc..89256c5b0 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): if has_mps: - # MPS fix for randn in torchsde - CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') - if platform.mac_ver()[0].startswith("13.2."): # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) From 2035cbbd5d6e7678450c701fce1a5de7d8bd7084 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 12 Aug 2023 06:01:36 -0400 Subject: [PATCH 8/8] Fix DDIM and PLMS samplers on MPS --- modules/sd_samplers_timesteps_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index 48d7e6491..d32e35213 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) @@ -42,7 +42,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) extra_args = {} if extra_args is None else extra_args