From 49ea395f325b1f36ac156e6ac3430335c6cae2c5 Mon Sep 17 00:00:00 2001 From: ThruptiRajLakshmanaGowda Date: Thu, 15 Jan 2026 17:39:04 +0000 Subject: [PATCH] Fixing GEMM Multi D on Tile Engine --- tile_engine/ops/gemm/gemm_instance_builder.py | 346 +++++++++--------- 1 file changed, 174 insertions(+), 172 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 9c60c565de..3607bbc59a 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -676,36 +676,38 @@ def populate_launch( if self.kernel_name_prefix == "gemm_multi_d": instance_code += """ - // Kernel type - using GemmKernelMultiD = ck_tile::GemmKernelMultiD; - - // Kernel arguments - auto kargs = GemmKernelMultiD::MakeKernelArgs(args); - - if (!GemmKernelMultiD::IsSupportedArgument(kargs)) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); - } + // Kernel type + using GemmKernelMultiD = ck_tile::GemmKernelMultiD; + + // Kernel arguments + auto kargs = GemmKernelMultiD::MakeKernelArgs(args); + + if (!GemmKernelMultiD::IsSupportedArgument(kargs)) { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + } - // Get grid and block sizes - const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = GemmKernelMultiD::BlockSize(); - - if(stream.log_level_ > 0) { - std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - }""" + // Get grid and block sizes + const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernelMultiD::BlockSize(); + + if(stream.log_level_ > 0) { + std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + }""" instance_code += f""" - // Launch kernel - constexpr int kBlockPerCu = {k_block_per_cu}; - float ave_time = ck_tile::launch_kernel( - stream, - ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); - - return ave_time; - }};""" + // Launch kernel + constexpr int kBlockPerCu = {k_block_per_cu}; + float ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernelMultiD{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }} +}}; +""" elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: instance_code += f""" @@ -713,32 +715,32 @@ def populate_launch( // Kernel type using GemmKernel = ck_tile::GemmKernel; - // Kernel arguments - auto kargs = GemmKernel::MakeKernelArgs(args); - - if (!GemmKernel::IsSupportedArgument(kargs)) {{ - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); - }} - - // Get grid and block sizes - const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; - const dim3 blocks = GemmKernel::BlockSize(); - - if(stream.log_level_ > 0) {{ - std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' - << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" - << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" - << std::endl; - }}""" + // Kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + // Get grid and block sizes + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' + << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }}""" instance_code += f""" - // Launch kernel - constexpr int kBlockPerCu = {k_block_per_cu}; - float ave_time = ck_tile::launch_kernel( - stream, - ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); - - return ave_time; + // Launch kernel + constexpr int kBlockPerCu = {k_block_per_cu}; + float ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; }} }}; """ @@ -747,8 +749,8 @@ def populate_launch( def populate_epilogue(self, epilogue): instance_code = """ - // Epilogue - """ + // Epilogue + """ if epilogue == "cshuffle": if self.kernel_name_prefix == "gemm_universal": @@ -769,145 +771,145 @@ def populate_epilogue(self, epilogue): def populate_cshuffle_gemm_universal(self): instance_code = """ - using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TileM, // kM_ - TileN, // kN_ - WarpPerBlock_M, // MWave_ - WarpPerBlock_N, // NWave_ - WarpTileM, // MPerXdl_ - WarpTileN, // NPerXdl_ - WarpTileK, // KPerXdl_ - TransposeC, // isCTransposed_ - NumWaveGroups>; // kNumWaveGroups_ - - using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TileM, // kM_ + TileN, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + NumWaveGroups>; // kNumWaveGroups_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_cshuffle_gemm_multi_d(self): instance_code = """ - using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< - ADataType, - BDataType, - DsDataType, - AccDataType, - CDataType, - DsLayout, - CLayout, - ElementWiseFn, - TileM, // kM_ - TileN, // kN_ - WarpPerBlock_M, // MWave_ - WarpPerBlock_N, // NWave_ - WarpTileM, // MPerXdl_ - WarpTileN, // NPerXdl_ - WarpTileK, // KPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + DsDataType, + AccDataType, + CDataType, + DsLayout, + CLayout, + ElementWiseFn, + TileM, // kM_ + TileN, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_cshuffle_gemm_preshuffle(self): instance_code = """ - using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TileM, // kM_ - TileN, // kN_ - WarpPerBlock_M, // MWave_ - WarpPerBlock_N, // NWave_ - WarpTileM, // MPerXdl_ - WarpTileN, // NPerXdl_ - WarpTileK, // KPerXdl_ - TransposeC, // isCTransposed_ - NumWaveGroups, // kNumWaveGroups_ - false, // FixedVectorSize_ - 1, // VectorSizeC_ - PermuteN>; // isPermuteN_ - - using GemmEpilogue = ck_tile::CShuffleEpilogue;""" + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TileM, // kM_ + TileN, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + NumWaveGroups, // kNumWaveGroups_ + false, // FixedVectorSize_ + 1, // VectorSizeC_ + PermuteN>; // isPermuteN_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code def populate_default_gemm_universal(self): instance_code = """ - using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TileM, // kM_ - TileN, // kN_ - kPadM, - kPadN, - WarpTileM, // kMPerXdl_ - WarpTileN, // kNPerXdl_ - WarpTileK, // kKPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TileM, // kM_ + TileN, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def populate_default_gemm_multi_d(self): instance_code = """ - using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< - ADataType, - BDataType, - DsDataType, - AccDataType, - CDataType, - DsLayout, - CLayout, - ElementWiseFn, - TileM, // kM_ - TileN, // kN_ - kPadM, - kPadN, - WarpTileM, // kMPerXdl_ - WarpTileN, // kNPerXdl_ - WarpTileK, // kKPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + DsDataType, + AccDataType, + CDataType, + DsLayout, + CLayout, + ElementWiseFn, + TileM, // kM_ + TileN, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def populate_default_gemm_preshuffle(self): instance_code = """ - using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< - ADataType, - BDataType, - ck_tile::tuple<>, // DsDataType - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TileM, // kM_ - TileN, // kN_ - kPadM, - kPadN, - WarpTileM, // kMPerXdl_ - WarpTileN, // kNPerXdl_ - WarpTileK, // kKPerXdl_ - TransposeC>; // isCTransposed_ - - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TileM, // kM_ + TileN, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue;""" return instance_code def _generate_cmake_individual_targets(self, kernel_list):