Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 174 additions & 172 deletions tile_engine/ops/gemm/gemm_instance_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,69 +676,71 @@ def populate_launch(
if self.kernel_name_prefix == "gemm_multi_d":
instance_code += """

// Kernel type
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}
// Kernel type
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;

// Kernel arguments
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);

if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}

// Get grid and block sizes
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = GemmKernelMultiD::BlockSize();
if(stream.log_level_ > 0) {
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}"""
// Get grid and block sizes
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = GemmKernelMultiD::BlockSize();

if(stream.log_level_ > 0) {
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}"""

instance_code += f"""
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));

return ave_time;
}};"""
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));

return ave_time;
}}
}};
"""

elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
instance_code += f"""

// Kernel type
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;

// Kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}

// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}"""
// Kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);

if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}

// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();

if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}"""

instance_code += f"""
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));

return ave_time;
}}
}};
"""
Expand All @@ -747,8 +749,8 @@ def populate_launch(
def populate_epilogue(self, epilogue):
instance_code = """

// Epilogue
"""
// Epilogue
"""

if epilogue == "cshuffle":
if self.kernel_name_prefix == "gemm_universal":
Expand All @@ -769,145 +771,145 @@ def populate_epilogue(self, epilogue):

def populate_cshuffle_gemm_universal(self):
instance_code = """
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
NumWaveGroups>; // kNumWaveGroups_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
NumWaveGroups>; // kNumWaveGroups_

using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
return instance_code

def populate_cshuffle_gemm_multi_d(self):
instance_code = """
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC>; // isCTransposed_

using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
return instance_code

def populate_cshuffle_gemm_preshuffle(self):
instance_code = """
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
NumWaveGroups, // kNumWaveGroups_
false, // FixedVectorSize_
1, // VectorSizeC_
PermuteN>; // isPermuteN_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
NumWaveGroups, // kNumWaveGroups_
false, // FixedVectorSize_
1, // VectorSizeC_
PermuteN>; // isPermuteN_

using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
return instance_code

def populate_default_gemm_universal(self):
instance_code = """
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_

using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
return instance_code

def populate_default_gemm_multi_d(self):
instance_code = """
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_

using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
return instance_code

def populate_default_gemm_preshuffle(self):
instance_code = """
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_

using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
return instance_code

def _generate_cmake_individual_targets(self, kernel_list):
Expand Down