Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,23 +193,43 @@ struct VerificationHelper {
};
///////////////////////////////////////////////////////////////////////////////////////////////////

template <class TA, class TB> auto choose_tiled_mma(TA *A, TB *B) {
template <typename TA, typename TB, typename TD> auto choose_mma_op() {
if constexpr (is_any_of_v<TA, bfloat16_t, half_t> &&
is_any_of_v<TB, bfloat16_t, half_t>) {
return XE_DPAS_TT<8, float, TA, TB>{};
} else if constexpr (is_complete_v<XE_DPAS_TT<8, TD, TA, TB>>) {
return XE_DPAS_TT<8, TD, TA, TB>{};
} else if constexpr (is_same_v<TA, cute::bfloat16_t>) {
return XE_DPAS_TT<8, float, cute::bfloat16_t>{};
} else { /* Use f16 by default as upconversion sequences are typically faster
*/
return XE_DPAS_TT<8, float, cute::half_t>{};
}
}

template <typename TA, typename TB, typename TD>
auto choose_tiled_mma(TA *const &A, TB *const &B, TD *D) {
using TA_non_CV = cutlass::platform::remove_cv_t<TA>;
using TB_non_CV = cutlass::platform::remove_cv_t<TB>;
auto op = XE_DPAS_TT<8, float, TA_non_CV, TB_non_CV>{};
auto op = choose_mma_op<TA_non_CV, TB_non_CV, TD>();

using WGTile = Shape<_256, _128, _32>; // 256x128 WG tile size
using SGLayout =
constexpr bool use_4x8_sg = (sizeof_bits_v<TB> < sizeof_bits_v<TA>);
using WGTileShape =
conditional_t<use_4x8_sg, Shape<_256, _256, _32>, Shape<_256, _128, _32>>;
using SGLayout8x2 =
Layout<Shape<_8, _2, _1>, Stride<_2, _1, _0>>; // 8x2 SG tiling, n-major

using MMA = typename TiledMMAHelper<MMA_Atom<decltype(op)>, Layout<WGTile>,
SGLayout>::TiledMMA;

return MMA{};
using SGLayout4x8 =
Layout<Shape<_4, _8, _1>, Stride<_8, _1, _0>>; // 4x8 SG tiling, n-major
using SGLayout = conditional_t<use_4x8_sg, SGLayout4x8, SGLayout8x2>;

using MMA = typename TiledMMAHelper<MMA_Atom<decltype(op)>,
Layout<WGTileShape>, SGLayout>::TiledMMA;
auto mma = MMA{};
return mma;
}

// type tag to define a unique sycl kernel name
template <typename, typename, typename, char, char> class GemmCuteName;
template <typename, typename, typename, char, char, int> class GemmCuteName;

template <char layoutA, char layoutB, typename ElementA, typename ElementB,
typename ElementS, typename ElementD>
Expand All @@ -236,20 +256,34 @@ void MoEGEMMLauncher(const ElementA *activations, const ElementB *weights,
auto dummy_group_problem_shape =
cutlass::gemm::GroupProblemShape<Shape<int, int, int>>{
1, &dummy_problem_shape, nullptr};
using TileShape = Shape<_256, _128, _32>;
constexpr bool use_4x8_sg =
(sizeof_bits_v<ElementB> < sizeof_bits_v<ElementA>);
using WGTileShape =
conditional_t<use_4x8_sg, Shape<_256, _256, _32>, Shape<_256, _128, _32>>;
using ClusterShape = Shape<_1, _1, _1>;
auto scheduler_params =
PersistentTileSchedulerXeMoE<ProblemShape>::to_underlying_arguments(
dummy_group_problem_shape, TileShape{}, ClusterShape{}, hw_info,
dummy_group_problem_shape, WGTileShape{}, ClusterShape{}, hw_info,
PersistentTileSchedulerXeMoE<ProblemShape>::Arguments{
1, RasterOrderOptions::AlongN});
auto group_distribution =
PersistentTileSchedulerXeMoE<ProblemShape>::get_grid_shape(
scheduler_params, dummy_group_problem_shape, TileShape{},
scheduler_params, dummy_group_problem_shape, WGTileShape{},
ClusterShape{}, hw_info,
PersistentTileSchedulerXeMoE<ProblemShape>::Arguments{
1, RasterOrderOptions::AlongN});
auto mma = choose_tiled_mma(activations, weights);

using SGLayout8x2 =
Layout<Shape<_8, _2, _1>, Stride<_2, _1, _0>>; // 8x2 SG tiling, n-major
using SGLayout4x8 =
Layout<Shape<_4, _8, _1>, Stride<_8, _1, _0>>; // 4x8 SG tiling, n-major
using SGLayout = conditional_t<use_4x8_sg, SGLayout4x8, SGLayout8x2>;

auto mma = choose_tiled_mma(activations, weights, outputs);
constexpr auto wg_n = get<1>(mma.tile_mnk());
constexpr auto sg_n = wg_n / get<1>(SGLayout{}.shape());
constexpr auto q_group_size = 32;

auto MaxThreadsPerWorkgroup = size(mma);
dim3 local_range{MaxThreadsPerWorkgroup, 1, 1};

Expand All @@ -268,16 +302,17 @@ void MoEGEMMLauncher(const ElementA *activations, const ElementB *weights,

GPU_Clock timer;
timer.start();
auto event = Q.parallel_for<
GemmCuteName<ElementA, ElementB, ElementD, layoutA, layoutB>>(
auto event = Q.parallel_for<GemmCuteName<ElementA, ElementB, ElementD,
layoutA, layoutB, q_group_size>>(
sycl::nd_range<3>(global, local), kernel_props, [=](auto) {
// Can also use void for copy atoms.
// In that case, they will be chosen automatically.
MoE::MoEGEMM<XE_LOAD_2D<16, 32, 32, 16>,
XE_LOAD_2D_VNNI<16, 32, 16, 16>, XE_STORE_2D<16, 8, 32>,
'R', 'R', 'R'>(activations, weights, scales, outputs, mma,
num_rows_per_expert_device, num_experts,
gemm_n, gemm_k, scheduler_params);
'R', 'R', 'R', sg_n, wg_n, q_group_size>(
activations, weights, scales, outputs, mma,
num_rows_per_expert_device, num_experts, gemm_n, gemm_k,
scheduler_params);
});
EventManager::getInstance().addEvent(event);
Q.wait_and_throw();
Expand Down Expand Up @@ -413,8 +448,7 @@ int main(int argc, const char **argv) {
{6, 13, 123, 28, 197, 0, 202, 69, 0, 6, 0, 21, 1434, 1582, 11, 0, 6,
0, 7, 190, 4, 1700, 6, 434, 1886, 0, 14, 28, 8, 30, 25, 18},
{5, 27, 1442, 18, 0, 6, 0, 73, 6, 781, 0, 1915, 291, 649, 98, 4,
33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9}
};
33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9}};

for (int i = 0; i < num_layers; i++) {
launcher(total_rows_for_each_expert[i], 5760, 2880, num_experts);
Expand Down
233 changes: 223 additions & 10 deletions examples/12_bmg_moe_gemm_cute_interface/moe_gemms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ template <
CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
BTensor const &B, // (N,K)
DTensor &D, // (M,N)
Coord<int, int, cute::Underscore, int> blk_coord,
Coord<int, int, cute::Underscore, int> &blk_coord,
TiledMMA const &mma) {
auto item = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
auto local_id = item.get_local_linear_id();
Expand Down Expand Up @@ -100,10 +100,10 @@ CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
auto thr_copy_b = tiled_copy_b.get_slice(local_id);
auto thr_copy_d = tiled_copy_d.get_slice(local_id);

auto tCrA = thr_mma.partition_sg_fragment_A(gA(_, _, 0));
auto tCrB = thr_mma.partition_sg_fragment_B(gB(_, _, 0));
auto tCrD = thr_mma.partition_sg_fragment_C(gD);
auto tCrD_final = thr_copy_d.partition_sg_fragment_S(gD);
auto tDrA = thr_mma.partition_sg_fragment_A(gA(_, _, 0));
auto tDrB = thr_mma.partition_sg_fragment_B(gB(_, _, 0));
auto tDrD = thr_mma.partition_sg_fragment_C(gD);
auto tDrD_final = thr_copy_d.partition_sg_fragment_S(gD);

auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_, _, 0));
auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_, _, 0));
Expand Down Expand Up @@ -145,14 +145,227 @@ CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
prefetch(prefetch_b, pBgB(_, _, _, prefetch_k));
}

reorder(tArA, tCrA);
reorder(tBrB, tCrB);
reorder(tArA, tDrA);
reorder(tBrB, tDrB);

cute::gemm(mma, tCrA, tCrB, tCrD);
cute::gemm(mma, tDrA, tDrB, tDrD);
barrier_wait(barrier_scope);
}
reorder(tCrD, tCrD_final);
copy(tiled_copy_d, tCrD_final, tCgD);
reorder(tDrD, tDrD_final);
copy(tiled_copy_d, tDrD_final, tCgD);
}

template <class GmemTiledCopyA, class GmemTiledCopyB, class GmemTiledCopyD,
int SG_N, int WG_N, int q_group_size, class ATensor, class BTensor,
class STensor, class DTensor, class TiledMMA,
class = std::enable_if_t<
!cute::is_void_v<typename STensor::element_type> &&
is_any_of_v<typename BTensor::element_type, float_e2m1_t,
float_e4m3_t, float_e5m2_t, int4_t> &&
is_any_of_v<typename STensor::element_type, float_ue8m0_t, half_t,
bfloat16_t> &&
is_any_of_v<typename ATensor::element_type, bfloat16_t, half_t>>>
CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
BTensor const &B, // (N,K)
STensor const &S, // (K/q_group_size, N)
DTensor &D, // (M,N)
Coord<int, int, cute::Underscore, int> &blk_coord,
TiledMMA const &mma) {
auto item = sycl::ext::oneapi::this_work_item::get_nd_item<2>();
auto wg_m = int(item.get_group(1));
auto wg_n = int(item.get_group(0));
auto local_id = int(item.get_local_id(0));
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
uint32_t sg_id = sg.get_group_linear_id();
uint32_t lane = sg.get_local_linear_id();

auto total_N = get<0>(B.shape());

Tensor cA = make_identity_tensor(A.shape()); // (M,K)
Tensor cB = make_identity_tensor(B.shape()); // (N,K)
Tensor cD = make_identity_tensor(D.shape()); // (M,N)
Tensor cS = make_identity_tensor(S.shape()); // (K/q_group_size,N)
Tensor cScales_per_sg =
make_identity_tensor(make_shape(Int<1>{}, Int<SG_N>{}));
auto wg_tile = mma.tile_mnk();
auto wg_coord = make_coord(wg_m, wg_n, 0);

Tensor gA = local_tile(cA, select<0, 2>(wg_tile),
make_coord(wg_m, _)); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(cB, select<1, 2>(wg_tile),
make_coord(wg_n, _)); // (BLK_N,BLK_K,k)
Tensor gD =
local_tile(cD, wg_tile, wg_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)

constexpr int num_N_SG_tiles = WG_N / SG_N;
constexpr int num_scales_per_col = (SG_N == 32) ? 4 : 2;

// When we use E8M0, the compiler behaves differently & loads more data than
// needed. The rest is discarded.
// The scales might be FP16 or BF16 in case of int4 weights
using scaleLoadType =
conditional_t<is_same_v<typename STensor::element_type, float_ue8m0_t>,
int8_t, int16_t>;

auto S_tile = coalesce(local_tile(S, make_shape(Int<1>{}, get<1>(wg_tile)),
make_coord(_, wg_n)));

auto copy_a = get_block_2d_copy_A<GmemTiledCopyA>(mma, A);
auto copy_b = get_block_2d_copy_B<GmemTiledCopyB>(mma, B);
auto copy_d = get_block_2d_copy_D<GmemTiledCopyD>(mma, D);

auto thr_mma = mma.get_slice(local_id);
auto thr_copy_a = copy_a.get_slice(local_id);
auto thr_copy_b = copy_b.get_slice(local_id);
auto thr_copy_d = copy_d.get_slice(local_id);

auto tDrA = thr_mma.partition_sg_fragment_A(gA(_, _, 0));
auto tDrB = thr_mma.partition_sg_fragment_B(gB(_, _, 0));

auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_, _, 0));
auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_, _, 0));
auto tDrD = thr_mma.partition_sg_fragment_C(gD);

Tensor tAgA = thr_copy_a.partition_S(gA);
Tensor tBgB = thr_copy_b.partition_S(gB);
Tensor tDgD = thr_copy_d.partition_D(gD);

auto prefetch_a = make_block_2d_prefetch(copy_a);
auto prefetch_b = make_block_2d_prefetch(copy_b);

auto thr_prefetch_A = prefetch_a.get_slice(local_id);
auto thr_prefetch_B = prefetch_b.get_slice(local_id);

auto pAgA = thr_prefetch_A.partition_S(gA);
auto pBgB = thr_prefetch_B.partition_S(gB);

const int prefetch_dist = 3;

constexpr int barrier_scope = 2;

int k_tile_count = ceil_div(shape<1>(A), get<2>(wg_tile));
int k_tile_prefetch = 0;
constexpr int num_threads_per_sg = 16;

typename STensor::element_type
frag[num_scales_per_col / 2]; // per-thread registers (compiler
// will keep in regs)
float frag_fp32[num_scales_per_col];
// assuming SG_K = WG_K
constexpr int frequency_scale_change = q_group_size / get<2>(wg_tile);
Tensor scales_e8m0 =
make_tensor(make_rmem_ptr(frag),
make_layout(make_shape(Int<num_scales_per_col / 2>{})));
Tensor scales_float =
make_tensor(make_rmem_ptr(frag_fp32),
make_layout(make_shape(Int<num_scales_per_col>{})));

auto srcTVLayout = make_layout(
make_shape(Int<num_threads_per_sg>{}, Int<num_scales_per_col / 2>{}),
make_stride(Int<1>{}, Int<num_threads_per_sg>{}));
auto dstTVLayout = make_layout(
make_shape(make_shape(Int<2>{}, Int<num_threads_per_sg / 2>{}),
make_shape(Int<num_scales_per_col / 2>{})),
make_stride(make_stride(Int<0>{}, Int<1>{}), make_stride(Int<8>{})));
auto scales_e8m0_sg_tensor = make_subgroup_tensor(scales_e8m0, srcTVLayout);
auto scales_float_sg_tensor = make_subgroup_tensor(scales_float, dstTVLayout);

/* Warm up loops with prefetch to L1 */
CUTE_UNROLL
for (; k_tile_prefetch < prefetch_dist; k_tile_prefetch++) {
prefetch(prefetch_a, pAgA(_, _, _, k_tile_prefetch));
prefetch(prefetch_b, pBgB(_, _, _, k_tile_prefetch));
}
/* Main loop */
for (int k_tile = 0; k_tile < k_tile_count; k_tile++, k_tile_prefetch++) {
barrier_arrive(barrier_scope);
copy(copy_b, tBgB(_, _, _, k_tile), tBrB);
prefetch(prefetch_b, pBgB(_, _, _, k_tile_prefetch));
reorder(tBrB, tDrB);

if (k_tile % frequency_scale_change == 0) {
auto scales_tensor = make_tensor(
make_gmem_ptr(reinterpret_cast<scaleLoadType *>(
static_cast<void *>(cute::raw_pointer_cast(
S_tile.data() + (SG_N * (sg_id % num_N_SG_tiles)) +
(k_tile / frequency_scale_change) * total_N)))),
make_layout(make_shape(Int<1>{}, Int<SG_N>{})));
auto copy_scales = make_block_2d_copy(
XE_LOAD_2D<sizeof_bits_v<typename STensor::element_type>, 1, SG_N,
SG_N>{},
scales_tensor);
auto thr_copy_scales = copy_scales.get_slice(lane);
auto scales_per_thread = thr_copy_scales.partition_S(cScales_per_sg);
copy(copy_scales, scales_per_thread(_, 0, 0), scales_e8m0);
reorder(scales_e8m0_sg_tensor, scales_float_sg_tensor);
if (k_tile != (k_tile_count - frequency_scale_change)) {
auto next_scales_tensor = make_tensor(
make_gmem_ptr(reinterpret_cast<scaleLoadType *>(
static_cast<void *>(cute::raw_pointer_cast(
S_tile.data() + (SG_N * (sg_id % num_N_SG_tiles)) +
((k_tile / frequency_scale_change) + 1) * total_N)))),
make_layout(make_shape(Int<1>{}, Int<SG_N>{})));
auto prefetch_scales = make_block_2d_prefetch<1>(
make_shape(Int<1>{}, Int<SG_N>{}), next_scales_tensor);
auto thr_prefetch_scales = prefetch_scales.get_slice(lane);
auto pSgS = thr_prefetch_scales.partition_S(cScales_per_sg);
prefetch(prefetch_scales, pSgS(_, 0, 0));
}
}
copy(copy_a, tAgA(_, _, _, k_tile), tArA);
prefetch(prefetch_a, pAgA(_, _, _, k_tile_prefetch));
reorder(tArA, tDrA);
// Instead of hardcoding, figure out CuTe algebra based
// transformations that can lead to generic code.
auto scale0 = scales_float_sg_tensor[0];
auto scale1 = scales_float_sg_tensor[1];
if (num_scales_per_col == 4) {
auto scale2 = scales_float_sg_tensor[2];
auto scale3 = scales_float_sg_tensor[3];
CUTE_UNROLL
for (int i = 0; i < 16; i += 2) {
tDrB[i] = static_cast<typename ATensor::element_type>(
scale0 * static_cast<float>(tDrB[i]));
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
scale1 * static_cast<float>(tDrB[i + 1]));
}
CUTE_UNROLL
for (int i = 16; i < 32; i += 2) {
tDrB[i] = static_cast<typename ATensor::element_type>(
scale2 * static_cast<float>(tDrB[i]));
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
scale3 * static_cast<float>(tDrB[i + 1]));
}
CUTE_UNROLL
for (int i = 32; i < 48; i += 2) {
tDrB[i] = static_cast<typename ATensor::element_type>(
scale0 * static_cast<float>(tDrB[i]));
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
scale1 * static_cast<float>(tDrB[i + 1]));
}
CUTE_UNROLL
for (int i = 48; i < 64; i += 2) {
tDrB[i] = static_cast<typename ATensor::element_type>(
scale2 * static_cast<float>(tDrB[i]));
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
scale3 * static_cast<float>(tDrB[i + 1]));
}
} else {
CUTE_UNROLL
for (int i = 0; i < 32; i += 2) {
tDrB[i] = static_cast<typename ATensor::element_type>(
scale0 * static_cast<float>(tDrB[i]));
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
scale1 * static_cast<float>(tDrB[i + 1]));
}
}

gemm(mma, tDrA, tDrB, tDrD);
barrier_wait(barrier_scope);
}
auto tDrD_final = thr_copy_d.partition_sg_fragment_S(gD);
reorder(tDrD, tDrD_final);
copy(copy_d, tDrD_final, tDgD);
}

} // namespace MoE
Loading
Loading