From b06f6f582c787b25775d7e59329cd591166b2f9c Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Sat, 10 Jan 2026 18:09:26 +0000 Subject: [PATCH 1/5] refactor: remove Default scheduler implementation as it not used anymore --- .../block/block_universal_gemm_as_bs_cr.hpp | 77 ------------------- 1 file changed, 77 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index f6e26ad206d..2a05fbf53e4 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -194,83 +194,6 @@ struct BlockUniversalGemmAsBsCr { }; - template - struct BlockGemmImpl - { - static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; - static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - - ALdsTile a_warp_tile_; - BLdsTile b_warp_tile_; - - // C += A * B - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window, - bool_constant = {}, - bool_constant = {}) - { - static_assert(std::is_same_v, - "The CDataType as defined in traits should be the same as correspoinding " - "C block tensor data type!"); - static_assert(std::is_same_v && - std::is_same_v, - "The ADataType and BDataType as defined in " - "traits should be the same as correspoinding block window data type!"); - - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); - // hot loop: - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // read C warp tensor from C block tensor- - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - } - }; - template struct BlockGemmImpl { From 09aa53610cbe95933649efe065952c9f6b3d1807 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Sat, 10 Jan 2026 19:20:48 +0000 Subject: [PATCH 2/5] refactor: remove dead code from gemm universal kernel --- include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp | 5 +---- include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 628f5f7dc89..9583ac8a3f1 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1035,7 +1035,6 @@ struct UniversalGemmKernel * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * */ - template CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, const std::array& bs_ptr, const std::array& ds_ptr, @@ -1161,9 +1160,7 @@ struct UniversalGemmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - constexpr auto scheduler_type = - GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1); - RunGemm( + RunGemm( as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 46c1f69b126..3597590c0f3 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -80,7 +80,7 @@ struct GemmPipelineProblemBase static constexpr bool kPadK = Traits::kPadK; static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave; static constexpr index_t VectorLoadSize = Traits::_VectorSize; // In the base situation, the Preshuffle setting should be false. From 958c99e53883f813044f54cb8f4ea193344df5ba Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Sat, 10 Jan 2026 21:47:54 +0000 Subject: [PATCH 3/5] chore: add descriptive comments about amd intrinsic hardware sync instructions --- .../gemm/block/block_universal_gemm_as_bs_cr.hpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 2a05fbf53e4..79030fcd513 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -373,7 +373,9 @@ struct BlockUniversalGemmAsBsCr // hot loop: static_for<0, KRepeat, 1>{}([&](auto kIter) { LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier( + 0); // Complete scheduling all pending instruction groups before this point + // NOTE: Synchronize threads in a workgroup at the start of each MAC // cluster, but except the first, as we can shorten non-MAC cluster a bit // and there's no observable negative impact. The desired effect is waves in @@ -383,8 +385,14 @@ struct BlockUniversalGemmAsBsCr // sync point. if constexpr(kIter.value != 0 || KRepeat == 1) { - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); + // This pattern ensures: + // At runtime: All waves synchronize (hardware barrier) + // At compile-time: Instructions after the barrier don't get moved before it + // (scheduling barrier) + __builtin_amdgcn_s_barrier(); // Blocks execution until all waves (threads) in + // the workgroup reach this point + __builtin_amdgcn_sched_barrier( + 0); // Prevents instruction reordering across this boundary } static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { From 2240aa51c27d7d698380dcfe8b818225ee21f474 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Fri, 9 Jan 2026 15:53:36 +0000 Subject: [PATCH 4/5] fix: label existing memory pipeline for aquant as intrawave --- .../gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 91dfc8494ad..2f6497fdbae 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -164,7 +164,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }; template <> - struct PipelineImpl : public PipelineImplBase + struct PipelineImpl : public PipelineImplBase { using Base = PipelineImplBase; @@ -491,7 +491,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem void* p_smem, index_t m = 0) const { - return PipelineImpl{} + return PipelineImpl{} .template operator()( a_dram_block_window_tmp, [](const BDataType& a) { return a; }, From 0f2c9134d8991b5fd079df367f876c1f64d6e36c Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Mon, 12 Jan 2026 16:07:28 +0000 Subject: [PATCH 5/5] refactor: unify interwave and intrawave pipeline implementation at work-group level --- .../block/block_universal_gemm_as_bs_cr.hpp | 36 +- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 19 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 324 +----------------- 3 files changed, 40 insertions(+), 339 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 79030fcd513..97df6163a1a 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -224,21 +224,28 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ASmemBlockWindow&, - const BSmemBlockWindow&, - bool_constant = {}, - bool_constant = {}) + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!"); + // LocalPrefetch can be controlled by pipeline via DoLocalPrefetch parameter + if constexpr(DoLocalPrefetch) + { + LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + } + // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -275,6 +282,7 @@ struct BlockUniversalGemmAsBsCr }); }); }); + block_sync_lds(); } }; @@ -355,7 +363,8 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template {}([&](auto kIter) { - LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + // Note: Interwave scheduler requires LocalPrefetch as part of its design + // The DoLocalPrefetch flag is provided for API consistency but should typically be + // true + if constexpr(DoLocalPrefetch) + { + LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + } __builtin_amdgcn_sched_barrier( 0); // Complete scheduling all pending instruction groups before this point @@ -494,7 +510,8 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template a_load_tr = {}, bool_constant b_load_tr = {}) { - block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr); + block_gemm_impl_.template operator()( + c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr); } // C = A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 8fae7042037..1ef405cc91b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -569,9 +569,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 load_tile_with_elementwise(b_copy_dram_window, b_element_func); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); + block_gemm.template operator()( + c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + // block_sync_lds(); block_gemm.LocalPrefetch( a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); @@ -586,12 +586,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { // Leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm.template operator()( + c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } else { - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); + block_gemm.template operator()( + c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + // block_sync_lds(); if constexpr(is_a_col_major && !is_a_load_tr_v()) { @@ -616,9 +618,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm.template operator()( + c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); return c_block_tile; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 38a22e38ac2..12a3bfe99d0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -231,13 +231,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return Policy::template GetSmemSize(); } - template + // Unified pipeline implementation for both Intrawave and Interwave schedulers + // Scheduler-specific behavior is encapsulated in BlockGemm (via BlockGemmImpl specializations) + template struct PipelineImpl : public PipelineImplBase - { - }; - - template <> - struct PipelineImpl : public PipelineImplBase { using Base = PipelineImplBase; @@ -407,318 +404,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); - // main body - if constexpr(HasHotLoop) - { - index_t i = 0; - do - { - static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { - block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - if constexpr(is_a_col_major && !is_a_load_tr_v()) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d( - a_shuffle_tmp, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); - } - else - { - Base::LocalPrefill( - a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - } - if constexpr(is_b_row_major && !is_b_load_tr_v()) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d( - b_shuffle_tmp, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); - } - else - { - Base::LocalPrefill( - b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - } - - a_block_tiles.at(number{}) = - load_tile_with_elementwise(a_copy_dram_window, a_element_func); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - - b_block_tiles.at(number{}) = - load_tile_with_elementwise(b_copy_dram_window, b_element_func); - - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - }); - - i += PrefetchStages; - } while(i < (num_loop - PrefetchStages)); - } - - auto HotLoopTail = [&](auto tail_num) { - static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { - block_sync_lds(); - - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - if constexpr(is_a_col_major && !is_a_load_tr_v()) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); - } - else - { - Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{})); - } - if constexpr(is_b_row_major && !is_b_load_tr_v()) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); - } - else - { - Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{})); - } - }); - - block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - }; - - if constexpr(TailNum == TailNumber::One) - { - block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } - else if constexpr(TailNum == TailNumber::Two) - { - HotLoopTail(number<2>{}); - } - else if constexpr(TailNum == TailNumber::Three) - { - HotLoopTail(number<3>{}); - } - else if constexpr(TailNum == TailNumber::Four) - { - HotLoopTail(number<4>{}); - } - else if constexpr(TailNum == TailNumber::Five) - { - HotLoopTail(number<5>{}); - } - else if constexpr(TailNum == TailNumber::Six) - { - HotLoopTail(number<6>{}); - } - else if constexpr(TailNum == TailNumber::Seven) - { - HotLoopTail(number<7>{}); - } - else if constexpr(TailNum == TailNumber::Full) - { - HotLoopTail(number{}); - } - - return c_block_tile; - } - }; - - template <> - struct PipelineImpl : public PipelineImplBase - { - using Base = PipelineImplBase; - - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - - static_assert( - std::is_same_v> && - std::is_same_v>, - "A/B Dram block window should have the same data type as appropriate " - "([A|B]DataType) defined in Problem definition!"); - - constexpr bool is_a_col_major = - std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_a_col_major - ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(is_b_row_major - ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "B block window has incorrect lengths for defined BLayout!"); - - // ------------------------------------------------------------------------------------ - // Definitions of all needed tiles - - // A/B tiles in LDS - // With c++20 could simplify to below line. - // Currently get error: captured structured bindings are a C++20 extension - // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); - auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); - auto& a_lds_block = ab_lds_blocks.at(I0{}); - auto& b_lds_block = ab_lds_blocks.at(I1{}); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); - - // A DRAM tile window for load - // A LDS tile window for store - // A LDS tile for block GEMM - auto a_windows = - Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); - auto& a_copy_dram_window = a_windows.at(I0{}); - auto& a_copy_lds_window = a_windows.at(I1{}); - auto& a_lds_gemm_window = a_windows.at(I2{}); - - // B DRAM tile window for load - // B LDS tile window for store - // B LDS tile for block GEMM - auto b_windows = - Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); - auto& b_copy_dram_window = b_windows.at(I0{}); - auto& b_copy_lds_window = b_windows.at(I1{}); - auto& b_lds_gemm_window = b_windows.at(I2{}); - - // Block GEMM - auto block_gemm = BlockGemm(); - auto c_block_tile = block_gemm.MakeCBlockTile(); - - using ABlockTileDistr = - decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); - using BBlockTileDistr = - decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); - - using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); - using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); - - tuple_array a_block_tiles; - tuple_array b_block_tiles; - - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - - constexpr ADramTileWindowStep a_dram_tile_window_step = - is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - // ----------------------------------------------------------------------------------------- - // Gemm pipeline start - - // prefetch - // global read 0 - - // Load tile — during value loading, an elementwise function is executed for each A0, - // A1, … AN. The values A0, A1, … AN are read by the same thread. - a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); - - // Move each A — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - - // Load tile — during value loading, an elementwise function is executed for each B0, - // B1, … BN. The values B0, B1, … BN are read by the same thread. - b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func); - - // Move each B — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - if constexpr(is_a_col_major && !is_a_load_tr_v()) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{})); - } - if constexpr(is_b_row_major && !is_b_load_tr_v()) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{})); - } - - // Global prefetch [1, PrefetchStages] - static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - a_block_tiles.at(number{}) = - load_tile_with_elementwise(a_copy_dram_window, a_element_func); - - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - - b_block_tiles.at(number{}) = - load_tile_with_elementwise(b_copy_dram_window, b_element_func); - - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - }); - // main body if constexpr(HasHotLoop) { @@ -732,7 +417,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - // no second block_sync_lds because it's interwave if constexpr(is_a_col_major && !is_a_load_tr_v()) { @@ -767,7 +451,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.at(number{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); b_block_tiles.at(number{}) = @@ -788,7 +471,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - // no second block_sync_lds because it's interwave if constexpr(is_a_col_major && !is_a_load_tr_v()) {