Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,15 @@ def test_dpa_softmax(dtype, model_configs, model):
)


@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax_thd(dtype, model_configs, model):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False)


model_configs_mla = {
# test: ModelConfig(b, sq, hq, dqk)
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
Expand Down
9 changes: 7 additions & 2 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,14 @@ def test_cp_with_fused_attention(
pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if config.softmax_type != "vanilla" and qkv_format == "thd":
if (
get_cudnn_version() < (9, 18, 0)
and config.softmax_type != "vanilla"
and qkv_format == "thd"
):
pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
"Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for"
" non-vanilla softmax types!"
)

dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4026,28 +4026,30 @@ def attn_forward_func_with_cp(
assert not sliding_window_attn or cp_comm_type in [
"a2a",
"all_gather",
], "Context parallelism does not support sliding window attention with {cp_comm_type=}!"
], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!"

enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a+p2p",
], "Context parallelism does not support MLA with {cp_comm_type=}!"
], f"Context parallelism does not support MLA with {cp_comm_type=}!"

if fp8 and fp8_meta is not None:
if fp8_meta["recipe"].fp8_dpa:
assert (
softmax_type == "vanilla"
), "Context parallelism does not support {softmax_type=} with FP8 attention!"
), f"Context parallelism does not support {softmax_type=} with FP8 attention!"
assert (
softmax_type == "vanilla" or use_fused_attention
), "Context parallelism only supports {softmax_type=} with FusedAttention backend!"
), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!"
assert (
softmax_type == "vanilla" or cp_comm_type == "a2a"
), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert (
softmax_type == "vanilla" or qkv_format != "thd"
), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!"
), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
if get_cudnn_version() < (9, 18, 0):
assert softmax_type == "vanilla" or qkv_format != "thd", (
f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with"
" qkv_format = 'thd'!"
)

args = [
is_training,
Expand Down
31 changes: 16 additions & 15 deletions transformer_engine/pytorch/attention/dot_product_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,22 +716,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if cudnn_version < (9, 18, 0):
logger.debug(
"Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN"
" version < 9.18",
softmax_type,
)
use_fused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
Expand Down Expand Up @@ -1049,6 +1041,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_flash_attention_2 = False
if use_fused_attention and deterministic:
if softmax_type != "vanilla":
logger.debug(
"Disabling FusedAttention for determinism reasons with softmax_type = %s. "
"Sink attention (off-by-one and learnable softmax) requires "
"NVTE_ALLOW_NONDETERMINISTIC_ALGO=1",
softmax_type,
)
use_fused_attention = False
fused_attention_backend = None
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons with FP8")
use_fused_attention = False
Expand Down
Loading