diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 94b2322a8..449a8755e 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -173,6 +173,9 @@ def git_clone(url, dir, name, commithash=None): if current_hash == commithash: return + if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url: + run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False) + run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False) run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index bce527ccc..89256c5b0 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): if has_mps: - # MPS fix for randn in torchsde - CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') - if platform.mac_ver()[0].startswith("13.2."): # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 0e810eec8..7f9e328d0 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ from __future__ import annotations import math import psutil +import platform import torch from torch import einsum @@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): class SdOptimizationSubQuad(SdOptimization): name = "sub-quadratic" cmd_opt = "opt_sub_quad_attention" - priority = 10 + + @property + def priority(self): + return 1000 if shared.device.type == 'mps' else 10 def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward @@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): @property def priority(self): - return 1000 if not torch.cuda.is_available() else 10 + return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI @@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens if chunk_threshold is None: - chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + if q.device.type == 'mps': + chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) + else: + chunk_threshold_bytes = int(get_available_vram() * 0.7) elif chunk_threshold == 0: chunk_threshold_bytes = None else: diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index 48d7e6491..d32e35213 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) @@ -42,7 +42,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) extra_args = {} if extra_args is None else extra_args diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 497568eb5..ae4ee4bbe 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -58,7 +58,7 @@ def _summarize_chunk( scale: float, ) -> AttnChunk: attn_weights = torch.baddbmm( - torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype), query, key.transpose(1,2), alpha=scale, @@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking( scale: float, ) -> Tensor: attn_scores = torch.baddbmm( - torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype), query, key.transpose(1,2), alpha=scale, diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 6354e73ba..24bc5c426 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -12,8 +12,6 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" -export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" -export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export PYTORCH_ENABLE_MPS_FALLBACK=1 ####################################################################