mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-07 06:02:53 +08:00
[IPEX] Fix SDPA attn_mask dtype
This commit is contained in:
parent
8b6848c6db
commit
16b4d2cf3f
@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention(
|
|||||||
# cast to same dtype first
|
# cast to same dtype first
|
||||||
key = key.to(query.dtype)
|
key = key.to(query.dtype)
|
||||||
value = value.to(query.dtype)
|
value = value.to(query.dtype)
|
||||||
|
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||||
|
attn_mask = attn_mask.to(query.dtype)
|
||||||
|
|
||||||
N = query.shape[:-2] # Batch size
|
N = query.shape[:-2] # Batch size
|
||||||
L = query.size(-2) # Target sequence length
|
L = query.size(-2) # Target sequence length
|
||||||
|
Loading…
Reference in New Issue
Block a user