Skip to content

Commit f10213c

Browse files
committed
Updated 00_bmg_gemm and 05_bmg_gemm to new Atom API
1 parent 1c847b7 commit f10213c

11 files changed

+2196
-54
lines changed

examples/00_bmg_gemm/legacy/00_bmg_gemm_padded.cpp

Lines changed: 467 additions & 0 deletions
Large diffs are not rendered by default.

examples/00_bmg_gemm/legacy/00_bmg_gemm_with_sycl_queue.cpp

Lines changed: 414 additions & 0 deletions
Large diffs are not rendered by default.

examples/00_bmg_gemm/legacy/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,19 @@ cutlass_example_add_executable(
3939
TEST_LARGE
4040
TEST_SMALL_SHAPE
4141
)
42+
43+
set(TEST_SMALL_SHAPE_PADDABLE --m=1 --n=1 --k=2 --l=2)
44+
cutlass_example_add_executable(
45+
00_bmg_gemm_padded_legacy
46+
00_bmg_gemm_padded.cpp
47+
TEST_COMMAND_OPTIONS
48+
TEST_BATCHES
49+
TEST_SMALL_SHAPE_PADDABLE
50+
)
51+
52+
cutlass_example_add_executable(
53+
00_bmg_gemm_with_sycl_queue_legacy
54+
00_bmg_gemm_with_sycl_queue.cpp
55+
TEST_COMMAND_OPTIONS
56+
TEST_BATCHES
57+
)

examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_softmax.cpp

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -402,54 +402,39 @@ int main(int argc, const char** argv)
402402
using LayoutC = cutlass::layout::RowMajor;
403403
using LayoutD = cutlass::layout::RowMajor;
404404

405-
// using GmemTiledCopyA = XE_2D_U16x8x16_LD_N;
406-
<<<<<<< HEAD
407-
using GmemTiledCopyA = XE_LOAD_2D<16, 8, 16>;
408-
// using GmemTiledCopyA = void;
409-
// using GmemTiledCopyB = XE_2D_U16x16x16_LD_V;
410-
using GmemTiledCopyB = XE_LOAD_2D_VNNI<16, 16, 16>;
411-
// using GmemTiledCopyB = void;
412-
=======
413-
using GmemTiledCopyA = void;
414-
// using GmemTiledCopyB = XE_2D_U16x16x16_LD_V;
415-
using GmemTiledCopyB = void;
416-
>>>>>>> afa071e0 (epilogue test)
405+
using GmemTiledCopyA = XE_2D_U16x8x16_LD_N;
406+
using GmemTiledCopyB = XE_2D_U16x16x16_LD_V;
417407

418408
// Workgroup-level tile
419409
using TileShape = Shape<_32, _512, _32>;
420410

421-
// using TiledMma =
422-
// typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
423-
// Layout<Shape<_2, _16, _1>, Stride<_16, _1, _0>>>::TiledMMA;
424-
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, Layout<TileShape>, Layout<Shape<_2, _16, _1>, Stride<_16, _1, _0>>>::TiledMMA;
411+
using TiledMma =
412+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
413+
Layout<Shape<_2, _16, _1>, Stride<_16, _1, _0>>>::TiledMMA;
425414

426415
using EpilogueTile = Shape<_16, _32>;
427416
constexpr int PipelineStages = 3;
428-
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
429-
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;
417+
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
418+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
430419

431420
// Linear Combination + Row-wise Softmax Epilogue
432421
using EpilogueOp = cutlass::epilogue::fusion::LinCombSoftmaxRow<ElementOutput,
433-
ElementComputeEpilogue, XE_STORE_2D<32, 8, 16>/*XE_2D_U32x8x16_ST_N*/, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
422+
ElementComputeEpilogue, XE_2D_U32x8x16_ST_N, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
434423

435-
using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
424+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
436425
EpilogueTile>;
437426
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
438427
EpilogueDispatchPolicy,
439-
TiledMma,
440-
void,
428+
TileShape,
441429
ElementAccumulator,
442430
cutlass::gemm::TagToStrideC_t<LayoutC>,
443431
ElementOutput,
444432
cutlass::gemm::TagToStrideC_t<LayoutD>,
445-
FusionCallbacks,
446-
<<<<<<< HEAD
447-
//XE_2D_U32x8x16_LD_N,
448-
XE_STORE_2D<32, 8 ,16>,
449-
=======
450-
>>>>>>> afa071e0 (epilogue test)
433+
FusionCallBacks,
434+
XE_2D_U32x8x16_LD_N,
435+
void, void,
451436
void,
452-
void>;
437+
void, void>;
453438

454439
// Mainloop
455440
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<

examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_splitk.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -434,27 +434,23 @@ int main(int argc, const char** argv)
434434
using LayoutC = cutlass::layout::RowMajor;
435435
using LayoutD = cutlass::layout::RowMajor;
436436

437-
// using GmemTiledCopyA = XE_2D_U16x8x16_LD_N;
438-
using GmemTiledCopyA =void;
439-
// using GmemTiledCopyB = XE_2D_U16x16x16_LD_V;
440-
using GmemTiledCopyB = void;
437+
using GmemTiledCopyA = XE_2D_U16x8x16_LD_N;
438+
using GmemTiledCopyB = XE_2D_U16x16x16_LD_V;
441439

442440
// Workgroup-level tile
443441
using TileShape = Shape<_32, _512, _32>;
444442

445-
// using TiledMma =
446-
// typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
447-
// Layout<Shape<_2, _16, _1>, Stride<_16, _1, _0>>>::TiledMMA;
448-
449-
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, Layout<TileShape>, Layout<Shape<_2, _16, _1>, Stride<_16, _1, _0>>>::TiledMMA;
443+
using TiledMma =
444+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
445+
Layout<Shape<_2, _16, _1>, Stride<_16, _1, _0>>>::TiledMMA;
450446

451447
using EpilogueTile = Shape<_16, _32>;
452448
constexpr int PipelineStages = 3;
453449
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
454450
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;
455451

456452
using EpilogueOp = cutlass::epilogue::fusion::LinCombSplitK<ElementOutput,
457-
ElementComputeEpilogue, XE_STORE_2D<32, 8, 16>/*XE_2D_U32x8x16_ST_N*/, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
453+
ElementComputeEpilogue, XE_2D_U32x8x16_ST_N, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
458454

459455
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
460456
EpilogueTile>;
@@ -467,10 +463,6 @@ int main(int argc, const char** argv)
467463
ElementOutput,
468464
cutlass::gemm::TagToStrideC_t<LayoutD>,
469465
FusionCallBacks,
470-
<<<<<<< HEAD
471-
XE_2D_U32x8x16_LD_N,
472-
=======
473-
>>>>>>> afa071e0 (epilogue test)
474466
void,
475467
void>;
476468

examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_per_row_bias.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ struct ExampleRunner {
151151

152152
using ElementA = typename Gemm::ElementA;
153153
using ElementB = typename Gemm::ElementB;
154-
using ElementAccumulator = typename Gemm::ElementAccumulator;
154+
using ElementAcc = typename Gemm::ElementAccumulator;
155155

156156
using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
157157
using ElementC = typename Gemm::ElementC;
158158
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
159159
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
160+
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
160161
using ElementBias = typename CollectiveEpilogue::ThreadEpilogueOp::ElementBias;
161162
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
162163

@@ -210,7 +211,7 @@ struct ExampleRunner {
210211
compat::wait();
211212

212213
for(int batch = 0, offset = 0; batch < L; batch++, offset += M * N) {
213-
auto D_view =
214+
auto D_view =
214215
cutlass::TensorView(
215216
block_ref_D.get() + offset, LayoutD::packed({M, N}), cutlass::make_Coord(M, N));
216217

@@ -368,32 +369,34 @@ int main(int argc, const char** argv)
368369
using LayoutC = cutlass::layout::RowMajor;
369370
using LayoutD = cutlass::layout::RowMajor;
370371

371-
using GmemTiledCopyA = void;
372-
using GmemTiledCopyB = void;
372+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
373+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
373374

374375
// Workgroup-level tile
375376
using TileShape = Shape<_256, _256, _32>;
376377

377-
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
378+
using TiledMma =
379+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
380+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
378381

379382
constexpr int PipelineStages = 2;
380-
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
381-
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;
383+
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
384+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
382385

383386
// The Linear Combination + Per Row Bias epilogue operation
384387
using EpilogueOp = cutlass::epilogue::fusion::LinCombPerRowBias<
385388
ElementOutput, ElementComputeEpilogue, ElementBias, ElementAccumulator,
386389
ElementAccumulator, 128 / sizeof_bits_v<ElementBias>,
387390
cutlass::FloatRoundStyle::round_to_nearest>;
388391

389-
using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks<
392+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
390393
EpilogueDispatchPolicy, EpilogueOp, TileShape,
391394
decltype(tile_shape(TiledMma()))>;
392395
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
393-
EpilogueDispatchPolicy, TiledMma, void, ElementAccumulator,
396+
EpilogueDispatchPolicy, TileShape, ElementAccumulator,
394397
cutlass::gemm::TagToStrideC_t<LayoutC>, ElementOutput,
395-
cutlass::gemm::TagToStrideC_t<LayoutD>, FusionCallbacks,
396-
void, void>;
398+
cutlass::gemm::TagToStrideC_t<LayoutD>, FusionCallBacks,
399+
XE_2D_U32x8x16_LD_N, void, void, XE_2D_U32x8x16_ST_N, void, void>;
397400

398401
// Mainloop
399402
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<

0 commit comments

Comments
 (0)