Skip to content
Open
180 changes: 129 additions & 51 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def run_dpa_with_cp(
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
bias.requires_grad = True
else:
bias = None

Expand Down Expand Up @@ -338,7 +339,7 @@ def run_dpa_with_cp(
out.backward(dout_fp8)
else:
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None
d_softmax_offset = None
if config.softmax_type != "vanilla":
d_softmax_offset = core_attn.softmax_offset.grad
Expand Down Expand Up @@ -394,6 +395,7 @@ def run_dpa_with_cp(
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
bias_.requires_grad = True
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
Expand Down Expand Up @@ -433,23 +435,27 @@ def run_dpa_with_cp(
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None
d_softmax_offset_ = None
if config.softmax_type != "vanilla":
d_softmax_offset_ = core_attn.softmax_offset.grad.clone()

# get outputs
tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_]
tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_]
if fp8_mha:
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
for i, tensor in enumerate(tensors_to_deq):
tensors_to_deq[i] = tensor.dequantize()
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[4] = tensors_to_deq
tensors[0], tensors[5] = tensors_to_deq
for tensor in tensors:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd":
Expand All @@ -467,6 +473,22 @@ def run_dpa_with_cp(
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
for x in [dq_, dk_, dv_, out_]
]
if dbias is not None and dbias_ is not None:
dbias = dbias.view(
dbias.shape[0],
dbias.shape[1],
2 * world_size,
dbias.shape[2] // (2 * world_size),
dbias.shape[3],
)
# bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our CP implementation (after your C changes) should support all bias shapes, not just 111s. I also think your reshaping here should work for all shapes. Could you run the tests to confirm?

dbias = dbias.index_select(2, seq_idx)
# Flatten
dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1])
dbias_ = dbias_.view(
dbias_.shape[0], dbias_.shape[1], 2, dbias_.shape[2] // 2, dbias_.shape[3]
)

elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
Expand Down Expand Up @@ -509,57 +531,113 @@ def run_dpa_with_cp(
)

atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
if qkv_format == "bshd":
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare the two sequence chunks separately
# Compare dbias
if names[i] == "dbias":
# After reshaping: (1, 1, 2, seq_q//2, seq_kv)
# Compare along dimension 2 (the split sequence dimension)
compare_and_assert(
t[:, :, 0], # First sequence chunk
tensors_cp[i][:, :, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, :, 1], # Second sequence chunk
tensors_cp[i][:, :, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare Q/K/V/out
else:
# Compare along dimension 1 (the split sequence dimension)
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "sbhd":
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare the two sequence chunks separately
# Compare dbias (same as BSHD)
if names[i] == "dbias":
# After reshaping: (1, 1, 2, seq_q//2, seq_kv)
# Compare along dimension 2 (the split sequence dimension)
compare_and_assert(
t[:, :, 0], # First sequence chunk
tensors_cp[i][:, :, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, :, 1], # Second sequence chunk
tensors_cp[i][:, :, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare Q/K/V/out
else:
# Compare along dimension 0 (the split sequence dimension)
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "thd":
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
Expand Down
9 changes: 6 additions & 3 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_dpa_mask(dtype, model_configs, model):

model_configs_bias = {
# test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="111s"),
"bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
Expand Down Expand Up @@ -1131,11 +1131,14 @@ def _run_dot_product_attention(
bias = None
if config.attn_bias_type == "post_scale_bias":
shape = "_".join(config.bias_shape)
# For 1hss, 11ss, b1ss, bhss
shape_cache = shape
shape = shape.replace("_s_s", "_sq_skv")
if shape == shape_cache:
# For 111s
shape = shape.replace("_1_s", "_1_skv")
tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != "1hss":
bias.requires_grad = False

# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
Expand Down
Loading