Skip to content

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 20, 2025

Description

TE common was not plumbing attention vector bias dimensions correctly to cuDNN.
Instead of using shape from Bias, i.e. [bias_sq, bias_skv] it was using [sq, skv] thereby passing larger than required dims. Using the reproducer : https://github.com/cyanguwa/TransformerEngine/tree/test_111s for bias [1,1,1,s] it can be seen in the cuDNN FE logs that prior to this PR the bias dims passed onto cuDNN from TE were
{"data_type":null,"dim":[1,1,128,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[16384,16384,128,1],"uid":0,"uid_assigned":false},
and after this PR they are:
"bias":{"data_type":null,"dim":[1,1,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[128,128,128,1],"uid":0,"uid_assigned":false},

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Passing bias_sq and bias_skv to fused_attn_arbitrary_seqlen_fwd_impl() and fused_attn_arbitrary_seqlen_bwd_impl()
  • Adding new entries for bias_sq and bias_skv in FADescriptor_v1
  • Correct the bias passed to the MHA cuDNN graph to use bias_sq and bias_skv instead of s_q and s_kv

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • [] I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • [] New and existing unit tests pass locally with my changes

@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/bias-shape branch from 200fd98 to 8da3252 Compare December 22, 2025 18:21
@KshitijLakhani KshitijLakhani marked this pull request as ready for review December 22, 2025 18:24
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 22, 2025

Greptile Summary

This PR fixes a bug where TransformerEngine was passing incorrect bias tensor dimensions to cuDNN. Previously, TE used the full sequence dimensions [s_q, s_kv] instead of the actual bias tensor dimensions [bias_sq, bias_skv], which caused issues with broadcasted bias shapes like [1,1,1,s].

Key Changes:

  • Extracts actual bias dimensions from input tensors (input_Bias in forward, output_dBias in backward) and passes them to cuDNN graph creation
  • Adds bias_sq and bias_skv fields to FADescriptor_v1 struct for proper execution plan caching
  • Enables bias gradient computation for non-1hss shapes by removing the logic that disabled FusedAttention for these cases
  • Adds comprehensive test coverage for 111s bias shape with context parallelism testing
  • Properly handles None bias cases in gradient collection and comparison logic

The fix ensures cuDNN receives correct bias dimensions (e.g., [1,1,1,128] instead of [1,1,128,128] for 111s shape), enabling proper broadcasting behavior.

Confidence Score: 5/5

  • This PR is safe to merge with no identified issues
  • The implementation correctly addresses the bias dimension bug with proper tensor shape extraction at the right points (forward uses input_Bias, backward uses output_dBias for consistency). Previous review concerns about shape consistency have been addressed. The changes are well-tested with proper None handling, and the logic is straightforward without introducing new edge cases.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Extracts bias dimensions (bias_sq, bias_skv) from input_Bias tensor and correctly passes them to cuDNN graph creation, fixing the dimension mismatch issue
transformer_engine/common/fused_attn/utils.h Adds bias_sq and bias_skv fields to FADescriptor_v1 struct and includes them in comparison operator for proper cache key generation
transformer_engine/pytorch/attention/dot_product_attention/utils.py Removes logic that disabled FusedAttention for bias gradient computation in non-1hss shapes, enabling support for shapes like 111s with backward pass

Sequence Diagram

sequenceDiagram
    participant PyT as PyTorch Layer
    participant Utils as utils.py
    participant F16Fwd as fused_attn_f16_fwd
    participant F16Bwd as fused_attn_f16_bwd
    participant cuDNN as cuDNN Graph

    Note over PyT,cuDNN: Forward Pass
    PyT->>Utils: check attention backend
    Utils->>Utils: Enable FusedAttention for bias shapes
    PyT->>F16Fwd: fused_attn_arbitrary_seqlen_fwd
    F16Fwd->>F16Fwd: Extract bias_b, bias_h, bias_sq, bias_skv<br/>from input_Bias->data.shape[0-3]
    F16Fwd->>cuDNN: Create bias tensor with dims<br/>[bias_b, bias_h, bias_sq, bias_skv]
    Note right of cuDNN: Uses actual bias tensor dims<br/>instead of [b, h, s_q, s_kv]
    cuDNN-->>F16Fwd: Forward result
    F16Fwd-->>PyT: Output

    Note over PyT,cuDNN: Backward Pass
    PyT->>F16Bwd: fused_attn_arbitrary_seqlen_bwd
    F16Bwd->>F16Bwd: Extract bias_b, bias_h, bias_sq, bias_skv<br/>from output_dBias->data.shape[0-3]
    F16Bwd->>cuDNN: Create bias/dBias tensors with dims<br/>[bias_b, bias_h, bias_sq, bias_skv]
    Note right of cuDNN: Consistent bias dims in forward/backward
    cuDNN-->>F16Bwd: Gradients (dQ, dK, dV, dBias)
    F16Bwd-->>PyT: Gradients
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 22, 2025

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@KshitijLakhani KshitijLakhani changed the title Plumbing correct bias dims from TE to cudnn [PyT] Plumbing correct bias dims from TE to cudnn Dec 22, 2025
@KshitijLakhani KshitijLakhani added bug Something isn't working pytorch labels Dec 22, 2025
@cyanguwa
Copy link
Collaborator

Looks good - please pick the 111s test from my branch as well. Thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Fixes bias dimension handling in fused attention by plumbing actual bias tensor dimensions (bias_sq, bias_skv) from input tensors through to cuDNN, replacing the previous incorrect usage of query/key sequence lengths (s_q, s_kv). This resolves dimension mismatches for broadcasted bias shapes like [1,1,1,s] where the bias dimensions are smaller than the attention matrix dimensions. The fix enables gradient computation for non-1hss bias shapes by removing the backward pass restriction in the Python layer.

Confidence Score: 4/5

  • Safe to merge after addressing minor consistency concern in backward pass dimension extraction
  • The core fix correctly addresses the bias dimension bug by extracting actual tensor shapes instead of using sequence lengths. The implementation is consistent across forward pass, backward pass, and FP8 paths. Test coverage has been expanded to validate the fix. One minor style issue: backward pass extracts bias_b/bias_h from output_dBias but bias_sq/bias_skv from input_Bias, creating potential inconsistency if shapes don't match, though this is unlikely in practice.
  • transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu for dimension extraction consistency in backward pass

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/utils.h 5/5 Adds bias_sq and bias_skv fields to FADescriptor_v1 struct and updates comparison operator
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 4/5 Updates fwd/bwd implementations to extract and use actual bias dimensions from input tensors instead of query/key sequence lengths
transformer_engine/pytorch/attention/dot_product_attention/utils.py 4/5 Removes restriction preventing bias gradient computation for non-1hss bias shapes, enabling backward pass support

Sequence Diagram

sequenceDiagram
    participant Py as Python Layer
    participant TE as TE Common (CUDA)
    participant cuDNN as cuDNN Backend
    
    Note over Py,cuDNN: Bias Dimension Propagation Fix
    
    Py->>TE: Pass bias tensor [b, h, bias_sq, bias_skv]
    Note over TE: Extract actual bias dims<br/>bias_sq = input_Bias->shape[2]<br/>bias_skv = input_Bias->shape[3]
    
    TE->>TE: Store in FADescriptor_v1<br/>(bias_sq, bias_skv)
    
    alt Before Fix
        Note over TE: Used s_q, s_kv incorrectly<br/>(e.g., [1,1,128,128] for [1,1,1,128])
    end
    
    alt After Fix
        Note over TE: Uses bias_sq, bias_skv correctly<br/>(e.g., [1,1,1,128] for [1,1,1,128])
    end
    
    TE->>cuDNN: Create bias tensor with<br/>dim={bias_b, bias_h, bias_sq, bias_skv}
    TE->>cuDNN: Create dBias tensor with same dims
    
    cuDNN->>TE: Compute attention + gradients
    TE->>Py: Return output with correct bias gradients
Loading

Comment on lines 1245 to 1248
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
bias_sq = input_Bias->data.shape[2];
bias_skv = input_Bias->data.shape[3];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bias dimensions are sourced from different tensors: bias_b and bias_h from output_dBias, while bias_sq and bias_skv from input_Bias. This assumes both tensors have matching shapes. Consider extracting all dimensions from the same tensor (preferably input_Bias for consistency with forward pass) or adding a validation check that shapes match.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in SHA 143ede5

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Fixes bias dimension plumbing from TransformerEngine to cuDNN by passing actual bias tensor dimensions (bias_sq, bias_skv) instead of sequence dimensions (s_q, s_kv). This resolves incorrect bias shape information being sent to cuDNN, particularly noticeable for bias shapes like [1,1,1,s] where the bias sequence dimensions differ from query/key/value sequence lengths. The fix enables cuDNN backend support for bias gradient computation in previously unsupported shapes.

Confidence Score: 5/5

  • Safe to merge - correct bug fix with comprehensive test coverage and no breaking changes
  • This PR correctly fixes the bias dimension plumbing issue where TE was incorrectly passing sequence dimensions instead of actual bias dimensions to cuDNN. The fix is well-implemented across all affected code paths (F16 and FP8), properly extracts bias dimensions from input tensors, and includes comprehensive test coverage. No functional issues or edge cases were identified.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/utils.h 5/5 Added bias_sq and bias_skv fields to FADescriptor_v1 struct and updated comparison operator
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 5/5 Updated forward and backward implementations to extract and pass correct bias dimensions from input tensors to cuDNN
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Removed restriction that disabled FusedAttention for bias gradients in non-1hss shapes, enabling cuDNN backend for these cases

Sequence Diagram

sequenceDiagram
    participant PyTorch as PyTorch Layer
    participant Utils as utils.py
    participant F16Impl as fused_attn_f16<br/>arbitrary_seqlen.cu
    participant Descriptor as FADescriptor_v1
    participant cuDNN as cuDNN FE Graph

    Note over PyTorch,cuDNN: Forward Pass with Bias [1,1,1,s]
    
    PyTorch->>Utils: get_attention_backend()<br/>check bias support
    Utils->>Utils: Enable cuDNN for<br/>bias gradient
    PyTorch->>F16Impl: fused_attn_arbitrary_seqlen_fwd()<br/>with input_Bias tensor
    F16Impl->>F16Impl: Extract bias dimensions:<br/>bias_sq = input_Bias.shape[2]<br/>bias_skv = input_Bias.shape[3]
    F16Impl->>Descriptor: Create FADescriptor_v1<br/>with bias_sq, bias_skv
    F16Impl->>cuDNN: Create bias tensor with<br/>dim=[bias_b, bias_h, bias_sq, bias_skv]
    Note over cuDNN: Correct dims [1,1,1,s]<br/>instead of [1,1,s,s]
    cuDNN-->>F16Impl: Execute attention
    F16Impl-->>PyTorch: Return output
Loading

@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR fixes a bug where TransformerEngine was incorrectly passing attention bias dimensions to cuDNN. Instead of using the actual bias tensor dimensions [bias_sq, bias_skv], it was using the full sequence dimensions [s_q, s_kv], which could be larger than the bias tensor.

Major Changes

  • Core Fix: Extract and pass actual bias dimensions (bias_sq, bias_skv) from the bias tensor shape throughout the call chain to cuDNN
  • Struct Update: Added bias_sq and bias_skv fields to FADescriptor_v1 for proper caching
  • Test Enhancement: Added bias gradient tracking and comparison in context parallelism tests
  • Backend Selection: Removed incorrect logic that disabled FusedAttention for non-1hss bias shapes when gradients weren't required

Issues Found

  • Critical Bug in Tests: run_attention_with_cp.py attempts to access bias.grad when bias is None (lines 342, 438), causing AttributeError for "no_bias" and "alibi" test cases

Confidence Score: 3/5

  • This PR fixes an important bug in bias dimension handling but introduces critical test failures
  • The core fix correctly addresses the bias dimension bug and is well-implemented across the C++/CUDA codebase. However, the test changes contain logic errors that will cause AttributeError when running tests with "no_bias" or "alibi" configurations, preventing proper validation of the fix.
  • Pay close attention to tests/pytorch/attention/run_attention_with_cp.py which has critical bugs on lines 342 and 438

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/utils.h 5/5 Added bias_sq and bias_skv fields to FADescriptor_v1 struct and updated the comparison operator. Changes are straightforward and correctly implemented.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 5/5 Correctly extracts bias_sq and bias_skv from input_Bias->data.shape and passes them through the call chain to cuDNN. Bias tensor dimensions and strides are properly updated to use actual bias dimensions instead of sequence lengths.
tests/pytorch/attention/run_attention_with_cp.py 2/5 Adds bias gradient tracking and comparison logic for context parallelism tests. Contains critical bugs where bias.grad and bias_.grad are accessed when bias is None, causing AttributeError. Also adds proper reshaping logic for dbias comparison.

Sequence Diagram

sequenceDiagram
    participant Python as Python Layer<br/>(utils.py)
    participant ArbitraryFwd as fused_attn_arbitrary_seqlen_fwd<br/>(C++ wrapper)
    participant ArbitraryFwdImpl as fused_attn_arbitrary_seqlen_fwd_impl<br/>(C++ implementation)
    participant cuDNN as cuDNN Graph
    
    Note over Python,cuDNN: Forward Pass with Bias [1, 1, 1, s_kv]
    
    Python->>ArbitraryFwd: input_Bias tensor with shape [b, h, sq, skv]
    ArbitraryFwd->>ArbitraryFwd: Extract bias_b = input_Bias->shape[0]<br/>bias_h = input_Bias->shape[1]<br/>bias_sq = input_Bias->shape[2]<br/>bias_skv = input_Bias->shape[3]
    ArbitraryFwd->>ArbitraryFwdImpl: Pass bias_b, bias_h, bias_sq, bias_skv
    ArbitraryFwdImpl->>ArbitraryFwdImpl: Store in FADescriptor_v1 for caching
    ArbitraryFwdImpl->>cuDNN: Create bias tensor with dimensions<br/>[bias_b, bias_h, bias_sq, bias_skv]<br/>Previously used [bias_b, bias_h, s_q, s_kv] ❌
    Note over cuDNN: Now receives correct bias dimensions ✓
    
    Note over Python,cuDNN: Backward Pass
    ArbitraryFwd->>ArbitraryFwd: Extract from output_dBias->shape
    ArbitraryFwd->>ArbitraryFwdImpl: Pass bias_sq, bias_skv
    ArbitraryFwdImpl->>cuDNN: Set dBias dimensions to [bias_b, bias_h, bias_sq, bias_skv]
    Note over cuDNN: dBias only computed if (bias_b==1 && bias_h==h)
Loading

else:
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias is None when attn_bias_type is "no_bias" or "alibi" (line 312), so bias.grad will raise AttributeError

Suggested change
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

else:
out_.backward(dout_)
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias_ is None when bias is None (line 355), so bias_.grad will raise AttributeError

Suggested change
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

dbias.shape[2] // (2 * world_size),
dbias.shape[3],
)
# bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our CP implementation (after your C changes) should support all bias shapes, not just 111s. I also think your reshaping here should work for all shapes. Could you run the tests to confirm?

KshitijLakhani and others added 7 commits January 21, 2026 19:41
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/bias-shape branch from 11c7107 to de3011e Compare January 21, 2026 19:41
@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

…s, b1ss and bhss

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants