From 68aef7b396580e208534eb03cead9d53a7996a0d Mon Sep 17 00:00:00 2001 From: amd-khushbu Date: Wed, 21 Jan 2026 18:00:09 +0000 Subject: [PATCH 1/5] initial commit --- .../38_block_scale_gemm/CMakeLists.txt | 42 +++--- .../38_block_scale_gemm/gemm_quant.cpp | 126 +++++++++--------- .../run_gemm_quant_example.inc | 18 +-- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 4 +- 4 files changed, 91 insertions(+), 99 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index ec536f72878..45323f8acf4 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -13,27 +13,27 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(${EXE_NAME} gemm_quant.cpp gemm_abquant_quantgrouped.cpp - gemm_aquant_quantgrouped.cpp - gemm_aquant_quantgrouped_preshufflequant.cpp - gemm_bquant_quantgrouped_bf8i4.cpp - gemm_bquant_quantgrouped_fp8i4.cpp - gemm_bquant_quantgrouped_bf16mxfp4.cpp - gemm_bquant_quantgrouped_bf8.cpp - gemm_bquant_quantgrouped_fp8.cpp - gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp - gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp - gemm_bquant_quantgrouped_preshuffleb_bf8.cpp - gemm_bquant_quantgrouped_preshuffleb_fp8.cpp - gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp - gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp - gemm_bquant_quantgrouped_preshufflequant_bf8.cpp - gemm_bquant_quantgrouped_preshufflequant_fp8.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp - gemm_quant_rowcol.cpp - gemm_quant_tensor.cpp +# gemm_aquant_quantgrouped.cpp +# gemm_aquant_quantgrouped_preshufflequant.cpp +# gemm_bquant_quantgrouped_bf8i4.cpp +# gemm_bquant_quantgrouped_fp8i4.cpp +# gemm_bquant_quantgrouped_bf16mxfp4.cpp +# gemm_bquant_quantgrouped_bf8.cpp +# gemm_bquant_quantgrouped_fp8.cpp +# gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp +# gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp +# gemm_bquant_quantgrouped_preshuffleb_bf8.cpp +# gemm_bquant_quantgrouped_preshuffleb_fp8.cpp +# gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp +# gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp +# gemm_bquant_quantgrouped_preshufflequant_bf8.cpp +# gemm_bquant_quantgrouped_preshufflequant_fp8.cpp +# gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp +# gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp +# gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp +# gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp +# gemm_quant_rowcol.cpp +# gemm_quant_tensor.cpp ) target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 8de58b0a309..7adabb53b9a 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -97,48 +97,48 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser) void abquant_quantgrouped_instance_factory( std::unordered_map>& lut); -void aquant_quantgrouped_instance_factory( - std::unordered_map>& lut); -void aquant_quantgrouped_preshufflequant_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_fp8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_bf8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_bf16fp4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_fp8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_fp8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( - std::unordered_map>& lut); -void quant_rowcol_instance_factory( - std::unordered_map>& lut); -void quant_tensor_instance_factory( - std::unordered_map>& lut); +// void aquant_quantgrouped_instance_factory( +// std::unordered_map>& lut); +// void aquant_quantgrouped_preshufflequant_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_fp8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_bf8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_fp8i4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_bf8i4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_bf16fp4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshuffleb_fp8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshuffleb_bf8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshufflequant_fp8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshufflequant_bf8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( +// std::unordered_map>& lut); +// void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( +// std::unordered_map>& lut); +// void quant_rowcol_instance_factory( +// std::unordered_map>& lut); +// void quant_tensor_instance_factory( +// std::unordered_map>& lut); int main(int argc, char* argv[]) { @@ -155,27 +155,27 @@ int main(int argc, char* argv[]) std::unordered_map> lut; abquant_quantgrouped_instance_factory(lut); - aquant_quantgrouped_instance_factory(lut); - aquant_quantgrouped_preshufflequant_instance_factory(lut); - bquant_quantgrouped_fp8_instance_factory(lut); - bquant_quantgrouped_bf8_instance_factory(lut); - bquant_quantgrouped_fp8i4_instance_factory(lut); - bquant_quantgrouped_bf8i4_instance_factory(lut); - bquant_quantgrouped_bf16fp4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut); - bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut); - bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut); - bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut); - bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut); - bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut); - bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut); - quant_rowcol_instance_factory(lut); - quant_tensor_instance_factory(lut); + // aquant_quantgrouped_instance_factory(lut); + // aquant_quantgrouped_preshufflequant_instance_factory(lut); + // bquant_quantgrouped_fp8_instance_factory(lut); + // bquant_quantgrouped_bf8_instance_factory(lut); + // bquant_quantgrouped_fp8i4_instance_factory(lut); + // bquant_quantgrouped_bf8i4_instance_factory(lut); + // bquant_quantgrouped_bf16fp4_instance_factory(lut); + // bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut); + // bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut); + // bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut); + // bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut); + // bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut); + // bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut); + // bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut); + // bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut); + // bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut); + // bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut); + // bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut); + // bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut); + // quant_rowcol_instance_factory(lut); + // quant_tensor_instance_factory(lut); auto key = gen_lut_key(arg_parser); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 607c53d9afd..1d30a1946e8 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -55,8 +55,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str BLayout, CLayout, QuantMode, - AQLayout, // for AQLayout - BQLayout, // for BQLayout + AQLayout, + BQLayout, false, GemmConfig::DoubleSmemBuffer>; @@ -537,21 +537,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, // Create BQ tensor with appropriate shape std::unique_ptr> bq_tensor_ptr = nullptr; if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) - { - bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + QuantMode == ck_tile::QuantType::ABQuantGrouped || + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { bq_tensor_ptr = std::make_unique>( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } - else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) - { - bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout))); - } std::random_device rd; std::mt19937 gen(rd()); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index fd94dfb6b3d..96462b6a14f 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -610,10 +610,10 @@ struct QuantGemmKernel static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto block_k = TilePartitioner::KPerBlock; + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; return make_tile_window( aq_tensor_view, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_m, 0}); } else if constexpr(kQuantType == QuantType::RowColQuant) From 934d04b93e3a794e6c5ee7c619bd61fa45709044 Mon Sep 17 00:00:00 2001 From: amd-khushbu Date: Wed, 21 Jan 2026 23:22:56 +0000 Subject: [PATCH 2/5] preshuffleQuant support for ABQuant --- .../38_block_scale_gemm/CMakeLists.txt | 42 +++--- .../gemm_abquant_quantgrouped.cpp | 60 ++++++++ .../38_block_scale_gemm/gemm_quant.cpp | 126 ++++++++-------- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 14 +- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 15 +- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 27 +++- .../block_universal_gemm_as_aquant_bs_cr.hpp | 16 +-- .../block_universal_gemm_as_bs_bquant_cr.hpp | 27 ++-- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 134 ++++++++---------- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 18 ++- .../gemm_aquant_pipeline_ag_bg_cr_base.hpp | 18 +-- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 22 +-- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 22 +-- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 30 ++-- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 4 +- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 26 ++-- .../pipeline/gemm_quant_pipeline_problem.hpp | 8 +- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 18 +-- 18 files changed, 354 insertions(+), 273 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 45323f8acf4..ec536f72878 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -13,27 +13,27 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(${EXE_NAME} gemm_quant.cpp gemm_abquant_quantgrouped.cpp -# gemm_aquant_quantgrouped.cpp -# gemm_aquant_quantgrouped_preshufflequant.cpp -# gemm_bquant_quantgrouped_bf8i4.cpp -# gemm_bquant_quantgrouped_fp8i4.cpp -# gemm_bquant_quantgrouped_bf16mxfp4.cpp -# gemm_bquant_quantgrouped_bf8.cpp -# gemm_bquant_quantgrouped_fp8.cpp -# gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp -# gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp -# gemm_bquant_quantgrouped_preshuffleb_bf8.cpp -# gemm_bquant_quantgrouped_preshuffleb_fp8.cpp -# gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp -# gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp -# gemm_bquant_quantgrouped_preshufflequant_bf8.cpp -# gemm_bquant_quantgrouped_preshufflequant_fp8.cpp -# gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp -# gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp -# gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp -# gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp -# gemm_quant_rowcol.cpp -# gemm_quant_tensor.cpp + gemm_aquant_quantgrouped.cpp + gemm_aquant_quantgrouped_preshufflequant.cpp + gemm_bquant_quantgrouped_bf8i4.cpp + gemm_bquant_quantgrouped_fp8i4.cpp + gemm_bquant_quantgrouped_bf16mxfp4.cpp + gemm_bquant_quantgrouped_bf8.cpp + gemm_bquant_quantgrouped_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp + gemm_quant_rowcol.cpp + gemm_quant_tensor.cpp ) target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 155f19881ea..395a984b4ac 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -129,4 +129,64 @@ void abquant_quantgrouped_instance_factory( BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "non-preshuffleb", + "preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "non-preshuffleb", + "preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 7adabb53b9a..8de58b0a309 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -97,48 +97,48 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser) void abquant_quantgrouped_instance_factory( std::unordered_map>& lut); -// void aquant_quantgrouped_instance_factory( -// std::unordered_map>& lut); -// void aquant_quantgrouped_preshufflequant_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_fp8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_bf8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_fp8i4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_bf8i4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_bf16fp4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshuffleb_fp8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshuffleb_bf8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshufflequant_fp8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshufflequant_bf8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( -// std::unordered_map>& lut); -// void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( -// std::unordered_map>& lut); -// void quant_rowcol_instance_factory( -// std::unordered_map>& lut); -// void quant_tensor_instance_factory( -// std::unordered_map>& lut); +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut); +void aquant_quantgrouped_preshufflequant_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf16fp4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut); +void quant_rowcol_instance_factory( + std::unordered_map>& lut); +void quant_tensor_instance_factory( + std::unordered_map>& lut); int main(int argc, char* argv[]) { @@ -155,27 +155,27 @@ int main(int argc, char* argv[]) std::unordered_map> lut; abquant_quantgrouped_instance_factory(lut); - // aquant_quantgrouped_instance_factory(lut); - // aquant_quantgrouped_preshufflequant_instance_factory(lut); - // bquant_quantgrouped_fp8_instance_factory(lut); - // bquant_quantgrouped_bf8_instance_factory(lut); - // bquant_quantgrouped_fp8i4_instance_factory(lut); - // bquant_quantgrouped_bf8i4_instance_factory(lut); - // bquant_quantgrouped_bf16fp4_instance_factory(lut); - // bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut); - // bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut); - // bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut); - // bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut); - // bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut); - // bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut); - // bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut); - // bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut); - // bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut); - // bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut); - // bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut); - // bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut); - // quant_rowcol_instance_factory(lut); - // quant_tensor_instance_factory(lut); + aquant_quantgrouped_instance_factory(lut); + aquant_quantgrouped_preshufflequant_instance_factory(lut); + bquant_quantgrouped_fp8_instance_factory(lut); + bquant_quantgrouped_bf8_instance_factory(lut); + bquant_quantgrouped_fp8i4_instance_factory(lut); + bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_bf16fp4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut); + bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut); + bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut); + quant_rowcol_instance_factory(lut); + quant_tensor_instance_factory(lut); auto key = gen_lut_key(arg_parser); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 63a51511088..91747782bce 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -127,9 +127,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using QuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; - static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); + static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); @@ -162,12 +162,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg static constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; - static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK; static constexpr index_t QScalesPerBlockRow = - integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1 + integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); // 128 / 128 = 1 static constexpr index_t QScalesPerWarpGemmRow = - integer_divide_ceil(WG::kK, QuantGroupSize::kK); + integer_divide_ceil(WG::kK, BQuantGroupSize::kK); static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read @@ -253,9 +253,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg CBlockTensor::PackedSize>{}; index_t reg_offset = [&]() { - if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN)) { - return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ + kQScale; } else diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index 2b67b65856e..2f8a1b5cf43 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -25,9 +25,9 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using QuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; - static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); + static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); @@ -63,12 +63,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg static constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; - static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK; static constexpr index_t QScalesPerBlockRow = - integer_divide_ceil(KPerBlock, QuantGroupSize::kK); + integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); static constexpr index_t QScalesPerWarpGemmRow = - integer_divide_ceil(WG::kK, QuantGroupSize::kK); + integer_divide_ceil(WG::kK, BQuantGroupSize::kK); static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read @@ -205,9 +205,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg else { index_t reg_offset = [&]() { - if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN)) { - return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * + KPerBlockBQ + kQScale; } else diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index c44d330d139..25a6b98ee2a 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using AQDataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BLayout = remove_cvref_t; using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -134,8 +135,12 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using CDataType = remove_cvref_t; // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; + using OverrideBDataType = std::conditional_t< + std::is_same_v && + std::is_same_v, + ADataType, + BDataType>; + using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; @@ -327,9 +332,25 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase if constexpr(PreshuffleQuant) { - constexpr index_t reg_offset = nIter; + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN > + (NWarp * WarpGemm::kN)) + { + if constexpr(Traits::NPerBlock == + GemmTraits::BQuantGroupSize::kN) + return kQScale; + else + return nIter; // for prefill needs kQscale, for decode needs + // nIter + } + else + { + return nIter; + } + }(); auto pull_from_lane = (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; // cross lane ops uint32_t scale_reg_dword; diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 705a992b526..8aebc0790e9 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -34,7 +34,7 @@ struct AQuantBlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr auto Scheduler = Problem::Scheduler; @@ -43,7 +43,7 @@ struct AQuantBlockUniversalGemmAsBsCr static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK; + static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK; static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WarpGemm = remove_cvref_t())>; @@ -69,20 +69,20 @@ struct AQuantBlockUniversalGemmAsBsCr static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; static constexpr index_t QScalesPerBlockRow = - integer_divide_ceil(KPerBlock, QuantGroupSize::kK); + integer_divide_ceil(KPerBlock, AQuantGroupSize::kK); static constexpr index_t QScalesPerWarpGemmRow = - integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK); + integer_divide_ceil(WarpGemm::kK, AQuantGroupSize::kK); static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; - static_assert(QuantGroupSize::kK % WarpGemm::kK == 0, - "Error! WarpGemm::kK should be a multiple of QuantGroupSize"); + static_assert(AQuantGroupSize::kK % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of AQuantGroupSize"); static_assert(QScalesPerWarpGemmRow == 1, - "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK"); + "Error! AQuantGroupSize shouldn't be smaller than WarpGemm::kK"); static_assert(KIterPerWarp % QScalesPerBlockRow == 0, "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); - static_assert(KPerBlock / QuantGroupSize::kK > 0, + static_assert(KPerBlock / AQuantGroupSize::kK > 0, "Error! Each row of blockgemm should have a separate scale"); static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 03b9dfe34db..e2083cfd2a1 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -36,7 +36,7 @@ struct BQuantBlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr auto Scheduler = Problem::Scheduler; @@ -46,8 +46,8 @@ struct BQuantBlockUniversalGemmAsBsCr static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN; - static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK; + static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN; + static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK; static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WarpGemm = remove_cvref_t())>; @@ -75,20 +75,20 @@ struct BQuantBlockUniversalGemmAsBsCr static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr index_t QScalesPerBlockRow = - integer_divide_ceil(KPerBlock, QuantGroupSize::kK); + integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); static constexpr index_t QScalesPerWarpGemmRow = - integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK); + integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK); static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; - static_assert(QuantGroupSize::kK % WarpGemm::kK == 0, - "Error! WarpGemm::kK should be a multiple of QuantGroupSize"); + static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of BQuantGroupSize"); static_assert(QScalesPerWarpGemmRow == 1, - "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK"); + "Error! BQuantGroupSize shouldn't be smaller than WarpGemm::kK"); static_assert(KIterPerWarp % QScalesPerBlockRow == 0, "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); - static_assert(KPerBlock / QuantGroupSize::kK > 0, + static_assert(KPerBlock / BQuantGroupSize::kK > 0, "Error! Each row of blockgemm should have a separate scale"); static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, @@ -321,11 +321,11 @@ struct BQuantBlockUniversalGemmAsBsCr { // constexpr index_t reg_offset = nIter; constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::QuantGroupSize::kN > + if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN)) { if constexpr(Traits::NPerBlock == - GemmTraits::QuantGroupSize::kN) + GemmTraits::BQuantGroupSize::kN) return kQScale; else return nIter; // for prefill needs kQscale, for decode needs @@ -370,10 +370,11 @@ struct BQuantBlockUniversalGemmAsBsCr { // Multiply bquant with accumulated C constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::QuantGroupSize::kN >= + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock + + GemmTraits::BQuantGroupSize::kN * + Traits::KQPerBlock + kQScale; else { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 96462b6a14f..303b4d48006 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -476,7 +476,9 @@ struct QuantGemmKernel { // Step 1: Create tensor view for AQ const auto& aq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) + if constexpr((kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) && + PreshuffleQuant) { static_assert(std::is_same_v); const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; @@ -571,13 +573,15 @@ struct QuantGemmKernel // Step 2: Create tile window (no padding for AQ) const auto& aq_block_window = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) + if constexpr((kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) && + PreshuffleQuant) { static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; constexpr auto block_m = TilePartitioner::MPerBlock; constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK; constexpr auto tile_window_width = ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); constexpr auto tile_window_height = block_m / warp_m; @@ -587,11 +591,19 @@ struct QuantGemmKernel make_tuple(number{}, number{}), {block_m_idx * tile_window_height, 0}); } - else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) + else if constexpr((kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) && + !PreshuffleQuant) { - using QuantGroupSize = remove_cvref_t; - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + + using AQuantGroupSize = remove_cvref_t; + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK; constexpr auto block_m = TilePartitioner::MPerBlock; + if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v, + "ABQuantGrouped requires RowMajor AQ layout"); + } if constexpr(std::is_same_v) { return make_tile_window(aq_tensor_view, @@ -605,17 +617,6 @@ struct QuantGemmKernel {0, i_m}); } } - else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - return make_tile_window( - aq_tensor_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } else if constexpr(kQuantType == QuantType::RowColQuant) { return make_tile_window(aq_tensor_view, @@ -808,14 +809,15 @@ struct QuantGemmKernel number<1>{}, number<1>{}); } - else if constexpr(kQuantType == QuantType::BQuantGrouped) + else if constexpr(kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) { if constexpr(PreshuffleQuant) { static_assert(std::is_same_v, "PreshuffleQuant with BQuantGrouped currently only supports " "ColumnMajor BQ layout"); - using QuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, @@ -824,48 +826,42 @@ struct QuantGemmKernel TilePartitioner::BlockGemmShape::WarpTile::at(I1), GemmPipeline::GetVectorSizeBQ()>( bq_ptr, - ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN), - QuantGroupSize::kN, + ck_tile::integer_divide_ceil(kargs.N, BQuantGroupSize::kN), + BQuantGroupSize::kN, kargs.QK_B); } else { - using QuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v, + "ABQuantGrouped requires RowMajor AQ layout"); + } if constexpr(std::is_same_v) { return make_naive_tensor_view( bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), - integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), + integer_divide_ceil(kargs.N, BQuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1), number{}, number<1>{}); } else { - static_assert(std::is_same_v); return make_naive_tensor_view( bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), - integer_divide_ceil(kargs.K, QuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), + integer_divide_ceil(kargs.K, BQuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1), number{}, number<1>{}); } } } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); - } else { return nullptr; @@ -881,28 +877,29 @@ struct QuantGemmKernel number{}), {i_m, i_n}); } - else if constexpr(kQuantType == QuantType::BQuantGrouped) + else if constexpr(kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) { - using QuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); // Number of N-dimension quantization groups per block - constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock) - ? TilePartitioner::NPerBlock / QuantGroupSize::kN - : QuantGroupSize::kN / TilePartitioner::NPerBlock; + constexpr auto block_n = (BQuantGroupSize::kN <= TilePartitioner::NPerBlock) + ? TilePartitioner::NPerBlock / BQuantGroupSize::kN + : BQuantGroupSize::kN / TilePartitioner::NPerBlock; // Number of N-dimension elements per warp constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); // Determine how many warps share the same scale in N-dimension - constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n) - ? (warp_n / QuantGroupSize::kN) - : (QuantGroupSize::kN / warp_n); + constexpr auto warp_per_group = (BQuantGroupSize::kN < warp_n) + ? (warp_n / BQuantGroupSize::kN) + : (BQuantGroupSize::kN / warp_n); // Number of K-dimension quantization groups per block - constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto bqk_per_block = TilePartitioner::KPerBlock / BQuantGroupSize::kK; // The pre-shuffled layout flattens warp_n × // bqk_per_block scales per row, Padded up to warp_size @@ -911,25 +908,25 @@ struct QuantGemmKernel ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); // Adapts based on fine vs coarse quantization granularity: - // - Fine-grained (QuantGroupSize::kN < warp_n): + // - Fine-grained (BQuantGroupSize::kN < warp_n): // Multiple quant groups per warp → fewer rows needed per block. // height = block_n / warp_per_group // - // - Coarse-grained (QuantGroupSize::kN >= warp_n): + // - Coarse-grained (BQuantGroupSize::kN >= warp_n): // Each row represents one quant group. // height = block_n constexpr auto tile_window_height = - (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n; + (BQuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n; auto block_n_idx = i_n / TilePartitioner::NPerBlock; // For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ... - if(QuantGroupSize::kN > TilePartitioner::NPerBlock) + if(BQuantGroupSize::kN > TilePartitioner::NPerBlock) { block_n_idx = block_n_idx >> 1; } - if(QuantGroupSize::kN > TilePartitioner::NPerBlock) + if(BQuantGroupSize::kN > TilePartitioner::NPerBlock) { return make_tile_window( bq_tensor_view, @@ -946,17 +943,22 @@ struct QuantGemmKernel } else { + if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v, + "ABQuantGrouped requires RowMajor AQ layout"); + } constexpr auto tensor_dim = - (QuantGroupSize::kN <= TilePartitioner::NPerBlock) - ? TilePartitioner::NPerBlock / QuantGroupSize::kN + (BQuantGroupSize::kN <= TilePartitioner::NPerBlock) + ? TilePartitioner::NPerBlock / BQuantGroupSize::kN : 1; if constexpr(std::is_same_v) { return make_tile_window( bq_tensor_view, - make_tuple(number{}, + make_tuple(number{}, number{}), - {0, i_n / QuantGroupSize::kN}); + {0, i_n / BQuantGroupSize::kN}); } else { @@ -964,21 +966,11 @@ struct QuantGemmKernel return make_tile_window( bq_tensor_view, make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); + number{}), + {i_n / BQuantGroupSize::kN, 0}); } } } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_tensor_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } else { return nullptr; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cd70c2ca862..ddc1a260efa 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -72,7 +72,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } @@ -329,9 +332,12 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), - 0) + (PreshuffleQuant) + ? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), + 0) : is_bq_row_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ); @@ -484,7 +490,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -495,7 +501,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp index 1acf0444cf3..9ddb7eecacb 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp @@ -12,21 +12,21 @@ namespace ck_tile { template struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase { - using Base = GemmPipelineAgBgCrImplBase; - using ADataType = typename Base::ADataType; - using ALayout = typename Base::ALayout; - using BDataType = typename Base::BDataType; - using BLayout = typename Base::BLayout; - using BlockGemmShape = typename Base::BlockGemmShape; - using QuantGroupSize = remove_cvref_t; + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + using AQuantGroupSize = remove_cvref_t; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK; + static constexpr index_t KPerBlockAQ = KPerBlock / AQuantGroupSize::kK; - static_assert(KPerBlock % QuantGroupSize::kK == 0, + static_assert(KPerBlock % AQuantGroupSize::kK == 0, "KPerBlock must be a multiple of QuantGroupSize"); // Create DRAM tile window for AQ diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 650cd947f73..04cf68b8f72 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -23,19 +23,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using Base = BaseGemmPipelineAgBgCrMem; using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using AQDataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; // When ADataType is pk_int4_t, use BDataType instead for transpose operations // since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision) using OverrideADataType = std::conditional_t, BDataType, ADataType>; - static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!"); - static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); + static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!"); + static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); using I0 = number<0>; using I1 = number<1>; @@ -60,7 +60,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK; + static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK; static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -99,7 +99,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem BlockSize, concat('x', WaveNumM, WaveNumN), concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK), - concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(), + concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave // clang-format on } @@ -156,7 +156,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem << "\n" << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" - << "QuantGroupSize: " << QuantGroupSize::GetName() << "\n" + << "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n" << "KPack: " << BlockGemm::Traits::KPack << "\n" << "PrefetchStages: " << PrefetchStages << "\n"; return str.str(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 71e4a744003..0bc455b7f83 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -20,19 +20,19 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using AQDataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; // When ADataType is pk_int4_t, use BDataType instead for transpose operations // since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision) using OverrideADataType = std::conditional_t, BDataType, ADataType>; - static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!"); - static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); + static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!"); + static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); using I0 = number<0>; using I1 = number<1>; @@ -57,7 +57,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -96,7 +96,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase { - using Base = GemmPipelineAgBgCrImplBase; - using ADataType = typename Base::ADataType; - using ALayout = typename Base::ALayout; - using BDataType = typename Base::BDataType; - using BLayout = typename Base::BLayout; - using BlockGemmShape = typename Base::BlockGemmShape; - using QuantGroupSize = remove_cvref_t; + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + using BQuantGroupSize = remove_cvref_t; using BQLayout = remove_cvref_t; @@ -27,16 +27,16 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase= 1, "NPerBlock must be >= QuantGroupSize"); - static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize"); + // static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize"); + static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize"); - // static_assert(NPerBlock % QuantGroupSize::kN == 0, - // "NPerBlock must be a multiple of QuantGroupSize::kN"); - static_assert(KPerBlock % QuantGroupSize::kK == 0, - "KPerBlock must be a multiple of QuantGroupSize::kK"); + // static_assert(NPerBlock % BQuantGroupSize::kN == 0, + // "NPerBlock must be a multiple of BQuantGroupSize::kN"); + static_assert(KPerBlock % BQuantGroupSize::kK == 0, + "KPerBlock must be a multiple of BQuantGroupSize::kK"); // Create DRAM tile window for BQ template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 5c4dfd37c79..9de21335ab1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -45,8 +45,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock) - ? NPerBlock / Problem::QuantGroupSize::kN + constexpr index_t NPerBlockBQ = (Problem::BQuantGroupSize::kN <= NPerBlock) + ? NPerBlock / Problem::BQuantGroupSize::kN : 1; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index be91002cdbd..7e5dc1cbc4f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -26,12 +26,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using BQDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; using ALayout = remove_cvref_t; using BQLayout = remove_cvref_t; @@ -45,7 +45,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; - static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); + static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; using I1 = number<1>; using I2 = number<2>; @@ -66,11 +66,11 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -109,7 +109,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{})) - ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) : ck_tile::integer_least_multiple(n, NPerBlock) / BlockGemmShape::WarpTile::at(number<1>{})), 0) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 39b00d2501b..1edbe9ac164 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -120,7 +120,7 @@ template ; using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; @@ -72,11 +72,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; static constexpr index_t NPerBlockBQ = - integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN); + integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN); static constexpr index_t KPerBlockBQ = - integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK); + integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = - integer_divide_ceil(kKPerBlock, QuantGroupSize::kK); + integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK); static constexpr index_t GetVectorSizeBQ() { @@ -94,7 +94,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV BlockSize, concat('x', WaveNumM, WaveNumN), concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()), - concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName()); + concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName()); // clang-format on } @@ -115,7 +115,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // then by vector width to get an approximate number of vector loads. constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), - QuantGroupSize::kK * QuantGroupSize::kK), + BQuantGroupSize::kK * BQuantGroupSize::kK), VectorLoadSize); // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration @@ -364,7 +364,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV { move_tile_window(bq_copy_dram_window, {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) - ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) : ck_tile::integer_least_multiple(n, kNPerBlock) / BlockGemmShape::WarpTile::at(number<1>{})), 0}); @@ -441,7 +441,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV { move_tile_window(bq_copy_dram_window, {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) - ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) : ck_tile::integer_least_multiple(n, kNPerBlock) / BlockGemmShape::WarpTile::at(number<1>{})), 0}); @@ -478,7 +478,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV { move_tile_window(bq_copy_dram_window, {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) - ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) : ck_tile::integer_least_multiple(n, kNPerBlock) / BlockGemmShape::WarpTile::at(number<1>{})), 0}); From bc13451746e1fed53d1e3e6d21ea2fc4e77ee409 Mon Sep 17 00:00:00 2001 From: amd-khushbu Date: Wed, 21 Jan 2026 23:31:22 +0000 Subject: [PATCH 3/5] fix mxfp4 to use correct QuantGroupSize --- .../gemm_mxfp4_pipeline_ag_bg_cr_base.hpp | 30 +++++++++---------- .../gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp | 8 ++--- .../gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 26 ++++++++-------- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp index 95122630ee7..facec252a35 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp @@ -12,13 +12,13 @@ namespace ck_tile { template struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase { - using Base = GemmPipelineAgBgCrImplBase; - using ADataType = typename Base::ADataType; - using ALayout = typename Base::ALayout; - using BDataType = typename Base::BDataType; - using BLayout = typename Base::BLayout; - using BlockGemmShape = typename Base::BlockGemmShape; - using QuantGroupSize = remove_cvref_t; + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + using BQuantGroupSize = remove_cvref_t; using BQLayout = remove_cvref_t; @@ -26,16 +26,16 @@ struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase= 1, "NPerBlock must be >= QuantGroupSize"); - static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize"); + static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize"); + static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize"); - static_assert(NPerBlock % QuantGroupSize::kN == 0, - "NPerBlock must be a multiple of QuantGroupSize::kN"); - static_assert(KPerBlock % QuantGroupSize::kK == 0, - "KPerBlock must be a multiple of QuantGroupSize::kK"); + static_assert(NPerBlock % BQuantGroupSize::kN == 0, + "NPerBlock must be a multiple of BQuantGroupSize::kN"); + static_assert(KPerBlock % BQuantGroupSize::kK == 0, + "KPerBlock must be a multiple of BQuantGroupSize::kK"); // Create DRAM tile window for BQ template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp index 7a2d1db2c85..6cf9e22f414 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp @@ -22,9 +22,9 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy using BQLayout = remove_cvref_t; using BQDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; static_assert(std::is_same_v); return GetABQGlobalVectorLoadSize(); @@ -76,7 +76,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2 + constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2 constexpr index_t VecLoadSize = Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; @@ -109,7 +109,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0, + static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, "KPerWarpGemm must be a multiple of QuantGroupSize!"); using WarpGemm = WarpGemmDispatcher; using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using BDqDataType = remove_cvref_t; - using BQDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using QuantGroupSize = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BDqDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; - static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); + static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; using I1 = number<1>; @@ -58,8 +58,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -93,7 +93,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Date: Thu, 22 Jan 2026 21:29:57 +0000 Subject: [PATCH 4/5] adding tests for preshuffleQuant --- test/ck_tile/gemm_block_scale/CMakeLists.txt | 40 ++++++++++------- ...est_gemm_quant_abquant_preshuffleQuant.cpp | 44 +++++++++++++++++++ 2 files changed, 68 insertions(+), 16 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffleQuant.cpp diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 5749a8d3b27..b4e215a67b2 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -25,10 +25,24 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_tile_gemm_quant_aquant_base_ccr test_gemm_quant_aquant_base_ccr.cpp ) - target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - # ABQuant tests + add_gtest_executable(test_tile_gemm_quant_aquant_prefill + test_gemm_quant_aquant_prefill.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c + test_gemm_quant_aquant_transpose_c.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle + test_gemm_quant_aquant_preshuffle.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # ABQuant tests split into 4 files add_gtest_executable(test_tile_gemm_quant_abquant_base test_gemm_quant_abquant_base.cpp ) @@ -44,21 +58,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - # AQuant tests - add_gtest_executable(test_tile_gemm_quant_aquant_prefill - test_gemm_quant_aquant_prefill.cpp - ) - target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - - add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c - test_gemm_quant_aquant_transpose_c.cpp - ) - target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - - add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle - test_gemm_quant_aquant_preshuffle.cpp + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant + test_gemm_quant_abquant_preshuffleQuant.cpp ) - target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_tile_gemm_quant_abquant_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) # BQuant tests (without PreshuffleB) - split into 6 files add_gtest_executable(test_tile_gemm_quant_bquant_1d_128 @@ -160,6 +163,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") test_tile_gemm_quant_aquant_prefill test_tile_gemm_quant_aquant_transpose_c test_tile_gemm_quant_aquant_preshuffle + # ABQuant tests + test_tile_gemm_quant_abquant_base + test_tile_gemm_quant_abquant_padding + test_tile_gemm_quant_abquant_preshuffle + test_tile_gemm_quant_abquant_preshuffleQuant # BQuant tests test_tile_gemm_quant_bquant_1d_128 test_tile_gemm_quant_bquant_1d_64 diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffleQuant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffleQuant.cpp new file mode 100644 index 00000000000..63d84ce99f6 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffleQuant.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} From 56a9b179965e2944cba4e83663c8f7ad6c320603 Mon Sep 17 00:00:00 2001 From: khuagarw Date: Sat, 24 Jan 2026 03:09:55 +0000 Subject: [PATCH 5/5] fixing A/B preshuffle quant to enable only BpreshuffleQaunt in ABQuant --- .../gemm_abquant_quantgrouped.cpp | 30 ---------- .../38_block_scale_gemm/gemm_utils.hpp | 12 ++-- .../run_gemm_quant_example.inc | 24 ++++---- .../block/block_gemm_quant_common.hpp | 6 +- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 3 +- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 4 +- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 18 +++--- .../block_universal_gemm_as_aquant_bs_cr.hpp | 4 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 16 ++--- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 59 +++++++++++-------- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 9 +-- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 4 +- .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 36 +++++------ .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 4 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 20 +++---- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 6 +- .../pipeline/gemm_group_quant_utils.hpp | 15 ++--- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 8 +-- .../pipeline/tile_gemm_quant_traits.hpp | 8 ++- .../gemm_block_scale/test_gemm_quant_base.hpp | 6 +- .../test_gemm_quant_fixtures.hpp | 23 ++++---- 21 files changed, 151 insertions(+), 164 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 619fd11b914..c887dae7d01 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -165,34 +165,4 @@ void abquant_quantgrouped_instance_factory( BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "preshufflequant", - "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "abquant", - "non-preshuffleb", - "preshufflequant", - "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using AQuantGroupSize = ck_tile::QuantGroupShape>; - using BQuantGroupSize = ck_tile::QuantGroupShape>; - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - AQuantGroupSize, - BQuantGroupSize, - ck_tile::QuantType::ABQuantGrouped>(arg_parser); - }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index a95ca4862cf..72720244457 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -72,7 +72,8 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; - static constexpr bool PreshuffleQuant = false; + static constexpr bool APreshuffleQuant = false; + static constexpr bool BPreshuffleQuant = false; static constexpr bool PreshuffleB = false; static constexpr bool DoubleSmemBuffer = false; static constexpr bool TiledMMAPermuteN = false; @@ -128,7 +129,8 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); - static constexpr bool PreshuffleQuant = true; + static constexpr bool APreshuffleQuant = true; + static constexpr bool BPreshuffleQuant = true; }; template @@ -158,7 +160,7 @@ template struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode : public GemmConfigPreshuffleB_BQuant_Decode { - static constexpr bool PreshuffleQuant = true; + static constexpr bool BPreshuffleQuant = true; }; template @@ -189,7 +191,7 @@ template struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill { - static constexpr bool PreshuffleQuant = true; + static constexpr bool BPreshuffleQuant = true; }; template @@ -241,7 +243,7 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill template struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { - static constexpr bool PreshuffleQuant = true; + static constexpr bool BPreshuffleQuant = true; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 6875784970d..9d7b81f943d 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -50,7 +50,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using GemmTraits = ck_tile::TileGemmQuantTraits, std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, + QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true, ck_tile::BaseGemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, @@ -147,7 +148,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str has_hot_loop_v, tail_number_v>>>>; using AQuantPipeline = - std::conditional_t, ck_tile::AQuantGemmPipelineAgBgCrMem>; @@ -391,8 +392,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::cout << " Acc_Type = " << ck_tile::DataTypeTraits::name << " C_Type = " << ck_tile::DataTypeTraits::name << " QuantMode = " << quant_type_to_string(QuantMode) - << " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : " - << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : " + << " APreshuffleQuant = " << (GemmConfig::APreshuffleQuant ? "true" : "false") + << " : " + << " BPreshuffleQuant = " << (GemmConfig::BPreshuffleQuant ? "true" : "false") + << " : " << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -686,7 +689,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) { - if constexpr(GemmConfig::PreshuffleQuant) + if constexpr(GemmConfig::APreshuffleQuant) { ck_tile::HostTensor aq_shuffle_host = ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / AQuantGroupSize::kK); @@ -745,7 +748,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor bq_permuted_host = ck_tile::bq_permuteN(*bq_tensor_ptr, BQuantGroupSize::kN); - if constexpr(GemmConfig::PreshuffleQuant) + if constexpr(GemmConfig::BPreshuffleQuant) { ck_tile::HostTensor bq_shuffle_host = ck_tile::shuffle_bq( &bq_permuted_host, GemmConfig::K_Tile / BQuantGroupSize::kK); @@ -756,7 +759,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, bq_dev_buf_ptr->ToDevice(bq_permuted_host.data()); } } - else if constexpr(GemmConfig::PreshuffleQuant) + else if constexpr(GemmConfig::BPreshuffleQuant) { ck_tile::HostTensor bq_shuffle_host = ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / BQuantGroupSize::kK); @@ -937,7 +940,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::ABQuantGrouped) && - !GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB) + !GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB) { if(a_layout == "R" && b_layout == "R") { @@ -958,7 +961,8 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) arg_parser, Col{}, Row{}, Row{}, Col{}, Row{}); } } - if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant) + if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && + !GemmConfig::APreshuffleQuant) { if(a_layout == "C" && b_layout == "C") { diff --git a/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp index fb4a701eac8..fcf1261754b 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp @@ -96,9 +96,9 @@ struct AQPickerCommon : public BlockGemmQuantBase if constexpr(Traits::TransposeC) // transposed C { index_t reg_offset = - Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale; + Traits::APreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale; auto scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset]; - if constexpr(Traits::PreshuffleQuant) + if constexpr(Traits::APreshuffleQuant) { auto pull_from_lane = (__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock + kQScale; @@ -121,7 +121,7 @@ struct AQPickerCommon : public BlockGemmQuantBase } else { - if constexpr(Traits::PreshuffleQuant) + if constexpr(Traits::APreshuffleQuant) { // A view is created on top of the preshuffled AQ, where each row of // the view is composed of a row from a warp tile within an AQ block diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index ddaf28a7c6d..abb9de6cf36 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -69,7 +69,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant; + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index 2f8a1b5cf43..d2cfaca7b72 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -53,7 +53,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); static constexpr index_t NIterPerWarp = @@ -173,7 +173,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg c_warp_y_index_zeros)) / CBlockTensor::PackedSize>{}; - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { constexpr index_t reg_offset = nIter; auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale; diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index b6749021c4b..53793bea855 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -76,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant; + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); @@ -161,7 +162,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using BWarpTensor = typename WarpGemm::BWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor; - static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant; + static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant; + static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant; static_assert(std::is_same_v); @@ -359,18 +361,14 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase AQPickerCommon aq_picker( aq_block_tensor); - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { constexpr index_t reg_offset = [&]() { if constexpr(GemmTraits::BQuantGroupSize::kN > - (NWarp * WarpGemm::kN)) + (NWarp * WarpGemm::kN) && + Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) { - if constexpr(Traits::NPerBlock == - GemmTraits::BQuantGroupSize::kN) - return kQScale; - else - return nIter; // for prefill needs kQscale, for decode needs - // nIter + return kQScale; } else { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 8aebc0790e9..0d6b91953d4 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -110,8 +110,8 @@ struct AQuantBlockUniversalGemmAsBsCr static constexpr index_t KPack = WarpGemm::kKPerThread; static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; - static constexpr bool TransposeC = Problem::TransposeC; + static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant; + static constexpr bool TransposeC = Problem::TransposeC; }; public: diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index e2083cfd2a1..d2d25416e48 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -72,7 +72,7 @@ struct BQuantBlockUniversalGemmAsBsCr static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); @@ -153,7 +153,7 @@ struct BQuantBlockUniversalGemmAsBsCr using BWarpTensor = typename WarpGemm::BWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor; - static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant; + static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant; static_assert(std::is_same_v); @@ -317,19 +317,15 @@ struct BQuantBlockUniversalGemmAsBsCr c_warp_y_index_zeros)) / CBlockTensor::PackedSize>{}; - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { // constexpr index_t reg_offset = nIter; constexpr index_t reg_offset = [&]() { if constexpr(GemmTraits::BQuantGroupSize::kN > - (NWarp * WarpGemm::kN)) + (NWarp * WarpGemm::kN) && + Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN) { - if constexpr(Traits::NPerBlock == - GemmTraits::BQuantGroupSize::kN) - return kQScale; - else - return nIter; // for prefill needs kQscale, for decode needs - // nIter + return kQScale; } else { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 303b4d48006..6ca5f3050de 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -67,15 +67,27 @@ struct get_bq_data_type_or> }; template -struct is_quantpreshuffle_enabled +struct is_Aquantpreshuffle_enabled { static constexpr bool value = false; }; template -struct is_quantpreshuffle_enabled> +struct is_Aquantpreshuffle_enabled> { - static constexpr bool value = T::PreshuffleQuant; + static constexpr bool value = T::APreshuffleQuant; +}; + +template +struct is_Bquantpreshuffle_enabled +{ + static constexpr bool value = false; +}; + +template +struct is_Bquantpreshuffle_enabled> +{ + static constexpr bool value = T::BPreshuffleQuant; }; template @@ -206,8 +218,10 @@ struct QuantGemmKernel typename detail::get_bq_layout_or::type>; static constexpr index_t kBlockSize = GemmPipeline::BlockSize; - static constexpr bool PreshuffleQuant = - detail::is_quantpreshuffle_enabled::value; + static constexpr bool APreshuffleQuant = + detail::is_Aquantpreshuffle_enabled::value; + static constexpr bool BPreshuffleQuant = + detail::is_Bquantpreshuffle_enabled::value; static constexpr bool PreshuffleB = detail::is_preshuffleB_enabled::value; using ADataType = remove_cvref_t; @@ -476,9 +490,7 @@ struct QuantGemmKernel { // Step 1: Create tensor view for AQ const auto& aq_tensor_view = [&]() { - if constexpr((kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) && - PreshuffleQuant) + if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant) { static_assert(std::is_same_v); const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; @@ -533,9 +545,9 @@ struct QuantGemmKernel return make_tensor_view(aq_ptr, aq_merge_pad1_desc); } - else if constexpr((kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) && - !PreshuffleQuant) + else if constexpr((kQuantType == QuantType::ABQuantGrouped || + kQuantType == QuantType::AQuantGrouped) && + !APreshuffleQuant) { if constexpr(std::is_same_v) { @@ -573,9 +585,7 @@ struct QuantGemmKernel // Step 2: Create tile window (no padding for AQ) const auto& aq_block_window = [&]() { - if constexpr((kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) && - PreshuffleQuant) + if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant) { static_assert(std::is_same_v); using AQuantGroupSize = remove_cvref_t; @@ -591,11 +601,10 @@ struct QuantGemmKernel make_tuple(number{}, number{}), {block_m_idx * tile_window_height, 0}); } - else if constexpr((kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) && - !PreshuffleQuant) + else if constexpr((kQuantType == QuantType::ABQuantGrouped || + kQuantType == QuantType::AQuantGrouped) && + !APreshuffleQuant) { - using AQuantGroupSize = remove_cvref_t; constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK; constexpr auto block_m = TilePartitioner::MPerBlock; @@ -812,10 +821,10 @@ struct QuantGemmKernel else if constexpr(kQuantType == QuantType::BQuantGrouped || kQuantType == QuantType::ABQuantGrouped) { - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { static_assert(std::is_same_v, - "PreshuffleQuant with BQuantGrouped currently only supports " + "BPreshuffleQuant with BQuantGrouped currently only supports " "ColumnMajor BQ layout"); using BQuantGroupSize = remove_cvref_t; @@ -881,7 +890,7 @@ struct QuantGemmKernel kQuantType == QuantType::ABQuantGrouped) { using BQuantGroupSize = remove_cvref_t; - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { static_assert(std::is_same_v); @@ -1215,7 +1224,7 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::AQuantGrouped) { index_t m = 0; - if constexpr(PreshuffleQuant) + if constexpr(APreshuffleQuant) { m = kargs.M; } @@ -1225,7 +1234,7 @@ struct QuantGemmKernel else if constexpr(kQuantType == QuantType::BQuantGrouped) { index_t n = 0; - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { n = kargs.N; } @@ -1236,9 +1245,9 @@ struct QuantGemmKernel { index_t m = 0; index_t n = 0; - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { - m = kargs.M; + // m = kargs.M; n = kargs.N; } return GemmPipeline{}.template operator()(a_block_window, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index ddc1a260efa..5902dd0c4f7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -98,7 +98,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{})) ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) : ck_tile::integer_least_multiple(n, NPerBlock) / diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 04cf68b8f72..177f0c5e546 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -78,7 +78,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; @@ -216,7 +216,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); + static_assert(!APreshuffleQuant, "Memory pipeline does not support APreshuffleQuant!"); static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 4485144f5f3..27828cce634 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -32,22 +32,22 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using AQLayout = remove_cvref_t; using BlockGemmShape = typename Problem::BlockGemmShape; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK; - constexpr index_t VecLoadSize = GetVectorSizeAQ(); - constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmDispatcher; + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK; + constexpr index_t VecLoadSize = GetVectorSizeAQ(); + constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; - if constexpr(PreshuffleQuant) + if constexpr(APreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_aq< BlockGemmShape, @@ -57,7 +57,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()), KPerBlockAQ, VecLoadSize, - PreshuffleQuant>; + APreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } @@ -89,7 +89,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC KPerBlockAQ, KPerBlockAQ, VecLoadSize, - PreshuffleQuant>; + APreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } @@ -103,7 +103,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC MPerBlock, // XPerTile KPerBlockAQ, VecLoadSize, - PreshuffleQuant>; + APreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution_transposed(); } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 0bc455b7f83..76d8985fb15 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -75,7 +75,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using BlockGemmShape = typename Problem::BlockGemmShape; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t NPerBlockBQ = (Problem::BQuantGroupSize::kN <= NPerBlock) - ? NPerBlock / Problem::BQuantGroupSize::kN - : 1; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; - constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = (Problem::BQuantGroupSize::kN <= NPerBlock) + ? NPerBlock / Problem::BQuantGroupSize::kN + : 1; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmDispatcher; - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_bq< BlockGemmShape, @@ -72,7 +72,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC Problem::BQuantGroupSize::kN, Problem::BQuantGroupSize::kK, BQLayout, - PreshuffleQuant>; + BPreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } else diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 7e5dc1cbc4f..df94eb72731 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -88,7 +88,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{})) ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) : ck_tile::integer_least_multiple(n, NPerBlock) / diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index bde0be89c02..48c27945b3f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -52,7 +52,7 @@ template + bool APreshuffleQuant> struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern { static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -72,7 +72,7 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { - if constexpr(PreshuffleQuant) + if constexpr(APreshuffleQuant) { // # of elements per thread static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0); @@ -193,8 +193,8 @@ template + typename BQLayout = tensor_layout::gemm::ColumnMajor, + bool BPreshuffleQuant = false> struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { static constexpr index_t warp_size = get_warp_size(); @@ -212,10 +212,11 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { // Preshuffle only supported for ColumnMajor currently - static_assert(!(PreshuffleQuant && std::is_same_v), - "PreshuffleQuant only supported for ColumnMajor BQLayout"); + static_assert( + !(BPreshuffleQuant && std::is_same_v), + "PreshuffleQuant only supported for ColumnMajor BQLayout"); - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { // ============================================================================= // PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 827f8864d83..a49279585e8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -69,7 +69,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV using Base::m_preload; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; static constexpr index_t NPerBlockBQ = integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN); @@ -360,7 +360,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV BQBlockTile bq_block_tile, bq_block_tile_2; bq_block_tile = load_tile(bq_copy_dram_window); // move BQ to tile 1 - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { move_tile_window(bq_copy_dram_window, {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) @@ -437,7 +437,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); bq_block_tile_2 = load_tile(bq_copy_dram_window); - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { move_tile_window(bq_copy_dram_window, {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) @@ -474,7 +474,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); bq_block_tile = load_tile(bq_copy_dram_window); - if constexpr(PreshuffleQuant) + if constexpr(BPreshuffleQuant) { move_tile_window(bq_copy_dram_window, {((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index b956caa14ff..5db09a0c46d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -33,7 +33,8 @@ inline std::string quant_type_to_string(QuantType quant_type) template @@ -212,7 +213,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase aq_shuffle_host = ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK); @@ -478,7 +479,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase(bq_bqk_bqn, QuantGroupSize::kN); bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); } - else if constexpr(GemmConfig::PreshuffleQuant) + else if constexpr(GemmConfig::BPreshuffleQuant) { ck_tile::HostTensor bq_shuffle_host = ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK); @@ -775,7 +776,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase aq_shuffle_host = ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / AQuantGroupSize::kK); @@ -792,7 +793,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase(bq_bqk_bqn, BQuantGroupSize::kN); bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); } - else if constexpr(GemmConfig::PreshuffleQuant) + else if constexpr(GemmConfig::BPreshuffleQuant) { ck_tile::HostTensor bq_shuffle_host = ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / BQuantGroupSize::kK);