-
Notifications
You must be signed in to change notification settings - Fork 112
Flash decoding round 1; AIMIGRAPHX-242 #4393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
4cd9f64
curr
bdevorem ed54478
first stabs
bdevorem ae80823
lse and kernel2
bdevorem 61bd398
env var, clang format
bdevorem 94353e9
test case
bdevorem 45252cf
remove prints, compiler warnings, don't rewrite if LSE
bdevorem 7106627
more tests; way to set splits from tests instead of env var
bdevorem fc371c6
format
bdevorem 300f5b4
tidy
bdevorem 12bbacf
tidy
bdevorem 172e9ee
tidy
bdevorem af612a8
formatttt
bdevorem 7ee88df
fix value of
bdevorem c49e900
use get_ins_param_map
bdevorem 151fdb6
edit documentation and change how num_splits is represented in the class
bdevorem 0c77e30
verify tests, updated comments/errors
bdevorem 2340dc9
format
bdevorem 9c1fd01
one more format issue
bdevorem da6df82
add fp16 and fp32 tests; fix reshape bug, fix math in second kernel
bdevorem 5c79e00
format
bdevorem 2c589c8
fix expected values in test and more formatting
bdevorem b74cc18
format
bdevorem 9d8427a
comment
bdevorem 3ee53d8
format
bdevorem 7ea5a1c
Merge branch 'develop' into bdevorem/flash-decoding
bdevorem cd9f60e
update var name, add back 4D test and update rocmlir commit for 4D test
bdevorem 80465ec
do not apply flash decoding if n is not divisible by g
bdevorem 59c8b75
add rocmlir commit that fixes tosatorock for 5d tensors; AIMIGRAPHX-242
bdevorem 8c55d75
Merge branch 'develop' into bdevorem/flash-decoding
bdevorem cdb764e
change back rocMLIR commit; might have caused external failures
bdevorem d182146
flash decoding + input fusion
bdevorem ec180df
output fusion 3d test
bdevorem 7da4b0e
remove rocmlir commit and comment out tests that need rocmlir fix
bdevorem 088cfda
Merge branch 'develop' into bdevorem/flash-decoding
bdevorem 23c2474
format
bdevorem 4499605
merge conflicts w gqa refactor
bdevorem 4c7b439
format
bdevorem 60b0d2d
format
bdevorem File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
|
|
||
| #include "verify_program.hpp" | ||
| #include <migraphx/program.hpp> | ||
| #include <migraphx/generate.hpp> | ||
| #include <migraphx/make_op.hpp> | ||
|
|
||
| template <migraphx::shape::type_t DType> | ||
| struct test_attention_flash_decoding_3d : verify_program<test_attention_flash_decoding_3d<DType>> | ||
| { | ||
| migraphx::program create_program() const | ||
| { | ||
| // 3D Shape: [batch, sequence_length, head_dim] | ||
| migraphx::shape s_3d{DType, {1, 256, 256}}; | ||
|
|
||
| migraphx::program p1; | ||
| auto* mm = p1.get_main_module(); | ||
| auto a = mm->add_parameter("q", s_3d); | ||
| auto b = mm->add_parameter("k", s_3d); | ||
| auto b1 = mm->add_parameter("v", s_3d); | ||
| b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); | ||
| b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); | ||
| auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); | ||
| auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); | ||
| rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), | ||
| rmax); | ||
| auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); | ||
| auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); | ||
| auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); | ||
| rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), | ||
| rsum); | ||
| auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); | ||
| auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); | ||
| mm->add_return({gemm2}); | ||
| return p1; | ||
| } | ||
| }; | ||
|
|
||
| // These tests are not run by default currently; the env vars below need to be set: | ||
| // MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor | ||
| // MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention | ||
| template struct test_attention_flash_decoding_3d<migraphx::shape::half_type>; | ||
| template struct test_attention_flash_decoding_3d<migraphx::shape::bf16_type>; | ||
| template struct test_attention_flash_decoding_3d<migraphx::shape::float_type>; |
110 changes: 110 additions & 0 deletions
110
test/verify/test_attention_flash_decoding_3d_input_fusion.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
|
|
||
| #include "verify_program.hpp" | ||
| #include <migraphx/program.hpp> | ||
| #include <migraphx/generate.hpp> | ||
| #include <migraphx/make_op.hpp> | ||
|
|
||
| template <migraphx::shape::type_t DType> | ||
| struct test_attention_flash_decoding_3d_input_fusion | ||
| : verify_program<test_attention_flash_decoding_3d_input_fusion<DType>> | ||
| { | ||
| migraphx::program create_program() const | ||
| { | ||
| // 3D Shape: [batch, sequence_length, head_dim] | ||
| migraphx::shape s_3d{DType, {1, 256, 256}}; | ||
|
|
||
| migraphx::program p1; | ||
| auto* mm = p1.get_main_module(); | ||
|
|
||
| // Input parameters | ||
| auto q_input = mm->add_parameter("q", s_3d); | ||
| auto k_input = mm->add_parameter("k", s_3d); | ||
| auto v_input = mm->add_parameter("v", s_3d); | ||
|
|
||
| // Bias parameters for input fusion | ||
| auto q_bias = mm->add_parameter("q_bias", s_3d); | ||
| auto k_bias = mm->add_parameter("k_bias", s_3d); | ||
| auto v_bias = mm->add_parameter("v_bias", s_3d); | ||
|
|
||
| // Scale parameter (typically 1/sqrt(head_dim)) | ||
| migraphx::shape scale_shape{DType, {1}}; | ||
| auto scale = mm->add_parameter("scale", scale_shape); | ||
|
|
||
| // Input fusion operations | ||
| // Add bias to Q, K, V | ||
| auto q_with_bias = mm->add_instruction(migraphx::make_op("add"), q_input, q_bias); | ||
| auto k_with_bias = mm->add_instruction(migraphx::make_op("add"), k_input, k_bias); | ||
| auto v_with_bias = mm->add_instruction(migraphx::make_op("add"), v_input, v_bias); | ||
|
|
||
| // Scale Q (common in attention mechanisms) | ||
| scale = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scale); | ||
| auto q_scaled = mm->add_instruction(migraphx::make_op("mul"), q_with_bias, scale); | ||
|
|
||
| // Apply activation (optional input fusion) | ||
| auto q_activated = mm->add_instruction(migraphx::make_op("tanh"), q_scaled); | ||
| auto k_activated = mm->add_instruction(migraphx::make_op("tanh"), k_with_bias); | ||
| auto v_activated = mm->add_instruction(migraphx::make_op("tanh"), v_with_bias); | ||
|
|
||
| // Now perform the attention mechanism with fused inputs | ||
| // Transpose K and V for matrix multiplication | ||
| auto k_transposed = mm->add_instruction( | ||
| migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), k_activated); | ||
| auto v_transposed = mm->add_instruction( | ||
| migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), v_activated); | ||
|
|
||
| // Compute attention scores: Q @ K^T | ||
| auto scores = mm->add_instruction(migraphx::make_op("dot"), q_activated, k_transposed); | ||
|
|
||
| // Apply softmax | ||
| auto scores_max = | ||
| mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), scores); | ||
| scores_max = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_max); | ||
| auto scores_sub = mm->add_instruction(migraphx::make_op("sub"), scores, scores_max); | ||
| auto scores_exp = mm->add_instruction(migraphx::make_op("exp"), scores_sub); | ||
| auto scores_sum = | ||
| mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), scores_exp); | ||
| scores_sum = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_sum); | ||
| auto attention_weights = | ||
| mm->add_instruction(migraphx::make_op("div"), scores_exp, scores_sum); | ||
|
|
||
| // Apply attention weights to values: attention_weights @ V^T | ||
| auto output = | ||
| mm->add_instruction(migraphx::make_op("dot"), attention_weights, v_transposed); | ||
|
|
||
| mm->add_return({output}); | ||
| return p1; | ||
| } | ||
| }; | ||
|
|
||
| // These tests are not run by default currently; the env vars below need to be set: | ||
| // MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor | ||
| // MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention | ||
| template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::half_type>; | ||
| template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::bf16_type>; | ||
| template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::float_type>; |
166 changes: 166 additions & 0 deletions
166
test/verify/test_attention_flash_decoding_3d_output_fusion.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
|
|
||
| #include "verify_program.hpp" | ||
| #include <migraphx/program.hpp> | ||
| #include <migraphx/generate.hpp> | ||
| #include <migraphx/make_op.hpp> | ||
|
|
||
| template <migraphx::shape::type_t DType> | ||
| struct test_attention_flash_decoding_3d_output_fusion | ||
| : verify_program<test_attention_flash_decoding_3d_output_fusion<DType>> | ||
| { | ||
| migraphx::program create_program() const | ||
| { | ||
| // 3D Shape: [batch, sequence_length, head_dim] | ||
| migraphx::shape s_3d{DType, {1, 256, 256}}; | ||
|
|
||
| migraphx::program p1; | ||
| auto* mm = p1.get_main_module(); | ||
|
|
||
| // Input parameters for attention | ||
| auto q = mm->add_parameter("q", s_3d); | ||
| auto k = mm->add_parameter("k", s_3d); | ||
| auto v = mm->add_parameter("v", s_3d); | ||
|
|
||
| // Parameters for output fusion | ||
| // Output projection weight matrix | ||
| migraphx::shape proj_weight_shape{DType, {256, 256}}; | ||
| auto output_proj_weight = mm->add_parameter("output_proj_weight", proj_weight_shape); | ||
|
|
||
| // Output bias | ||
| migraphx::shape bias_shape{DType, {256}}; | ||
| auto output_bias = mm->add_parameter("output_bias", bias_shape); | ||
|
|
||
| // Residual input for skip connection | ||
| auto residual = mm->add_parameter("residual", s_3d); | ||
|
|
||
| // Layer norm parameters (gamma and beta) | ||
| auto ln_gamma = mm->add_parameter("ln_gamma", bias_shape); | ||
| auto ln_beta = mm->add_parameter("ln_beta", bias_shape); | ||
|
|
||
| // Gate for gated output | ||
| auto output_gate = mm->add_parameter("output_gate", s_3d); | ||
|
|
||
| // Standard attention mechanism (no input fusion) | ||
| // Transpose K and V for matrix multiplication | ||
| auto k_transposed = | ||
| mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), k); | ||
| auto v_transposed = | ||
| mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), v); | ||
|
|
||
| // Compute attention scores: Q @ K^T | ||
| auto scores = mm->add_instruction(migraphx::make_op("dot"), q, k_transposed); | ||
|
|
||
| // Apply softmax | ||
| auto scores_max = | ||
| mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), scores); | ||
| scores_max = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_max); | ||
| auto scores_sub = mm->add_instruction(migraphx::make_op("sub"), scores, scores_max); | ||
| auto scores_exp = mm->add_instruction(migraphx::make_op("exp"), scores_sub); | ||
| auto scores_sum = | ||
| mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), scores_exp); | ||
| scores_sum = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_sum); | ||
| auto attention_weights = | ||
| mm->add_instruction(migraphx::make_op("div"), scores_exp, scores_sum); | ||
|
|
||
| // Apply attention weights to values: attention_weights @ V^T | ||
| auto attention_output = | ||
| mm->add_instruction(migraphx::make_op("dot"), attention_weights, v_transposed); | ||
|
|
||
| // OUTPUT FUSION OPERATIONS START HERE | ||
|
|
||
| // 1. Output projection (linear transformation) | ||
| // Reshape for matrix multiplication with projection weight | ||
| auto attn_reshaped = mm->add_instruction( | ||
| migraphx::make_op("reshape", {{"dims", {256, 256}}}), attention_output); | ||
| auto projected = | ||
| mm->add_instruction(migraphx::make_op("dot"), attn_reshaped, output_proj_weight); | ||
| projected = | ||
| mm->add_instruction(migraphx::make_op("reshape", {{"dims", s_3d.lens()}}), projected); | ||
|
|
||
| // 2. Add output bias | ||
| output_bias = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), output_bias); | ||
| auto with_bias = mm->add_instruction(migraphx::make_op("add"), projected, output_bias); | ||
|
|
||
| // 3. Apply dropout-like operation (using a gate for deterministic testing) | ||
| auto gate_sigmoid = mm->add_instruction(migraphx::make_op("sigmoid"), output_gate); | ||
| auto gated = mm->add_instruction(migraphx::make_op("mul"), with_bias, gate_sigmoid); | ||
|
|
||
| // 4. Add residual connection | ||
| auto with_residual = mm->add_instruction(migraphx::make_op("add"), gated, residual); | ||
|
|
||
| // 5. Layer normalization | ||
| // Compute mean | ||
| auto mean = | ||
| mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), with_residual); | ||
| mean = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), | ||
| mean); | ||
| auto centered = mm->add_instruction(migraphx::make_op("sub"), with_residual, mean); | ||
|
|
||
| // Compute variance | ||
| auto squared = mm->add_instruction(migraphx::make_op("mul"), centered, centered); | ||
| auto variance = | ||
| mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), squared); | ||
| variance = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), variance); | ||
|
|
||
| // Add epsilon for numerical stability | ||
| migraphx::shape epsilon_shape{DType, {1}}; | ||
| auto epsilon = mm->add_literal(migraphx::literal{epsilon_shape, {1e-5f}}); | ||
| epsilon = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), epsilon); | ||
| auto var_plus_eps = mm->add_instruction(migraphx::make_op("add"), variance, epsilon); | ||
|
|
||
| // Compute standard deviation | ||
| auto std_dev = mm->add_instruction(migraphx::make_op("sqrt"), var_plus_eps); | ||
|
|
||
| // Normalize | ||
| auto normalized = mm->add_instruction(migraphx::make_op("div"), centered, std_dev); | ||
|
|
||
| // Scale and shift | ||
| ln_gamma = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), ln_gamma); | ||
| ln_beta = mm->add_instruction( | ||
| migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), ln_beta); | ||
| auto scaled = mm->add_instruction(migraphx::make_op("mul"), normalized, ln_gamma); | ||
| auto ln_output = mm->add_instruction(migraphx::make_op("add"), scaled, ln_beta); | ||
|
|
||
| // 6. Final activation (ReLU) | ||
| auto final_output = mm->add_instruction(migraphx::make_op("relu"), ln_output); | ||
|
|
||
| mm->add_return({final_output}); | ||
| return p1; | ||
| } | ||
| }; | ||
|
|
||
| // These tests are not run by default currently; the env vars below need to be set: | ||
| // MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor | ||
| // MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention | ||
| template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::half_type>; | ||
| template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::bf16_type>; | ||
| template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::float_type>; |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.