-
Notifications
You must be signed in to change notification settings - Fork 612
[PyT] Plumbing correct bias dims from TE to cudnn #2537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[PyT] Plumbing correct bias dims from TE to cudnn #2537
Conversation
|
/te-ci pytorch L0 L1 |
200fd98 to
8da3252
Compare
Greptile SummaryThis PR fixes a bug where TransformerEngine was passing incorrect bias tensor dimensions to cuDNN. Previously, TE used the full sequence dimensions Key Changes:
The fix ensures cuDNN receives correct bias dimensions (e.g., Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PyT as PyTorch Layer
participant Utils as utils.py
participant F16Fwd as fused_attn_f16_fwd
participant F16Bwd as fused_attn_f16_bwd
participant cuDNN as cuDNN Graph
Note over PyT,cuDNN: Forward Pass
PyT->>Utils: check attention backend
Utils->>Utils: Enable FusedAttention for bias shapes
PyT->>F16Fwd: fused_attn_arbitrary_seqlen_fwd
F16Fwd->>F16Fwd: Extract bias_b, bias_h, bias_sq, bias_skv<br/>from input_Bias->data.shape[0-3]
F16Fwd->>cuDNN: Create bias tensor with dims<br/>[bias_b, bias_h, bias_sq, bias_skv]
Note right of cuDNN: Uses actual bias tensor dims<br/>instead of [b, h, s_q, s_kv]
cuDNN-->>F16Fwd: Forward result
F16Fwd-->>PyT: Output
Note over PyT,cuDNN: Backward Pass
PyT->>F16Bwd: fused_attn_arbitrary_seqlen_bwd
F16Bwd->>F16Bwd: Extract bias_b, bias_h, bias_sq, bias_skv<br/>from output_dBias->data.shape[0-3]
F16Bwd->>cuDNN: Create bias/dBias tensors with dims<br/>[bias_b, bias_h, bias_sq, bias_skv]
Note right of cuDNN: Consistent bias dims in forward/backward
cuDNN-->>F16Bwd: Gradients (dQ, dK, dV, dBias)
F16Bwd-->>PyT: Gradients
|
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
Looks good - please pick the 111s test from my branch as well. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Fixes bias dimension handling in fused attention by plumbing actual bias tensor dimensions (bias_sq, bias_skv) from input tensors through to cuDNN, replacing the previous incorrect usage of query/key sequence lengths (s_q, s_kv). This resolves dimension mismatches for broadcasted bias shapes like [1,1,1,s] where the bias dimensions are smaller than the attention matrix dimensions. The fix enables gradient computation for non-1hss bias shapes by removing the backward pass restriction in the Python layer.
Confidence Score: 4/5
- Safe to merge after addressing minor consistency concern in backward pass dimension extraction
- The core fix correctly addresses the bias dimension bug by extracting actual tensor shapes instead of using sequence lengths. The implementation is consistent across forward pass, backward pass, and FP8 paths. Test coverage has been expanded to validate the fix. One minor style issue: backward pass extracts bias_b/bias_h from output_dBias but bias_sq/bias_skv from input_Bias, creating potential inconsistency if shapes don't match, though this is unlikely in practice.
- transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu for dimension extraction consistency in backward pass
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/utils.h | 5/5 | Adds bias_sq and bias_skv fields to FADescriptor_v1 struct and updates comparison operator |
| transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 4/5 | Updates fwd/bwd implementations to extract and use actual bias dimensions from input tensors instead of query/key sequence lengths |
| transformer_engine/pytorch/attention/dot_product_attention/utils.py | 4/5 | Removes restriction preventing bias gradient computation for non-1hss bias shapes, enabling backward pass support |
Sequence Diagram
sequenceDiagram
participant Py as Python Layer
participant TE as TE Common (CUDA)
participant cuDNN as cuDNN Backend
Note over Py,cuDNN: Bias Dimension Propagation Fix
Py->>TE: Pass bias tensor [b, h, bias_sq, bias_skv]
Note over TE: Extract actual bias dims<br/>bias_sq = input_Bias->shape[2]<br/>bias_skv = input_Bias->shape[3]
TE->>TE: Store in FADescriptor_v1<br/>(bias_sq, bias_skv)
alt Before Fix
Note over TE: Used s_q, s_kv incorrectly<br/>(e.g., [1,1,128,128] for [1,1,1,128])
end
alt After Fix
Note over TE: Uses bias_sq, bias_skv correctly<br/>(e.g., [1,1,1,128] for [1,1,1,128])
end
TE->>cuDNN: Create bias tensor with<br/>dim={bias_b, bias_h, bias_sq, bias_skv}
TE->>cuDNN: Create dBias tensor with same dims
cuDNN->>TE: Compute attention + gradients
TE->>Py: Return output with correct bias gradients
| bias_b = output_dBias->data.shape[0]; | ||
| bias_h = output_dBias->data.shape[1]; | ||
| bias_sq = input_Bias->data.shape[2]; | ||
| bias_skv = input_Bias->data.shape[3]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bias dimensions are sourced from different tensors: bias_b and bias_h from output_dBias, while bias_sq and bias_skv from input_Bias. This assumes both tensors have matching shapes. Consider extracting all dimensions from the same tensor (preferably input_Bias for consistency with forward pass) or adding a validation check that shapes match.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in SHA 143ede5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Fixes bias dimension plumbing from TransformerEngine to cuDNN by passing actual bias tensor dimensions (bias_sq, bias_skv) instead of sequence dimensions (s_q, s_kv). This resolves incorrect bias shape information being sent to cuDNN, particularly noticeable for bias shapes like [1,1,1,s] where the bias sequence dimensions differ from query/key/value sequence lengths. The fix enables cuDNN backend support for bias gradient computation in previously unsupported shapes.
Confidence Score: 5/5
- Safe to merge - correct bug fix with comprehensive test coverage and no breaking changes
- This PR correctly fixes the bias dimension plumbing issue where TE was incorrectly passing sequence dimensions instead of actual bias dimensions to cuDNN. The fix is well-implemented across all affected code paths (F16 and FP8), properly extracts bias dimensions from input tensors, and includes comprehensive test coverage. No functional issues or edge cases were identified.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/utils.h | 5/5 | Added bias_sq and bias_skv fields to FADescriptor_v1 struct and updated comparison operator |
| transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 5/5 | Updated forward and backward implementations to extract and pass correct bias dimensions from input tensors to cuDNN |
| transformer_engine/pytorch/attention/dot_product_attention/utils.py | 5/5 | Removed restriction that disabled FusedAttention for bias gradients in non-1hss shapes, enabling cuDNN backend for these cases |
Sequence Diagram
sequenceDiagram
participant PyTorch as PyTorch Layer
participant Utils as utils.py
participant F16Impl as fused_attn_f16<br/>arbitrary_seqlen.cu
participant Descriptor as FADescriptor_v1
participant cuDNN as cuDNN FE Graph
Note over PyTorch,cuDNN: Forward Pass with Bias [1,1,1,s]
PyTorch->>Utils: get_attention_backend()<br/>check bias support
Utils->>Utils: Enable cuDNN for<br/>bias gradient
PyTorch->>F16Impl: fused_attn_arbitrary_seqlen_fwd()<br/>with input_Bias tensor
F16Impl->>F16Impl: Extract bias dimensions:<br/>bias_sq = input_Bias.shape[2]<br/>bias_skv = input_Bias.shape[3]
F16Impl->>Descriptor: Create FADescriptor_v1<br/>with bias_sq, bias_skv
F16Impl->>cuDNN: Create bias tensor with<br/>dim=[bias_b, bias_h, bias_sq, bias_skv]
Note over cuDNN: Correct dims [1,1,1,s]<br/>instead of [1,1,s,s]
cuDNN-->>F16Impl: Execute attention
F16Impl-->>PyTorch: Return output
|
/te-ci pytorch L0 L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR fixes a bug where TransformerEngine was incorrectly passing attention bias dimensions to cuDNN. Instead of using the actual bias tensor dimensions [bias_sq, bias_skv], it was using the full sequence dimensions [s_q, s_kv], which could be larger than the bias tensor.
Major Changes
- Core Fix: Extract and pass actual bias dimensions (
bias_sq,bias_skv) from the bias tensor shape throughout the call chain to cuDNN - Struct Update: Added
bias_sqandbias_skvfields toFADescriptor_v1for proper caching - Test Enhancement: Added bias gradient tracking and comparison in context parallelism tests
- Backend Selection: Removed incorrect logic that disabled FusedAttention for non-1hss bias shapes when gradients weren't required
Issues Found
- Critical Bug in Tests:
run_attention_with_cp.pyattempts to accessbias.gradwhenbiasisNone(lines 342, 438), causingAttributeErrorfor "no_bias" and "alibi" test cases
Confidence Score: 3/5
- This PR fixes an important bug in bias dimension handling but introduces critical test failures
- The core fix correctly addresses the bias dimension bug and is well-implemented across the C++/CUDA codebase. However, the test changes contain logic errors that will cause
AttributeErrorwhen running tests with "no_bias" or "alibi" configurations, preventing proper validation of the fix. - Pay close attention to
tests/pytorch/attention/run_attention_with_cp.pywhich has critical bugs on lines 342 and 438
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/utils.h | 5/5 | Added bias_sq and bias_skv fields to FADescriptor_v1 struct and updated the comparison operator. Changes are straightforward and correctly implemented. |
| transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 5/5 | Correctly extracts bias_sq and bias_skv from input_Bias->data.shape and passes them through the call chain to cuDNN. Bias tensor dimensions and strides are properly updated to use actual bias dimensions instead of sequence lengths. |
| tests/pytorch/attention/run_attention_with_cp.py | 2/5 | Adds bias gradient tracking and comparison logic for context parallelism tests. Contains critical bugs where bias.grad and bias_.grad are accessed when bias is None, causing AttributeError. Also adds proper reshaping logic for dbias comparison. |
Sequence Diagram
sequenceDiagram
participant Python as Python Layer<br/>(utils.py)
participant ArbitraryFwd as fused_attn_arbitrary_seqlen_fwd<br/>(C++ wrapper)
participant ArbitraryFwdImpl as fused_attn_arbitrary_seqlen_fwd_impl<br/>(C++ implementation)
participant cuDNN as cuDNN Graph
Note over Python,cuDNN: Forward Pass with Bias [1, 1, 1, s_kv]
Python->>ArbitraryFwd: input_Bias tensor with shape [b, h, sq, skv]
ArbitraryFwd->>ArbitraryFwd: Extract bias_b = input_Bias->shape[0]<br/>bias_h = input_Bias->shape[1]<br/>bias_sq = input_Bias->shape[2]<br/>bias_skv = input_Bias->shape[3]
ArbitraryFwd->>ArbitraryFwdImpl: Pass bias_b, bias_h, bias_sq, bias_skv
ArbitraryFwdImpl->>ArbitraryFwdImpl: Store in FADescriptor_v1 for caching
ArbitraryFwdImpl->>cuDNN: Create bias tensor with dimensions<br/>[bias_b, bias_h, bias_sq, bias_skv]<br/>Previously used [bias_b, bias_h, s_q, s_kv] ❌
Note over cuDNN: Now receives correct bias dimensions ✓
Note over Python,cuDNN: Backward Pass
ArbitraryFwd->>ArbitraryFwd: Extract from output_dBias->shape
ArbitraryFwd->>ArbitraryFwdImpl: Pass bias_sq, bias_skv
ArbitraryFwdImpl->>cuDNN: Set dBias dimensions to [bias_b, bias_h, bias_sq, bias_skv]
Note over cuDNN: dBias only computed if (bias_b==1 && bias_h==h)
| else: | ||
| out.backward(dout) | ||
| dq, dk, dv = q.grad, k.grad, v.grad | ||
| dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias is None when attn_bias_type is "no_bias" or "alibi" (line 312), so bias.grad will raise AttributeError
| dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad | |
| dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| else: | ||
| out_.backward(dout_) | ||
| dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad | ||
| dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias_ is None when bias is None (line 355), so bias_.grad will raise AttributeError
| dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad | |
| dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| dbias.shape[2] // (2 * world_size), | ||
| dbias.shape[3], | ||
| ) | ||
| # bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think our CP implementation (after your C changes) should support all bias shapes, not just 111s. I also think your reshaping here should work for all shapes. Could you run the tests to confirm?
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
11c7107 to
de3011e
Compare
|
/te-ci pytorch L0 L1 |
…s, b1ss and bhss Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
|
/te-ci pytorch L0 L1 |
for more information, see https://pre-commit.ci
Description
TE common was not plumbing attention vector bias dimensions correctly to cuDNN.
Instead of using shape from Bias, i.e.
[bias_sq, bias_skv]it was using[sq, skv]thereby passing larger than required dims. Using the reproducer : https://github.com/cyanguwa/TransformerEngine/tree/test_111s for bias [1,1,1,s] it can be seen in the cuDNN FE logs that prior to this PR the bias dims passed onto cuDNN from TE were{"data_type":null,"dim":[1,1,128,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[16384,16384,128,1],"uid":0,"uid_assigned":false},and after this PR they are:
"bias":{"data_type":null,"dim":[1,1,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[128,128,128,1],"uid":0,"uid_assigned":false},Type of change
Changes
bias_sqandbias_skvtofused_attn_arbitrary_seqlen_fwd_impl()andfused_attn_arbitrary_seqlen_bwd_impl()bias_sqandbias_skvinFADescriptor_v1bias_sqandbias_skvinstead ofs_qands_kvChecklist: