-
Notifications
You must be signed in to change notification settings - Fork 612
[TORCH] Added flex_attention hop function #4366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
keshavvinayak01
wants to merge
24
commits into
llvm:main
Choose a base branch
from
keshavvinayak01:keshavvinayak01/torch-aten-flex_attention
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[TORCH] Added flex_attention hop function #4366
keshavvinayak01
wants to merge
24
commits into
llvm:main
from
keshavvinayak01:keshavvinayak01/torch-aten-flex_attention
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Change 1: Converts builtin tensors → Torch tensors when entering the loop body Change 2: Ensures Torch tensors → builtin tensors when yielding back to the loop condition Without these fixes, the conversion would fail when while loops carry tensor values Also modified basic_test.py FILECHECK statements. Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Better documentation for AtenFlexAttentionOp 2. Function referece added as attributes to aten.flex_attention 3. Updates to _import_hop_flex_attention reflecting latest changes of module import. 4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr Signed-off-by: Keshav Vinayak Jha <[email protected]>
Remove note about method usage for HOPs.
Removed TODO note for grouped query attention support in the docstring and comments.
095cb61 to
5e024f6
Compare
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Torch_AtenFlexAttentionOpwith 6 operands (query, key, value, scale, enable_gqa, return_lse) and 2 optional attributes (score_mod_fn, mask_mod_fn) for function references._import_hop_flex_attention) correctly extracts score/mask modification functions fromget_attrnodes using module IDs, following the while_loop HOP pattern.kernel_optionsperformance tuning parameters.