From 16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 16:32:18 +0800 Subject: [PATCH] [IPEX] Fix SDPA attn_mask dtype --- modules/xpu_specific.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index f7687a66c..4e11125b2 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( # cast to same dtype first key = key.to(query.dtype) value = value.to(query.dtype) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(query.dtype) N = query.shape[:-2] # Batch size L = query.size(-2) # Target sequence length