diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index a18f108e47..e37bd381bf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -13,105 +13,12 @@ #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_batched_gemm_wmma_cshuffle_v3( - typename GridwiseGemm::Argument karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - // The normal approach to batching would be to increase the grid size by just stretching out - // the grid Z dimension (which is the outermost dimension), but this depends on lower level - // functions not directly using the Z dimension for other calculations. As it turns out, k - // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now - // we will use the grid Y dimension for batching. This may be a bit fragile. - const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - // shift A matrices pointer for splitk - typename GridwiseGemm::AsGridPointer p_as_grid_shift; - static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { - using ADataType_ = - remove_cvref_t>; - p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + - splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; - }); - - // shift B matrices pointer for splitk - typename GridwiseGemm::BsGridPointer p_bs_grid_shift; - static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { - using BDataType_ = - remove_cvref_t>; - p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + - splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; - }); - - auto epilogue_args = EpilogueType{}; - - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; - ignore = compute_ptr_offset_of_batch; -#endif -} - /// @brief \"Universal\" Batched GEMM operation without SplitK support. /// /// @par Overview @@ -399,285 +306,35 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm, + Tuple, + Tuple<>, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; - /// @brief Helper structure responsible for kernel invocation. - /// - /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU - /// kernel function. It usually determines the launched grid size prepares kernel - /// arguments as well as perform specific kernel configuration selection based on - /// runtime arguments. - /// - /// @note If appropriately configured it may measure kernel execution time. - /// - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - // The normal approach to batching would be to increase the grid size by just stretching - // out the grid Z dimension (which is the outermost dimension), but this depends on - // lower level functions not directly using the Z dimension for other calculations. As - // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. - // Therefore, for now we will use the grid Y dimension for batching. This may be a bit - // fragile. - gdy *= arg.Batch; - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - // Packed sizes are 1 for all implemented data types but we include it anyway - // for future compatibility. - // Note: the grid descriptors and size_a / size_b do *not* take batching into - // account, so we have to manually multiply overall buffer sizes for rotating - // memory by batch. - std::array size_as_buffers; - size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; - - std::array size_bs_buffers; - size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - Tuple<>> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - std::array{}); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_, - arg_.compute_ptr_offset_of_batch); - } - else - { - auto clear_workspace = [&]() { - // clear c mem - if(arg.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg.p_e_grid, - 0, - arg.Batch * arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg, - arg.compute_ptr_offset_of_batch); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - else - { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg); - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - // TODO: This is not part of the DeviceBatchedGemm base class but it was part of - // DeviceBatchedGemmV2. Remove? - // index_t GetKPerBlock() override { return KPerBlock; } - // bool GetPermuteA() override { return PermuteA; } - // bool GetPermuteB() override { return PermuteB; } - static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index b88f071a96..c96a4acd65 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -13,109 +13,12 @@ #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_batched_gemm_b_scale_wmma_cshuffle_v3( - typename GridwiseGemm::Argument karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - // The normal approach to batching would be to increase the grid size by just stretching out - // the grid Z dimension (which is the outermost dimension), but this depends on lower level - // functions not directly using the Z dimension for other calculations. As it turns out, k - // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now - // we will use the grid Y dimension for batching. This may be a bit fragile. - __shared__ char p_shared[LDS_size]; - - const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - const long_index_t b_scale_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - // shift A matrices pointer for splitk - typename GridwiseGemm::AsGridPointer p_as_grid_shift; - static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { - using ADataType_ = - remove_cvref_t>; - p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + - splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; - }); - - // shift B matrices pointer for splitk - typename GridwiseGemm::BsGridPointer p_bs_grid_shift; - static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { - using BDataType_ = - remove_cvref_t>; - p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + - splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; - }); - - auto epilogue_args = EpilogueType{}; - - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - karg.p_a_scale_grid, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; - ignore = compute_ptr_offset_of_batch; -#endif -} - /// @brief \"Universal\" Batched GEMM operation without SplitK support. /// /// @par Overview @@ -433,274 +336,34 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; }; - /// @brief Helper structure responsible for kernel invocation. - /// - /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU - /// kernel function. It usually determines the launched grid size prepares kernel - /// arguments as well as perform specific kernel configuration selection based on - /// runtime arguments. - /// - /// @note If appropriately configured it may measure kernel execution time. - /// - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - // The normal approach to batching would be to increase the grid size by just stretching - // out the grid Z dimension (which is the outermost dimension), but this depends on - // lower level functions not directly using the Z dimension for other calculations. As - // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. - // Therefore, for now we will use the grid Y dimension for batching. This may be a bit - // fragile. - gdy *= arg.Batch; - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - // Packed sizes are 1 for all implemented data types but we include it anyway - // for future compatibility. - // Note: the grid descriptors and size_a / size_b do *not* take batching into - // account, so we have to manually multiply overall buffer sizes for rotating - // memory by batch. - std::array size_as_buffers; - size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; - - std::array size_bs_buffers; - size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - Tuple<>> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - std::array{}); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - ck::utility::flush_icache(); - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_, - arg_.compute_ptr_offset_of_batch); - } - else - { - auto clear_workspace = [&]() { - // clear c mem - if(arg.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg.p_e_grid, - 0, - arg.Batch * arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg, - arg.compute_ptr_offset_of_batch); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - else - { - throw std::runtime_error("Pipeline not implemented"); - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } + using DeviceGemmCommon = DeviceBatchedGemm_Wmma_CShuffleV3_Common< + GridwiseGemm, + Argument, + Tuple, + Tuple, + Tuple<>, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true>; // IsBScaled - return GridwiseGemm::CheckValidity(arg); - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } index_t GetKPerBlock() override { return KPerBlock; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 0000000000..f2e15ba89e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,311 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm_Wmma_CShuffleV3_Common +{ + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + // The normal approach to batching would be to increase the grid size by just stretching + // out the grid Z dimension (which is the outermost dimension), but this depends on + // lower level functions not directly using the Z dimension for other calculations. As + // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. + // Therefore, for now we will use the grid Y dimension for batching. This may be a bit + // fragile. + gdy *= arg.Batch; + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); + + // Packed sizes are 1 for all implemented data types but we include it anyway + // for future compatibility. + // Note: the grid descriptors and size_a / size_b do *not* take batching into + // account, so we have to manually multiply overall buffer sizes for rotating + // memory by batch. + std::array size_as_buffers; + size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * + GridwiseGemm::NumATensor / GridwiseGemm::APackedSize * + arg_.Batch; + + std::array size_bs_buffers; + size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * + GridwiseGemm::NumBTensor / GridwiseGemm::BPackedSize * + arg_.Batch; + + ck::utility:: + RotatingMemWrapperMultiABD> + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + std::array{}); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + ck::utility::flush_icache(); + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg_.p_e_grid, + 0, + arg_.Batch * arg_.M * arg_.N * sizeof(EDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_, + arg_.compute_ptr_offset_of_batch); + } + else + { + auto clear_workspace = [&]() { + // clear c mem + if(arg.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg.p_e_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(EDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + arg.compute_ptr_offset_of_batch); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + using ComputePtrOffsetOfStridedBatch = decltype(arg.compute_ptr_offset_of_batch); + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + } + else + { + throw std::runtime_error("Pipeline not implemented"); + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index ec7710d066..5657374247 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -77,6 +77,122 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif } +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + // shift A matrices pointer for splitk + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + // shift B matrices pointer for splitk + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + + auto epilogue_args = EpilogueType{}; + + if constexpr(IsBScaled) + { + const long_index_t b_scale_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); + + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + karg.p_a_scale_grid, + karg.p_b_scale_grid + b_scale_batch_offset + + splitk_batch_offset.scale_b_k_split_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } + else + { + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + template