Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
4cd9f64
curr
bdevorem Oct 10, 2025
ed54478
first stabs
bdevorem Oct 17, 2025
ae80823
lse and kernel2
bdevorem Oct 17, 2025
61bd398
env var, clang format
bdevorem Oct 20, 2025
94353e9
test case
bdevorem Oct 20, 2025
45252cf
remove prints, compiler warnings, don't rewrite if LSE
bdevorem Oct 20, 2025
7106627
more tests; way to set splits from tests instead of env var
bdevorem Oct 22, 2025
fc371c6
format
bdevorem Oct 22, 2025
300f5b4
tidy
bdevorem Oct 22, 2025
12bbacf
tidy
bdevorem Oct 22, 2025
172e9ee
tidy
bdevorem Oct 22, 2025
af612a8
formatttt
bdevorem Oct 22, 2025
7ee88df
fix value of
bdevorem Oct 29, 2025
c49e900
use get_ins_param_map
bdevorem Oct 29, 2025
151fdb6
edit documentation and change how num_splits is represented in the class
bdevorem Oct 29, 2025
0c77e30
verify tests, updated comments/errors
bdevorem Oct 29, 2025
2340dc9
format
bdevorem Oct 30, 2025
9c1fd01
one more format issue
bdevorem Oct 30, 2025
da6df82
add fp16 and fp32 tests; fix reshape bug, fix math in second kernel
bdevorem Nov 3, 2025
5c79e00
format
bdevorem Nov 3, 2025
2c589c8
fix expected values in test and more formatting
bdevorem Nov 3, 2025
b74cc18
format
bdevorem Nov 3, 2025
9d8427a
comment
bdevorem Nov 3, 2025
3ee53d8
format
bdevorem Nov 3, 2025
7ea5a1c
Merge branch 'develop' into bdevorem/flash-decoding
bdevorem Nov 5, 2025
cd9f60e
update var name, add back 4D test and update rocmlir commit for 4D test
bdevorem Nov 12, 2025
80465ec
do not apply flash decoding if n is not divisible by g
bdevorem Nov 12, 2025
59c8b75
add rocmlir commit that fixes tosatorock for 5d tensors; AIMIGRAPHX-242
bdevorem Nov 12, 2025
8c55d75
Merge branch 'develop' into bdevorem/flash-decoding
bdevorem Nov 14, 2025
cdb764e
change back rocMLIR commit; might have caused external failures
bdevorem Nov 14, 2025
d182146
flash decoding + input fusion
bdevorem Nov 17, 2025
ec180df
output fusion 3d test
bdevorem Nov 17, 2025
7da4b0e
remove rocmlir commit and comment out tests that need rocmlir fix
bdevorem Nov 19, 2025
088cfda
Merge branch 'develop' into bdevorem/flash-decoding
bdevorem Nov 19, 2025
23c2474
format
bdevorem Nov 19, 2025
4499605
merge conflicts w gqa refactor
bdevorem Nov 19, 2025
4c7b439
format
bdevorem Nov 19, 2025
60b0d2d
format
bdevorem Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/reference/MIGraphX-dev-env-vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ Model performance tunable variables change the compilation behavior of a model.

| Default: Split-k performance configurations are turned off.

* - | ``MIGRAPHX_FLASH_DECODING_NUM_SPLITS``
| Turns on flash decoding for attention fusion and sets the number of splits along the key-value sequence dimension.

- | ``0``: Flash decoding is turned off (i.e., number of splits is 0).
| ``N`` (where N > 1): Enables flash decoding with N splits along the key-value sequence dimension. For example, ``2`` enables flash decoding with 2 splits, ``4`` with 4 splits, etc.

| Default: flash decoding is turned off.

* - | ``MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT``
| When set, FP16 is not converted to FP32 in the ``InstanceNormalization`` ONNX operator.

Expand Down
417 changes: 417 additions & 0 deletions src/fuse_attention.cpp

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/include/migraphx/fuse_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <migraphx/config.hpp>
#include <string>
#include <optional>
#include <cstddef>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -35,6 +37,7 @@ struct module_pass_manager;

struct MIGRAPHX_EXPORT fuse_attention
{
std::optional<std::size_t> flash_decoding_num_splits = std::nullopt;
bool attn_enabled = false;

std::string name() const { return "fuse_attention"; }
Expand Down
298 changes: 290 additions & 8 deletions test/fuse_attention.cpp

Large diffs are not rendered by default.

66 changes: 66 additions & 0 deletions test/verify/test_attention_flash_decoding_3d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <migraphx::shape::type_t DType>
struct test_attention_flash_decoding_3d : verify_program<test_attention_flash_decoding_3d<DType>>
{
migraphx::program create_program() const
{
// 3D Shape: [batch, sequence_length, head_dim]
migraphx::shape s_3d{DType, {1, 256, 256}};

migraphx::program p1;
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("q", s_3d);
auto b = mm->add_parameter("k", s_3d);
auto b1 = mm->add_parameter("v", s_3d);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b);
b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1);
rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}),
rmax);
auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax);
auto exp = mm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp);
rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}),
rsum);
auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1);
mm->add_return({gemm2});
return p1;
}
};

// These tests are not run by default currently; the env vars below need to be set:
// MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor
// MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention
template struct test_attention_flash_decoding_3d<migraphx::shape::half_type>;
template struct test_attention_flash_decoding_3d<migraphx::shape::bf16_type>;
template struct test_attention_flash_decoding_3d<migraphx::shape::float_type>;
110 changes: 110 additions & 0 deletions test/verify/test_attention_flash_decoding_3d_input_fusion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <migraphx::shape::type_t DType>
struct test_attention_flash_decoding_3d_input_fusion
: verify_program<test_attention_flash_decoding_3d_input_fusion<DType>>
{
migraphx::program create_program() const
{
// 3D Shape: [batch, sequence_length, head_dim]
migraphx::shape s_3d{DType, {1, 256, 256}};

migraphx::program p1;
auto* mm = p1.get_main_module();

// Input parameters
auto q_input = mm->add_parameter("q", s_3d);
auto k_input = mm->add_parameter("k", s_3d);
auto v_input = mm->add_parameter("v", s_3d);

// Bias parameters for input fusion
auto q_bias = mm->add_parameter("q_bias", s_3d);
auto k_bias = mm->add_parameter("k_bias", s_3d);
auto v_bias = mm->add_parameter("v_bias", s_3d);

// Scale parameter (typically 1/sqrt(head_dim))
migraphx::shape scale_shape{DType, {1}};
auto scale = mm->add_parameter("scale", scale_shape);

// Input fusion operations
// Add bias to Q, K, V
auto q_with_bias = mm->add_instruction(migraphx::make_op("add"), q_input, q_bias);
auto k_with_bias = mm->add_instruction(migraphx::make_op("add"), k_input, k_bias);
auto v_with_bias = mm->add_instruction(migraphx::make_op("add"), v_input, v_bias);

// Scale Q (common in attention mechanisms)
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scale);
auto q_scaled = mm->add_instruction(migraphx::make_op("mul"), q_with_bias, scale);

// Apply activation (optional input fusion)
auto q_activated = mm->add_instruction(migraphx::make_op("tanh"), q_scaled);
auto k_activated = mm->add_instruction(migraphx::make_op("tanh"), k_with_bias);
auto v_activated = mm->add_instruction(migraphx::make_op("tanh"), v_with_bias);

// Now perform the attention mechanism with fused inputs
// Transpose K and V for matrix multiplication
auto k_transposed = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), k_activated);
auto v_transposed = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), v_activated);

// Compute attention scores: Q @ K^T
auto scores = mm->add_instruction(migraphx::make_op("dot"), q_activated, k_transposed);

// Apply softmax
auto scores_max =
mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), scores);
scores_max = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_max);
auto scores_sub = mm->add_instruction(migraphx::make_op("sub"), scores, scores_max);
auto scores_exp = mm->add_instruction(migraphx::make_op("exp"), scores_sub);
auto scores_sum =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), scores_exp);
scores_sum = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_sum);
auto attention_weights =
mm->add_instruction(migraphx::make_op("div"), scores_exp, scores_sum);

// Apply attention weights to values: attention_weights @ V^T
auto output =
mm->add_instruction(migraphx::make_op("dot"), attention_weights, v_transposed);

mm->add_return({output});
return p1;
}
};

// These tests are not run by default currently; the env vars below need to be set:
// MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor
// MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention
template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::half_type>;
template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::bf16_type>;
template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::float_type>;
166 changes: 166 additions & 0 deletions test/verify/test_attention_flash_decoding_3d_output_fusion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <migraphx::shape::type_t DType>
struct test_attention_flash_decoding_3d_output_fusion
: verify_program<test_attention_flash_decoding_3d_output_fusion<DType>>
{
migraphx::program create_program() const
{
// 3D Shape: [batch, sequence_length, head_dim]
migraphx::shape s_3d{DType, {1, 256, 256}};

migraphx::program p1;
auto* mm = p1.get_main_module();

// Input parameters for attention
auto q = mm->add_parameter("q", s_3d);
auto k = mm->add_parameter("k", s_3d);
auto v = mm->add_parameter("v", s_3d);

// Parameters for output fusion
// Output projection weight matrix
migraphx::shape proj_weight_shape{DType, {256, 256}};
auto output_proj_weight = mm->add_parameter("output_proj_weight", proj_weight_shape);

// Output bias
migraphx::shape bias_shape{DType, {256}};
auto output_bias = mm->add_parameter("output_bias", bias_shape);

// Residual input for skip connection
auto residual = mm->add_parameter("residual", s_3d);

// Layer norm parameters (gamma and beta)
auto ln_gamma = mm->add_parameter("ln_gamma", bias_shape);
auto ln_beta = mm->add_parameter("ln_beta", bias_shape);

// Gate for gated output
auto output_gate = mm->add_parameter("output_gate", s_3d);

// Standard attention mechanism (no input fusion)
// Transpose K and V for matrix multiplication
auto k_transposed =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), k);
auto v_transposed =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), v);

// Compute attention scores: Q @ K^T
auto scores = mm->add_instruction(migraphx::make_op("dot"), q, k_transposed);

// Apply softmax
auto scores_max =
mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), scores);
scores_max = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_max);
auto scores_sub = mm->add_instruction(migraphx::make_op("sub"), scores, scores_max);
auto scores_exp = mm->add_instruction(migraphx::make_op("exp"), scores_sub);
auto scores_sum =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), scores_exp);
scores_sum = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_sum);
auto attention_weights =
mm->add_instruction(migraphx::make_op("div"), scores_exp, scores_sum);

// Apply attention weights to values: attention_weights @ V^T
auto attention_output =
mm->add_instruction(migraphx::make_op("dot"), attention_weights, v_transposed);

// OUTPUT FUSION OPERATIONS START HERE

// 1. Output projection (linear transformation)
// Reshape for matrix multiplication with projection weight
auto attn_reshaped = mm->add_instruction(
migraphx::make_op("reshape", {{"dims", {256, 256}}}), attention_output);
auto projected =
mm->add_instruction(migraphx::make_op("dot"), attn_reshaped, output_proj_weight);
projected =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", s_3d.lens()}}), projected);

// 2. Add output bias
output_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), output_bias);
auto with_bias = mm->add_instruction(migraphx::make_op("add"), projected, output_bias);

// 3. Apply dropout-like operation (using a gate for deterministic testing)
auto gate_sigmoid = mm->add_instruction(migraphx::make_op("sigmoid"), output_gate);
auto gated = mm->add_instruction(migraphx::make_op("mul"), with_bias, gate_sigmoid);

// 4. Add residual connection
auto with_residual = mm->add_instruction(migraphx::make_op("add"), gated, residual);

// 5. Layer normalization
// Compute mean
auto mean =
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), with_residual);
mean = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}),
mean);
auto centered = mm->add_instruction(migraphx::make_op("sub"), with_residual, mean);

// Compute variance
auto squared = mm->add_instruction(migraphx::make_op("mul"), centered, centered);
auto variance =
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), squared);
variance = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), variance);

// Add epsilon for numerical stability
migraphx::shape epsilon_shape{DType, {1}};
auto epsilon = mm->add_literal(migraphx::literal{epsilon_shape, {1e-5f}});
epsilon = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), epsilon);
auto var_plus_eps = mm->add_instruction(migraphx::make_op("add"), variance, epsilon);

// Compute standard deviation
auto std_dev = mm->add_instruction(migraphx::make_op("sqrt"), var_plus_eps);

// Normalize
auto normalized = mm->add_instruction(migraphx::make_op("div"), centered, std_dev);

// Scale and shift
ln_gamma = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), ln_gamma);
ln_beta = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), ln_beta);
auto scaled = mm->add_instruction(migraphx::make_op("mul"), normalized, ln_gamma);
auto ln_output = mm->add_instruction(migraphx::make_op("add"), scaled, ln_beta);

// 6. Final activation (ReLU)
auto final_output = mm->add_instruction(migraphx::make_op("relu"), ln_output);

mm->add_return({final_output});
return p1;
}
};

// These tests are not run by default currently; the env vars below need to be set:
// MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor
// MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention
template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::half_type>;
template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::bf16_type>;
template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::float_type>;
Loading
Loading