Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions examples/00_bmg_gemm/00_bmg_gemm_padded.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions.
To support more input shapes using these instructions, rows of the input/output matrices are padded
to a multiple of 16 and each matrix in batch is padded to a multiple of 64, as required by these
to a multiple of 16 and each matrix in batch is padded to a multiple of 64, as required by these
instructions.

The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the
Expand Down Expand Up @@ -161,14 +161,14 @@ struct ExampleRunner {

using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementAcc = typename Gemm::ElementAccumulator;
using ElementAccumulator = typename Gemm::ElementAccumulator;

using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementC = typename Gemm::ElementC;
using ElementD = typename Gemm::ElementD;
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;


using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;

Expand Down Expand Up @@ -200,15 +200,15 @@ struct ExampleRunner {

bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
auto [M, N, K, L] = problem_size;

// Padded values
// The inner dimension is padded. Since this example is all RowMajor,
// we require the following:
int N_B = cute::round_up(N, AlignElemB);
int N_C = cute::round_up(N, AlignElemC);
int N_D = cute::round_up(N, AlignElemD);
int K_A = cute::round_up(K, AlignElemA);

int AlignmentOuter = AlignmentPtr / AlignmentInner;
int M_ACD = cute::round_up(M, AlignmentOuter);
int K_B = cute::round_up(K, AlignmentOuter);
Expand Down Expand Up @@ -383,31 +383,35 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

// The 2D block copy operations used for the A and B matrices
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
// [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects
// appropriate 2D block copy operations for matrices A and B. Alternatively, you can
// explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI,
// or XE_LOAD_2D_TRANSPOSE.
// Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

// A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
// hardware (sub-groups for Intel BMG) and iterations by each sub-group.
//
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
// (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
// TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses
// the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with
//float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1).
// The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// single contiguous chunk of the work-group TileShape. For this configuration, this implies that
// each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
// 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
// performance reasons.
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;

// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
// For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;

// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
// (D = alpha * (A*B) + beta * C)
Expand All @@ -418,22 +422,21 @@ int main(int argc, const char** argv)

// FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
// policy/architecture) and defines the epilogue arguments.
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;
// GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any
// auxiliary data required
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
EpilogueDispatchPolicy,
TileShape,
void, // Epilogue tile (void = automatic)
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutC>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
FusionCallBacks,
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
void, void,
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
void, void>;
FusionCallbacks,
void, // The copy atom used to load matrix C (void = automatic)
void>; // The copy atom used to store matrix D (void = automatic)

// GEMM Mainloop - iteration over blocks in K dimension
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
Expand Down
47 changes: 27 additions & 20 deletions examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,12 @@ struct ExampleRunner {

using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementAcc = typename Gemm::ElementAccumulator;
using ElementAccumulator = typename Gemm::ElementAccumulator;

using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementC = typename Gemm::ElementC;
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;

using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;

Expand Down Expand Up @@ -348,42 +347,50 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
// [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects
// appropriate 2D block copy operations for matrices A and B. Alternatively, you can
// explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI,
// or XE_LOAD_2D_TRANSPOSE.
// Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

// The Tile of this layout describes how 8x4x1 sub-groups tile the TileShape of <256, 256, 32>.
// This permutation (which can be thought of as a scatter operation on the default tiling)
// ensures that each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations)
// See 0t_mma_atom.md#TiledMMAs for more info.
// Sub-groups are arranged row-major (stride 4,1,0) for performance reasons.
using TiledMma =
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
// A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
// hardware (sub-groups for Intel BMG) and iterations by each sub-group.
//
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses
// the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with
//float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1).
// The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// single contiguous chunk of the work-group TileShape. For this configuration, this implies that
// each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
// 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
// performance reasons.
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;

constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
EpilogueDispatchPolicy,
TileShape,
void,
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutC>,
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>,
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
void, void,
XE_2D_U32x8x16_ST_N,
void, void>;
FusionCallbacks,
void,
void>;

// Mainloop
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
Expand Down
Loading
Loading