diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 79030fcd51..97df6163a1 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -224,21 +224,28 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ASmemBlockWindow&, - const BSmemBlockWindow&, - bool_constant = {}, - bool_constant = {}) + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!"); + // LocalPrefetch can be controlled by pipeline via DoLocalPrefetch parameter + if constexpr(DoLocalPrefetch) + { + LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + } + // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -275,6 +282,7 @@ struct BlockUniversalGemmAsBsCr }); }); }); + block_sync_lds(); } }; @@ -355,7 +363,8 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template {}([&](auto kIter) { - LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + // Note: Interwave scheduler requires LocalPrefetch as part of its design + // The DoLocalPrefetch flag is provided for API consistency but should typically be + // true + if constexpr(DoLocalPrefetch) + { + LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + } __builtin_amdgcn_sched_barrier( 0); // Complete scheduling all pending instruction groups before this point @@ -494,7 +510,8 @@ struct BlockUniversalGemmAsBsCr } // C += A * B - template a_load_tr = {}, bool_constant b_load_tr = {}) { - block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr); + block_gemm_impl_.template operator()( + c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr); } // C = A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 8fae704203..1ef405cc91 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -569,9 +569,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 load_tile_with_elementwise(b_copy_dram_window, b_element_func); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); + block_gemm.template operator()( + c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + // block_sync_lds(); block_gemm.LocalPrefetch( a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); @@ -586,12 +586,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { // Leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm.template operator()( + c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } else { - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); + block_gemm.template operator()( + c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + // block_sync_lds(); if constexpr(is_a_col_major && !is_a_load_tr_v()) { @@ -616,9 +618,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_gemm.template operator()( + c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); return c_block_tile; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 38a22e38ac..12a3bfe99d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -231,13 +231,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return Policy::template GetSmemSize(); } - template + // Unified pipeline implementation for both Intrawave and Interwave schedulers + // Scheduler-specific behavior is encapsulated in BlockGemm (via BlockGemmImpl specializations) + template struct PipelineImpl : public PipelineImplBase - { - }; - - template <> - struct PipelineImpl : public PipelineImplBase { using Base = PipelineImplBase; @@ -407,318 +404,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); - // main body - if constexpr(HasHotLoop) - { - index_t i = 0; - do - { - static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { - block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - if constexpr(is_a_col_major && !is_a_load_tr_v()) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d( - a_shuffle_tmp, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); - } - else - { - Base::LocalPrefill( - a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - } - if constexpr(is_b_row_major && !is_b_load_tr_v()) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d( - b_shuffle_tmp, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); - } - else - { - Base::LocalPrefill( - b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - } - - a_block_tiles.at(number{}) = - load_tile_with_elementwise(a_copy_dram_window, a_element_func); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - - b_block_tiles.at(number{}) = - load_tile_with_elementwise(b_copy_dram_window, b_element_func); - - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - }); - - i += PrefetchStages; - } while(i < (num_loop - PrefetchStages)); - } - - auto HotLoopTail = [&](auto tail_num) { - static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { - block_sync_lds(); - - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - if constexpr(is_a_col_major && !is_a_load_tr_v()) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); - } - else - { - Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{})); - } - if constexpr(is_b_row_major && !is_b_load_tr_v()) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); - } - else - { - Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{})); - } - }); - - block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - }; - - if constexpr(TailNum == TailNumber::One) - { - block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } - else if constexpr(TailNum == TailNumber::Two) - { - HotLoopTail(number<2>{}); - } - else if constexpr(TailNum == TailNumber::Three) - { - HotLoopTail(number<3>{}); - } - else if constexpr(TailNum == TailNumber::Four) - { - HotLoopTail(number<4>{}); - } - else if constexpr(TailNum == TailNumber::Five) - { - HotLoopTail(number<5>{}); - } - else if constexpr(TailNum == TailNumber::Six) - { - HotLoopTail(number<6>{}); - } - else if constexpr(TailNum == TailNumber::Seven) - { - HotLoopTail(number<7>{}); - } - else if constexpr(TailNum == TailNumber::Full) - { - HotLoopTail(number{}); - } - - return c_block_tile; - } - }; - - template <> - struct PipelineImpl : public PipelineImplBase - { - using Base = PipelineImplBase; - - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - - static_assert( - std::is_same_v> && - std::is_same_v>, - "A/B Dram block window should have the same data type as appropriate " - "([A|B]DataType) defined in Problem definition!"); - - constexpr bool is_a_col_major = - std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_a_col_major - ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(is_b_row_major - ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "B block window has incorrect lengths for defined BLayout!"); - - // ------------------------------------------------------------------------------------ - // Definitions of all needed tiles - - // A/B tiles in LDS - // With c++20 could simplify to below line. - // Currently get error: captured structured bindings are a C++20 extension - // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); - auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); - auto& a_lds_block = ab_lds_blocks.at(I0{}); - auto& b_lds_block = ab_lds_blocks.at(I1{}); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); - - // A DRAM tile window for load - // A LDS tile window for store - // A LDS tile for block GEMM - auto a_windows = - Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); - auto& a_copy_dram_window = a_windows.at(I0{}); - auto& a_copy_lds_window = a_windows.at(I1{}); - auto& a_lds_gemm_window = a_windows.at(I2{}); - - // B DRAM tile window for load - // B LDS tile window for store - // B LDS tile for block GEMM - auto b_windows = - Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); - auto& b_copy_dram_window = b_windows.at(I0{}); - auto& b_copy_lds_window = b_windows.at(I1{}); - auto& b_lds_gemm_window = b_windows.at(I2{}); - - // Block GEMM - auto block_gemm = BlockGemm(); - auto c_block_tile = block_gemm.MakeCBlockTile(); - - using ABlockTileDistr = - decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); - using BBlockTileDistr = - decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); - - using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); - using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); - - tuple_array a_block_tiles; - tuple_array b_block_tiles; - - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - - constexpr ADramTileWindowStep a_dram_tile_window_step = - is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - // ----------------------------------------------------------------------------------------- - // Gemm pipeline start - - // prefetch - // global read 0 - - // Load tile — during value loading, an elementwise function is executed for each A0, - // A1, … AN. The values A0, A1, … AN are read by the same thread. - a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); - - // Move each A — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - - // Load tile — during value loading, an elementwise function is executed for each B0, - // B1, … BN. The values B0, B1, … BN are read by the same thread. - b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func); - - // Move each B — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - if constexpr(is_a_col_major && !is_a_load_tr_v()) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{})); - } - if constexpr(is_b_row_major && !is_b_load_tr_v()) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{})); - } - - // Global prefetch [1, PrefetchStages] - static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - a_block_tiles.at(number{}) = - load_tile_with_elementwise(a_copy_dram_window, a_element_func); - - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - - b_block_tiles.at(number{}) = - load_tile_with_elementwise(b_copy_dram_window, b_element_func); - - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - }); - // main body if constexpr(HasHotLoop) { @@ -732,7 +417,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - // no second block_sync_lds because it's interwave if constexpr(is_a_col_major && !is_a_load_tr_v()) { @@ -767,7 +451,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.at(number{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); b_block_tiles.at(number{}) = @@ -788,7 +471,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); - // no second block_sync_lds because it's interwave if constexpr(is_a_col_major && !is_a_load_tr_v()) {