Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,12 +1155,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
block.softmax_offset.requires_grad = True

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
q = inp[0]
k = inp[1]
v = inp[2]
Expand All @@ -1176,14 +1176,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
cu_seqlens_q_padded=(
cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
cu_seqlens_kv_padded=(
cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
pad_between_seqs=pad_between_seqs,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
Expand All @@ -1197,12 +1202,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad

if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
if is_training:
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ def forward(
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal",
Expand All @@ -678,6 +680,7 @@ def forward(
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
Expand Down Expand Up @@ -924,8 +927,14 @@ def forward(
else:
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
if not use_flash_attn_3 or inference_params is None:
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(
cu_seqlens_q_padded if cu_seqlens_q_padded is not None else cu_seqlens_q
)
fa_optional_forward_args_thd.append(
cu_seqlens_kv_padded
if cu_seqlens_kv_padded is not None
else cu_seqlens_kv
)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if not use_flash_attn_3:
Expand Down Expand Up @@ -961,6 +970,17 @@ def forward(
fa_3_optional_forward_kwargs["num_splits"] = num_splits
if inference_params is None:
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic

# if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k`
# in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the
# padding positions.
if pad_between_seqs:
fa_3_optional_forward_kwargs["seqused_q"] = (
cu_seqlens_q[1:] - cu_seqlens_q[:-1]
)
fa_3_optional_forward_kwargs["seqused_k"] = (
cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
)
Comment on lines +974 to +983
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?

else:
fa_3_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
fa_3_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,8 @@ def forward(
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes=alibi_slopes,
Expand All @@ -1436,6 +1438,7 @@ def forward(
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,14 +675,20 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
# Filter: QKV layout
if qkv_format == "thd":
if pad_between_seqs:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
use_flash_attention = False
if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
# Turn on FlashAttention 3 for thd when there is padding between
# sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD].
# This is because flash_attn_3 can take in `seqused_q` and `seqused_k`
# in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the
# padding positions.
use_flash_attention = True

if device_compute_capability == (12, 0):
if use_fused_attention:
logger.debug(
Expand Down
Loading