mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-08 13:59:00 +08:00
Merge pull request #12526 from brkirch/mps-adjust-sub-quad
Fixes for `git checkout`, MPS/macOS fixes and optimizations
This commit is contained in:
commit
9cd0475c08
@ -173,6 +173,9 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
if current_hash == commithash:
|
if current_hash == commithash:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||||
|
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||||
|
|
||||||
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||||
|
|
||||||
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||||
|
@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if has_mps:
|
if has_mps:
|
||||||
# MPS fix for randn in torchsde
|
|
||||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
|
||||||
|
|
||||||
if platform.mac_ver()[0].startswith("13.2."):
|
if platform.mac_ver()[0].startswith("13.2."):
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import math
|
import math
|
||||||
import psutil
|
import psutil
|
||||||
|
import platform
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
|||||||
class SdOptimizationSubQuad(SdOptimization):
|
class SdOptimizationSubQuad(SdOptimization):
|
||||||
name = "sub-quadratic"
|
name = "sub-quadratic"
|
||||||
cmd_opt = "opt_sub_quad_attention"
|
cmd_opt = "opt_sub_quad_attention"
|
||||||
priority = 10
|
|
||||||
|
@property
|
||||||
|
def priority(self):
|
||||||
|
return 1000 if shared.device.type == 'mps' else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def priority(self):
|
def priority(self):
|
||||||
return 1000 if not torch.cuda.is_available() else 10
|
return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
if chunk_threshold is None:
|
if chunk_threshold is None:
|
||||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
if q.device.type == 'mps':
|
||||||
|
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
||||||
|
else:
|
||||||
|
chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
||||||
elif chunk_threshold == 0:
|
elif chunk_threshold == 0:
|
||||||
chunk_threshold_bytes = None
|
chunk_threshold_bytes = None
|
||||||
else:
|
else:
|
||||||
|
@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc
|
|||||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
|||||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
@ -58,7 +58,7 @@ def _summarize_chunk(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> AttnChunk:
|
) -> AttnChunk:
|
||||||
attn_weights = torch.baddbmm(
|
attn_weights = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
attn_scores = torch.baddbmm(
|
attn_scores = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
|
@ -12,8 +12,6 @@ fi
|
|||||||
export install_dir="$HOME"
|
export install_dir="$HOME"
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
||||||
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
|
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
|
||||||
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
|
||||||
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
|
||||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||||
|
|
||||||
####################################################################
|
####################################################################
|
||||||
|
Loading…
x
Reference in New Issue
Block a user