Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,28 @@ struct BlockUniversalGemmAsBsCr
}

// C += A * B
template <typename CBlockTensor,
template <bool DoLocalPrefetch = true,
typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow&,
const BSmemBlockWindow&,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");

// LocalPrefetch can be controlled by pipeline via DoLocalPrefetch parameter
if constexpr(DoLocalPrefetch)
{
LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}

// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
Expand Down Expand Up @@ -275,6 +282,7 @@ struct BlockUniversalGemmAsBsCr
});
});
});
block_sync_lds();
}
};

Expand Down Expand Up @@ -355,7 +363,8 @@ struct BlockUniversalGemmAsBsCr
}

// C += A * B
template <typename CBlockTensor,
template <bool DoLocalPrefetch = true,
typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
Expand All @@ -372,7 +381,14 @@ struct BlockUniversalGemmAsBsCr

// hot loop:
static_for<0, KRepeat, 1>{}([&](auto kIter) {
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
// Note: Interwave scheduler requires LocalPrefetch as part of its design
// The DoLocalPrefetch flag is provided for API consistency but should typically be
// true
if constexpr(DoLocalPrefetch)
{
LocalPrefetch<kIter.value>(
a_block_window, b_block_window, a_load_tr, b_load_tr);
}
Comment on lines +384 to +391
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states that LocalPrefetch is required for Interwave scheduler design, but the code still allows it to be disabled via DoLocalPrefetch=false. Consider either enforcing this requirement with a static_assert or updating the comment to clarify what happens if DoLocalPrefetch is false for Interwave.

Copilot uses AI. Check for mistakes.
__builtin_amdgcn_sched_barrier(
0); // Complete scheduling all pending instruction groups before this point

Expand Down Expand Up @@ -494,7 +510,8 @@ struct BlockUniversalGemmAsBsCr
}

// C += A * B
template <typename CBlockTensor,
template <bool DoLocalPrefetch = true,
typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
Expand All @@ -505,7 +522,8 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
block_gemm_impl_.template operator()<DoLocalPrefetch>(
c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
}

// C = A * B
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);

block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);

block_sync_lds();
block_gemm.template operator()<false>(
c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// block_sync_lds();
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing the commented-out block_sync_lds() calls rather than leaving them in the code. If these are no longer needed due to the refactoring (where synchronization is now handled within the BlockGemm operator), they should be removed to avoid confusion. If they might be needed for debugging or future reference, add a comment explaining why they are commented out.

Suggested change
// block_sync_lds();

Copilot uses AI. Check for mistakes.

block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
Expand All @@ -586,12 +586,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_gemm.template operator()<false>(
c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.template operator()<false>(
c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// block_sync_lds();
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing the commented-out block_sync_lds() calls rather than leaving them in the code. If these are no longer needed due to the refactoring (where synchronization is now handled within the BlockGemm operator), they should be removed to avoid confusion. If they might be needed for debugging or future reference, add a comment explaining why they are commented out.

Suggested change
// block_sync_lds();

Copilot uses AI. Check for mistakes.

if constexpr(is_a_col_major && !is_a_load_tr_v())
{
Expand All @@ -616,9 +618,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_gemm.template operator()<true>(
c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// __builtin_amdgcn_sched_barrier(0);
return c_block_tile;
Expand Down
Loading