Skip to content

Commit 58a3f8c

Browse files
authored
fix test failure of speculative_generation on xpu (#42052)
* fix test failure of speculative_generation on xpu Signed-off-by: Wang, Yi A <[email protected]> * code refine Signed-off-by: Wang, Yi A <[email protected]> * address review comment Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]>
1 parent fcea1e1 commit 58a3f8c

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

src/transformers/masking_utils.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,42 @@ def prepare_padding_mask(
177177
return local_padding_mask
178178

179179

180+
def _can_skip_causal_mask_xpu(
181+
padding_mask: Optional[torch.Tensor],
182+
query_length: int,
183+
kv_length: int,
184+
local_attention_size: Optional[int],
185+
) -> bool:
186+
"""
187+
XPU-specific logic for determining if we can skip causal mask creation.
188+
189+
For XPU devices, we have special handling:
190+
- Single query tokens (query_length == 1) use the same logic as CUDA
191+
- Multi-query tokens can skip if padding_mask is provided and correctly structured
192+
The mask must have all True values in the query window and all False after
193+
"""
194+
195+
if is_tracing(padding_mask):
196+
return False
197+
198+
# Check local attention constraint (same as CUDA)
199+
if local_attention_size is not None and kv_length >= local_attention_size:
200+
return False
201+
202+
if padding_mask is None:
203+
# Without padding mask, can skip if single query token or full causal attention
204+
return query_length == 1 or kv_length == query_length
205+
206+
# XPU allows skipping under additional conditions when padding_mask is provided
207+
if query_length == 1:
208+
# Single query token: skip only if no padding tokens present
209+
return padding_mask.all()
210+
211+
# XPU-specific: check if query window is all True and rest is all False
212+
# This allows XPU to optimize the 1st token in static cache
213+
return padding_mask[:, :query_length].all() and not padding_mask[:, query_length:].any()
214+
215+
180216
def _ignore_causal_mask_sdpa(
181217
padding_mask: Optional[torch.Tensor],
182218
query_length: int,
@@ -197,25 +233,24 @@ def _ignore_causal_mask_sdpa(
197233
mask_indices += kv_offset
198234
padding_mask = padding_mask[:, mask_indices]
199235

236+
if _is_torch_xpu_available:
237+
# XPU devices have special handling for mask skipping:
238+
# - Single query tokens use the same logic as CUDA
239+
# - Multi-query tokens can skip if padding_mask is provided and correctly structured
240+
# (all True in query window, all False after)
241+
return _can_skip_causal_mask_xpu(padding_mask, query_length, kv_length, local_attention_size)
200242
# When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
201243
# hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
202244
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
203245
# `ignore_causal_mask = True` if we are not tracing
204246
if (
205247
not is_tracing(padding_mask)
206248
# only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
207-
and (query_length == 1 or (kv_length == query_length or _is_torch_xpu_available))
249+
and (query_length == 1 or kv_length == query_length)
208250
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
209251
and (local_attention_size is None or kv_length < local_attention_size)
210252
# In this case, we need to add padding to the mask, so cannot be skipped otherwise
211-
and (
212-
padding_mask is None
213-
or (
214-
padding_mask.all()
215-
if not _is_torch_xpu_available or query_length == 1
216-
else padding_mask[:, :query_length].all()
217-
)
218-
)
253+
and (padding_mask is None or padding_mask.all())
219254
):
220255
return True
221256

tests/models/qwen3/test_modeling_qwen3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_model_600m_long_prompt_sdpa(self):
165165
def test_speculative_generation(self):
166166
EXPECTED_TEXT_COMPLETIONS = Expectations(
167167
{
168-
("xpu", 3): "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it",
168+
("xpu", 3): "My favourite condiment is 100% beef and comes in a 12 oz. jar. It is sold in",
169169
("cuda", 7): "My favourite condiment is 100% natural. It's a little spicy and a little sweet, but it's the",
170170
("cuda", 8): "My favourite condiment is 100% beef, 100% beef, 100% beef.",
171171
}

0 commit comments

Comments
 (0)