-
Notifications
You must be signed in to change notification settings - Fork 159
[TRITON] Add attention sink support to Triton MHA kernels #1576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
f677369 to
76bfe95
Compare
|
Rebased on top of |
|
@brunomazzottiamd I added some comments specifically about the threshold for the tests. |
Thank you so much for your time reviewing this PR. I'll answer you properly. |
b082964 to
6eccfc1
Compare
|
@cagrikymk, please let me know if I have answered your questions properly. Feel free to suggest other changes or improvements. Thanks to your review I could decrease FYI: Agreed deadline for merging this PR is December 15th. |
|
@brunomazzottiamd LGTM! |
6eccfc1 to
e898512
Compare
|
Rebased on top of |
Motivation
In gpt-oss attention implementation, each attention head has a learned bias in the denominator of the softmax. This is similar to attention sink and we can enable gpt-oss by adding attention sink support to our AITER MHA kernels (both forward and backward kernels). The target model is gpt-oss20b.
Technical Details
fp8data types. The proposed changes were tested withbf16andfp32data types, but they should also work withfp16data type.-inf.fp32to enable atomics ops.Test Plan
op_tests/triton_tests/test_mha.py:test_mha_with_sinkandtest_mha_varlen_with_pe. They cover 192 new cases and test both forward and backward passes.Test Result
op_tests/triton_tests/test_mha.pyare passing ongfx942andgfx950.op_tests/triton_tests/test_mha.pyproduce the same results as before the sink was added. This happens on bothgfx942andgfx950. So we can conclude that the newly added sink feature didn't break anything that was already working.Performance Assessment
Target attention shapes:
bf16thdlayout and 1 forbshdlayout.Forward performance in
gfx950:Backward performance in
gfx950:Conclusion: Attention sink feature doesn't change performance on
gfx950.I did the same analysis on
gfx942and got the same conclusion. I'm not publishing the numbers in the PR for the sake of brevity.Submission Checklist