@@ -177,6 +177,42 @@ def prepare_padding_mask(
177177 return local_padding_mask
178178
179179
180+ def _can_skip_causal_mask_xpu (
181+ padding_mask : Optional [torch .Tensor ],
182+ query_length : int ,
183+ kv_length : int ,
184+ local_attention_size : Optional [int ],
185+ ) -> bool :
186+ """
187+ XPU-specific logic for determining if we can skip causal mask creation.
188+
189+ For XPU devices, we have special handling:
190+ - Single query tokens (query_length == 1) use the same logic as CUDA
191+ - Multi-query tokens can skip if padding_mask is provided and correctly structured
192+ The mask must have all True values in the query window and all False after
193+ """
194+
195+ if is_tracing (padding_mask ):
196+ return False
197+
198+ # Check local attention constraint (same as CUDA)
199+ if local_attention_size is not None and kv_length >= local_attention_size :
200+ return False
201+
202+ if padding_mask is None :
203+ # Without padding mask, can skip if single query token or full causal attention
204+ return query_length == 1 or kv_length == query_length
205+
206+ # XPU allows skipping under additional conditions when padding_mask is provided
207+ if query_length == 1 :
208+ # Single query token: skip only if no padding tokens present
209+ return padding_mask .all ()
210+
211+ # XPU-specific: check if query window is all True and rest is all False
212+ # This allows XPU to optimize the 1st token in static cache
213+ return padding_mask [:, :query_length ].all () and not padding_mask [:, query_length :].any ()
214+
215+
180216def _ignore_causal_mask_sdpa (
181217 padding_mask : Optional [torch .Tensor ],
182218 query_length : int ,
@@ -197,25 +233,24 @@ def _ignore_causal_mask_sdpa(
197233 mask_indices += kv_offset
198234 padding_mask = padding_mask [:, mask_indices ]
199235
236+ if _is_torch_xpu_available :
237+ # XPU devices have special handling for mask skipping:
238+ # - Single query tokens use the same logic as CUDA
239+ # - Multi-query tokens can skip if padding_mask is provided and correctly structured
240+ # (all True in query window, all False after)
241+ return _can_skip_causal_mask_xpu (padding_mask , query_length , kv_length , local_attention_size )
200242 # When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
201243 # hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
202244 # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
203245 # `ignore_causal_mask = True` if we are not tracing
204246 if (
205247 not is_tracing (padding_mask )
206248 # only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
207- and (query_length == 1 or ( kv_length == query_length or _is_torch_xpu_available ) )
249+ and (query_length == 1 or kv_length == query_length )
208250 # in this case we need to add special patterns to the mask so cannot be skipped otherwise
209251 and (local_attention_size is None or kv_length < local_attention_size )
210252 # In this case, we need to add padding to the mask, so cannot be skipped otherwise
211- and (
212- padding_mask is None
213- or (
214- padding_mask .all ()
215- if not _is_torch_xpu_available or query_length == 1
216- else padding_mask [:, :query_length ].all ()
217- )
218- )
253+ and (padding_mask is None or padding_mask .all ())
219254 ):
220255 return True
221256
0 commit comments