diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp index cf8dd31c3f3..78d98e92ce0 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp @@ -96,11 +96,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -108,7 +108,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -174,6 +174,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideD = f_get_default_stride(M, N, StrideD, D0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp index e4033e5bacf..089404757ac 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp @@ -94,11 +94,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -106,7 +106,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -133,7 +133,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 11) + else if(argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -170,6 +170,28 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp index 5817269fdfe..d5ccf7eb598 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp @@ -141,11 +141,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<4, 64, 1>, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, - 1, + 8, 8, 0, 1, @@ -233,6 +233,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideD = f_get_default_stride(M, N, StrideD, DLayout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp index 4fb1a5ab4ec..2d07bc480d8 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp @@ -95,11 +95,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm 8, 8, 0, - S<8, 32, 1>, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, - 1, + 8, 8, 0, 1, @@ -107,7 +107,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + ck::BlockGemmPipelineVersion::v1>; int main(int argc, char* argv[]) { @@ -173,6 +173,29 @@ int main(int argc, char* argv[]) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, A0Layout{}); + StrideB = f_get_default_stride(K, N, StrideB, B0Layout{}); + StrideD = f_get_default_stride(M, N, StrideD, D0Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp index 1c322fe4a73..d1c6f30a14b 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -12,16 +12,17 @@ namespace ck { -template + bool DoTranspose, + index_t NumThreadScratch = 1> struct ThreadGroupTransferGlobal { static constexpr auto I0 = Number<0>{}; @@ -32,60 +33,63 @@ struct ThreadGroupTransferGlobal static constexpr auto I5 = Number<5>{}; static constexpr auto I6 = Number<6>{}; - static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + static constexpr index_t nSrc = SrcDescs::Size(); using Index = MultiIndex; - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - - __device__ ThreadGroupTransferGlobal(const SrcDesc& src_desc, - const DstDesc& dst_desc, - const Index& src_block_slice_origin, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) - : src_coord_(make_tensor_coordinate(src_desc, src_block_slice_origin)), + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + __device__ + ThreadGroupTransferGlobal(const SrcDescs& src_descs, + const DstDesc& dst_desc, + const StaticallyIndexedArray& src_block_slice_origins, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_block_slice_origins)), dst_coord_(make_tensor_coordinate(dst_desc, dst_block_slice_origin)), element_op_(element_op) { } - template - __device__ void RunRead(const SrcDesc& src_desc, const GridBufferType& grid_buf) + template + __device__ static auto generate_vectors() { - constexpr auto src_access_lengths = NumberOfIterations{}; - constexpr auto src_dim_access_order = IterationOrder{}; - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - constexpr auto ordered_fwd_step = StepsPerIteration{}; + auto data_types = DataTypes_{}; - // make forward steps - // forward step for each iteration just add 1 - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? ordered_fwd_step[i] : 0; - }); + constexpr index_t num = data_types.Size(); - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - // backward step at the end of the dimension iteration subtract IterationLength - 1 - const auto src_backward_steps = generate_tuple( + return generate_tuple( [&](auto i) { - Index backward_step_idx; + using DataType = remove_cvref_t; - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) - ? (-src_access_lengths[i] + 1) * ordered_fwd_step[i] - : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); + return vector_type_maker_t{}; }, - Number{}); + Number{}); + } + + template = false> + __device__ void RunRead(SrcDescs& src_descs, + const GridBufferTypes& grid_bufs, + Number thread_scratch_id = Number{}) + { + constexpr auto src_access_lengths = NumberOfIterations{}; + constexpr auto src_dim_access_order = IterationOrder{}; + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + constexpr auto ordered_fwd_step = StepsPerIteration{}; static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward @@ -157,10 +161,26 @@ struct ThreadGroupTransferGlobal }, Number{}); - // check if src element is valid - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - oob_thread_scratch_.template SetAsType(vgpr_data_idx_seq, is_src_valid); + auto src_vectors = generate_vectors(); + bool oob_val = true; + + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + // check if src element is valid + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + // Load data from memory in src_vector first + auto index = is_src_valid || !DoTranspose ? src_coords_[i].GetOffset() : 0; + src_vectors(i).template AsType()(I0) = + grid_bufs[i].template Get(index, true); + }); + + oob_thread_scratch_(thread_scratch_id) + .template SetAsType(vgpr_data_idx_seq, oob_val); // Vector length of elementwise operation constexpr auto get_elem_op_vec_len = []() { @@ -185,57 +205,105 @@ struct ThreadGroupTransferGlobal } }; - // This is 1 for pass through because internally it's doing type conversion constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); - using src_vector_container = vector_type_maker_t; - using src_vector_container_t = typename src_vector_container::type; - - using elem_op_vec_t = typename vector_type::type; - using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; - dst_vector_type op_r_v; - // Load data from memory in src_vector first - auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0; - src_vector_container src_vector = src_vector_container{ - grid_buf.template Get(index, true)}; - // apply the src elementwise op and convert to DstData under the hood if needed static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { - element_op_(op_r_v.template AsType()(idx), - src_vector.template AsType()[idx]); + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[idx]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { + using elem_op_vec_t = typename vector_type::type; + + return op_r_v.template AsType()(idx); + }, + Number<1>{}); + + // apply pointwise function + unpack2(element_op_, dst_data_refs, src_data_refs); }); // store result in dvgpr_ (static array holding loaded data). // At this point data is already converted to DstData type and // the elementwise operation has been applied - src_dvgpr_.template SetAsType(vgpr_data_idx_seq, - op_r_v.template AsType()[I0]); - - // For each dimension move fwd, bwd or don't move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); - } - else + src_dvgpr_(thread_scratch_id) + .template SetAsType(vgpr_data_idx_seq, + op_r_v.template AsType()[I0]); + + // Move each src coordinate + static_for<0, nSrc, 1>{}([&](auto iSrc) { + // make forward steps + // forward step for each iteration just add 1 + const auto src_forward_steps = generate_tuple( + [&](auto iDim) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (iDim.value == j.value) ? ordered_fwd_step[iDim] : 0; + }); + return make_tensor_coordinate_step(src_descs[iSrc], forward_step_idx); + }, + Number{}); + + // make backward steps + // backward step at the end of the dimension iteration subtract IterationLength - 1 + const auto src_backward_steps = generate_tuple( + [&](auto iDim) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = + (iDim.value == j.value) + ? (-src_access_lengths[iDim] + 1) * ordered_fwd_step[iDim] + : 0; + }); + return make_tensor_coordinate_step(src_descs[iSrc], backward_step_idx); + }, + Number{}); + + // For each dimension move fwd, bwd or don't move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(src_descs[iSrc], + src_coords_(iSrc), + src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate(src_descs[iSrc], + src_coords_(iSrc), + src_backward_steps[src_dim_access_order[i]]); + } } - } + }); }); }); } - template - __device__ void RunWrite(const DstDesc& dst_desc, BlockBufferType& dst_buf) + template + __device__ void RunWrite(const DstDesc& dst_desc, + BlockBufferType& dst_buf, + Number thread_scratch_id = Number{}) { using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; @@ -272,9 +340,10 @@ struct ThreadGroupTransferGlobal }, Number{}); - auto op_r = src_dvgpr_.template GetAsType(vgpr_data_idx_seq); + auto op_r = + src_dvgpr_(thread_scratch_id).template GetAsType(vgpr_data_idx_seq); const bool is_src_valid = - oob_thread_scratch_.template GetAsType(vgpr_data_idx_seq); + oob_thread_scratch_(thread_scratch_id).template GetAsType(vgpr_data_idx_seq); auto op_r_v = is_src_valid ? op_r : dst_vector_t(0); dst_dvgpr_.template SetAsType(vgpr_data_idx_seq, op_r_v); }); @@ -404,10 +473,12 @@ struct ThreadGroupTransferGlobal }); } - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) { - const auto adjusted_step = make_tensor_coordinate_step(src_desc, step); - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + static_for<0, nSrc, 1>{}([&](auto iSrc) { + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], step); + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + }); } private: @@ -443,10 +514,10 @@ struct ThreadGroupTransferGlobal decltype(src_oob_thread_scratch_desc_), true>; - ThreadScratchData src_dvgpr_; + StaticallyIndexedArray src_dvgpr_; ThreadScratchData dst_dvgpr_; - OOBThreadScratch oob_thread_scratch_; - SrcCoord src_coord_; + StaticallyIndexedArray oob_thread_scratch_; + SrcCoords src_coords_; DstCoord dst_coord_; const ElementwiseOperation element_op_; }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 96387c6f64d..4d5c052e022 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -488,6 +488,19 @@ struct ABTransferThreadTiles { return make_dynamic_buffer(p_shared_AB, size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + if constexpr(numElements > 1) + { + return array; + } + else + { + return array[I0]; + } + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp index ad9af92ae58..fb6d1451d3d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp @@ -133,6 +133,19 @@ struct ABTransferThreadTilesPreShuffle { return make_static_buffer(size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + if constexpr(numElements > 1) + { + return array; + } + else + { + return array[I0]; + } + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index caf468d6cbb..63c02997506 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -318,43 +318,43 @@ struct ABTransferWaveTiles const index_t block_mn_id, const index_t) { - // Note: GlobalBufferNum is currently not used but it will be needed - // once we add other pipelines. It is currently needed only for - // consistency with the thread tiles approach - static_assert(GlobalBufferNum == 1, "single global buffer is only supported"); constexpr index_t NumABTensor = ABsDataType::Size(); - static_assert(NumABTensor == 1, "multiAB currently not supported"); - - using ABDataType = remove_cvref_t>; const auto wave_idx = GetWaveIdx(); index_t wave_idK = wave_idx[I1]; index_t wave_idMN = wave_idx[I0]; - const auto grid_lane_id = GetGridLaneIdx(); - index_t lane_group_grid = grid_lane_id[I0]; - index_t lane_local_id_grid = grid_lane_id[I1]; - const auto block_lane_id = GetBlockLaneIdx(); index_t lane_group_block = block_lane_id[I0]; index_t lane_local_id_block = block_lane_id[I1]; - return ThreadGroupTransferGlobal>; + const auto grid_lane_id = GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + return make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN, + wave_idK, + lane_group_grid, + lane_local_id_grid); + }, + Number{}); + + return ThreadGroupTransferGlobal, Sequence, Sequence, ABK1Value, - ABDoTranspose>( - grid_descriptor[I0], + ABDoTranspose, + GlobalBufferNum>( + grid_descriptor, block_descriptor, - make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN, - wave_idK, - lane_group_grid, - lane_local_id_grid), + idx_as_block_begin, make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block), ab_element_op); } @@ -398,6 +398,12 @@ struct ABTransferWaveTiles { return make_dynamic_buffer(p_shared_AB, size); } + + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + return array; + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp index bfe5b7bd08a..e1ee47770ba 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -218,45 +218,46 @@ struct ABTransferWaveTilesInterleave : ABTransferWaveTiles>; const auto wave_idx = GetWaveIdx(); index_t wave_idK = wave_idx[I1]; index_t wave_idMN = wave_idx[I0]; - const auto grid_lane_id = Base::template GetGridLaneIdx(); - index_t lane_group_grid = grid_lane_id[I0]; - index_t lane_local_id_grid = grid_lane_id[I1]; - const auto block_lane_id = GetBlockLaneIdx(); index_t lane_group_block = block_lane_id[I0]; index_t lane_local_id_block = block_lane_id[I1]; constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_; - return ThreadGroupTransferGlobal>; + const auto grid_lane_id = Base::template GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + return make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_Grid, + (wave_idMN % MNRepeatRatio) * MNRepeat_, + lane_group_grid, + lane_local_id_grid); + }, + Number{}); + + return ThreadGroupTransferGlobal, Sequence, Sequence, ABK1Value, - ABDoTranspose>( - grid_descriptor[I0], + ABDoTranspose, + GlobalBufferNum>( + grid_descriptor, block_descriptor, - make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, - wave_idK * KRepeat_Grid, - (wave_idMN % MNRepeatRatio) * MNRepeat_, - lane_group_grid, - lane_local_id_grid), + idx_as_block_begin, make_multi_index(wave_idMN / MNRepeatRatio, wave_idK * KRepeat_, (wave_idMN % MNRepeatRatio) * MNRepeat_, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 5b91e25ce5e..f1adf6282fa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -249,7 +249,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base __host__ __device__ static constexpr bool AWaveTransferApplicable() { - return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && + return !ForceThreadTileTransfer && APackedSize == 1 && ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; @@ -257,13 +257,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base __host__ __device__ static constexpr bool BWaveTransferApplicable() { - return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && + return !ForceThreadTileTransfer && BPackedSize == 1 && BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; } - // Limitations of the current implementation: - // - no multiAB #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable(); @@ -1144,19 +1142,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } - template - __device__ __forceinline__ static auto get_first_element_workaround(Type& array) - { - if constexpr(numElements > 1) - { - return array; - } - else - { - return array[I0]; - } - } - // Note: arguments k_batch and k_id should be set if splitk is used // with implicit gemm (no pointer shift but shift using tensor descriptors) template ( - get_first_element_workaround(as_grid_desc_ak0_m_ak1), + ATransfer::template get_first_element_workaround(as_grid_desc_ak0_m_ak1), a_block_desc_ak0_m_ak1, a_blockwise_copy, - get_first_element_workaround(as_grid_buf), + ATransfer::template get_first_element_workaround(as_grid_buf), a_block_buf, a_block_slice_copy_step, - get_first_element_workaround(bs_grid_desc_bk0_n_bk1), + BTransfer::template get_first_element_workaround(bs_grid_desc_bk0_n_bk1), b_block_desc_bk0_n_bk1, b_blockwise_copy, - get_first_element_workaround(bs_grid_buf), + BTransfer::template get_first_element_workaround(bs_grid_buf), b_block_buf, b_block_slice_copy_step, c_thread_buf, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp index 4cd4403436d..0dd666b3d9e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -73,14 +73,17 @@ template using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< // clang-format off - //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| - //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | - //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | - //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 76a92a19719..3587c6700c2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( Multiply, PassThrough, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp index 1607b240f66..7cb50cd954d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -71,12 +71,15 @@ template using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances = std::tuple< // clang-format off - //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| - //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | - //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | - //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 2a4aae98a53..731518257bd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( Multiply, Add, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 477d6811d29..0a67f2357ee 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances Multiply, AddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp index 71c04b3485a..c0b4cf7b9a8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp @@ -36,7 +36,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances ck::Tuple, AddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( @@ -58,7 +58,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( ck::Tuple, Add, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( @@ -80,7 +80,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( ck::Tuple<>, PassThrough, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( @@ -102,7 +102,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( ck::Tuple<>, FastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 33422fc6dbf..9176910cea6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( Multiply, FastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 639bda60179..669eb4144a1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( PassThrough, Multiply, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances( instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 7f8fea44c58..c6a812645b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_insta PassThrough, MultiplyAdd, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index b2bf9955078..2d7ffd120d5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_ PassThrough, MultiplyAddFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances(instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< ck::Tuple, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index d2adc36dc3e..ab49d2f1c9c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -39,7 +39,7 @@ void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_insta PassThrough, MultiplyFastGelu, GemmMNKPadding, - Interwave>{}); + Intrawave>{}); add_device_operation_instances( instances, device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances,