diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 5ef6c042..b611eb77 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -121,9 +121,14 @@ function(cuda_test TARGET_NAME) cmake_parse_arguments(cuda_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - list(APPEND UT_SRCS "${PROJECT_SOURCE_DIR}/tests/cpp/test_unit.cc" - "${PROJECT_SOURCE_DIR}/src/cuda_utils.cc" - "${PROJECT_SOURCE_DIR}/tests/cpp/common/test_utils.cc" ${cuda_test_SRCS}) + list( + APPEND + UT_SRCS + "${PROJECT_SOURCE_DIR}/tests/cpp/test_unit.cc" + "${PROJECT_SOURCE_DIR}/src/cuda_utils.cc" + "${PROJECT_SOURCE_DIR}/tests/cpp/common/test_utils.cc" + "${PROJECT_SOURCE_DIR}/src/cuda_info.cc" + ${cuda_test_SRCS}) cuda_add_executable(${TARGET_NAME} ${UT_SRCS}) target_link_libraries(${TARGET_NAME} ${cuda_test_DEPS} gtest glog::glog) diff --git a/include/cell/copy/global_to_shared_v2.hpp b/include/cell/copy/global_to_shared_v2.hpp new file mode 100644 index 00000000..aa643ddc --- /dev/null +++ b/include/cell/copy/global_to_shared_v2.hpp @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "cell/copy/mod.hpp" +#include "types/mod.hpp" + +namespace tilefusion::cell::copy { +using namespace atom; +namespace tl = tile_layout; + +/// @brief The thread-block level API that cooperatively transfers a data tile +/// from global memory to shared memory by all the threads within a +/// thread block. +template +struct GlobalToSharedLoaderV2 { + using Shared = Shared_; + using DType = Shared::DType; + using WarpLayout = WarpLayout_; + + // NOTE: The WarpShape calculated here is for the warp reuse mode `kCont`. + // If you use a different mode, update the WarpShape accordingly. + static_assert((Shared::kRows % WarpLayout ::kRows == 0) && + (Shared::kCols % WarpLayout::kCols == 0), + "The shape of SharedTile must be divisible by the shape of " + "WarpLayout."); + + template + DEVICE void operator()(const Global& src, Shared& dst) {} +}; + +template +struct SharedToGlobalStorerV2 { + using Shared = Shared_; + using DType = Shared::DType; + using WarpLayout = WarpLayout_; + + static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode + + template + DEVICE void operator()(const Shared& src_, Global& dst_) {} +}; + +} // namespace tilefusion::cell::copy diff --git a/include/cell/copy/mod.hpp b/include/cell/copy/mod.hpp index 05f87272..e67d791b 100644 --- a/include/cell/copy/mod.hpp +++ b/include/cell/copy/mod.hpp @@ -7,6 +7,7 @@ #include "cell/copy/copy_atom.hpp" #include "cell/copy/global_to_register.hpp" #include "cell/copy/global_to_shared.hpp" +#include "cell/copy/global_to_shared_v2.hpp" #include "cell/copy/register.hpp" #include "cell/copy/shared_to_register.hpp" #include "cell/copy/sync.hpp" diff --git a/include/cuda_info.hpp b/include/cuda_info.hpp index f66d283c..8b73a51e 100644 --- a/include/cuda_info.hpp +++ b/include/cuda_info.hpp @@ -35,4 +35,15 @@ std::string GetComputeCapability(); // Returns the maximum shared memory per block for the current device. int GetMaxSharedMemoryPerBlock(); + +void check_gpu_memory(); + +/** + * Configure dynamic shared memory for a kernel if needed + * @param kernel The CUDA kernel function pointer + * @param shared_memory_size Required shared memory size in bytes + */ +template +void configure_dynamic_shared_memory(KernelFunc kernel, int shared_memory_size); + } // namespace tilefusion diff --git a/include/cuda_utils.hpp b/include/cuda_utils.hpp index 4bb83482..c685b0ed 100644 --- a/include/cuda_utils.hpp +++ b/include/cuda_utils.hpp @@ -3,8 +3,6 @@ #pragma once -#include "config.hpp" - #include #include #include @@ -54,17 +52,4 @@ inline void __cublasCheck(const cublasStatus_t err, const char* file, } \ } while (0) -inline void check_gpu_memory() { - size_t free_byte; - size_t total_byte; - CUDA_CHECK(cudaMemGetInfo(&free_byte, &total_byte)); - - double free_db = (double)free_byte; - double total_db = (double)total_byte; - double used_db = total_db - free_db; - printf("GPU memory usage: used = %f MB, free = %f MB, total = %f MB\n", - used_db / 1024.0 / 1024.0, free_db / 1024.0 / 1024.0, - total_db / 1024.0 / 1024.0); -} - } // namespace tilefusion diff --git a/include/types/global.hpp b/include/types/global.hpp index c9998003..b59dfa25 100644 --- a/include/types/global.hpp +++ b/include/types/global.hpp @@ -10,19 +10,17 @@ namespace tilefusion { namespace tl = tile_layout; namespace { - /// @brief Helper for pretty printing a GlobalTile's static shape-related /// information. This printer works ONLY on the host. struct GlobalTilePrettyPrinter { template static HOST void print(std::ostream& out, const Global& tile) { // parameter `tile` here is not used - out << layout_type_to_str(Global::kType) << "(" << Global::kRows << ", " - << Global::kCols << ", " << Global::kRowStride << ", " - << Global::kColStride << "), numel = " << Global::kNumel; + out << "GlobalTile {" << std::endl + << " " << typename Global::Layout{} << std::endl + << "}"; } }; - } // namespace template diff --git a/include/types/shared.hpp b/include/types/shared.hpp index d1a32097..346f9980 100644 --- a/include/types/shared.hpp +++ b/include/types/shared.hpp @@ -12,7 +12,6 @@ namespace tilefusion { namespace tl = tile_layout; namespace { - /// @brief Helper for pretty printing a SharedTile's static shape-related /// information. This printer works ONLY on the host. struct SharedTilePrettyPrinter { @@ -27,6 +26,17 @@ struct SharedTilePrettyPrinter { } }; +/// @brief Helper for pretty printing a SharedTile's static shape-related +/// information. This printer works ONLY on the host. +struct SharedTileV2PrettyPrinter { + template + static HOST void print(std::ostream& out, const Shared& tile) { + // parameter `tile` here is not used + out << "SharedTile {" << std::endl + << " " << typename Shared::Layout{} << std::endl + << "}"; + } +}; } // namespace template +class SharedTileV2 { + public: + using DType = Element_; + using Layout = Layout_; + + // pre-computed values + static constexpr int kRows = Layout::kRows; + static constexpr int kCols = Layout::kCols; + static constexpr int kNumel = Layout::kNumel; + + static constexpr tl::Layout kType = Layout::kType; + static constexpr bool isRowMajor = tl::is_row_major::value; + + DEVICE SharedTileV2() : data_(nullptr), layout_(Layout{}) {} + + DEVICE SharedTileV2(DType* data) : data_(data), layout_(Layout{}) {} + + DEVICE SharedTileV2(const DType* data) + : data_(const_cast(data)), layout_(Layout{}) {} + + DEVICE DType* mutable_data() { return data_; } + + DEVICE const DType* data() const { return data_; } + + HOST_DEVICE const Layout& layout() const { return layout_; } + + // for write access + DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } + + // for read access + DEVICE const DType& operator()(int x, int y) const { + return data_[layout_(x, y)]; + } + + DEVICE void dump_value() { util::print_tile(data_, layout_); } + + private: + DType* data_; + Layout layout_; +}; + +/// @brief Pretty printer for the static shape information of a SharedTile. +/// Note: This printer function works ONLY on the host. +template +static HOST std::ostream& operator<<( + std::ostream& out, const SharedTileV2& tile) { + SharedTileV2PrettyPrinter::print(out, tile); + return out; +} + } // namespace tilefusion diff --git a/include/types/swizzle.hpp b/include/types/swizzle.hpp index 67d86674..6852908a 100644 --- a/include/types/swizzle.hpp +++ b/include/types/swizzle.hpp @@ -173,19 +173,6 @@ struct SwizzleBaseTileShape { static constexpr int S = 3; }; -template <> -struct SwizzleBaseTileShape { - using DType = float; - - static constexpr int kRows = 8; - static constexpr int kCols = 32; - static constexpr int kNumel = kRows * kCols; - - static constexpr int B = 3; - static constexpr int M = 2; - static constexpr int S = 3; -}; - template requires HalfType struct SwizzleBaseTileShape { @@ -200,6 +187,19 @@ struct SwizzleBaseTileShape { static constexpr int S = 2; }; +template <> +struct SwizzleBaseTileShape { + using DType = float; + + static constexpr int kRows = 8; + static constexpr int kCols = 32; + static constexpr int kNumel = kRows * kCols; + + static constexpr int B = 3; + static constexpr int M = 2; + static constexpr int S = 3; +}; + template <> struct SwizzleBaseTileShape { using DType = float; diff --git a/src/cuda_info.cc b/src/cuda_info.cc index efe15391..9886f397 100644 --- a/src/cuda_info.cc +++ b/src/cuda_info.cc @@ -108,4 +108,32 @@ int GetMaxSharedMemoryPerBlock() { return prop.sharedMemPerBlock; } +void check_gpu_memory() { + size_t free_byte; + size_t total_byte; + CUDA_CHECK(cudaMemGetInfo(&free_byte, &total_byte)); + + double free_db = (double)free_byte; + double total_db = (double)total_byte; + double used_db = total_db - free_db; + printf("GPU memory usage: used = %f MB, free = %f MB, total = %f MB\n", + used_db / 1024.0 / 1024.0, free_db / 1024.0 / 1024.0, + total_db / 1024.0 / 1024.0); +} + +/** + * Configure dynamic shared memory for a kernel if needed + * @param kernel The CUDA kernel function pointer + * @param shared_memory_size Required shared memory size in bytes + */ +template +void configure_dynamic_shared_memory(KernelFunc kernel, + int shared_memory_size) { + if (shared_memory_size > GetMaxSharedMemoryPerBlock()) { + CUDA_CHECK(cudaFuncSetAttribute(kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_memory_size)); + } +} + } // namespace tilefusion diff --git a/tests/cpp/cell/test_g2s_load_v2.cu b/tests/cpp/cell/test_g2s_load_v2.cu new file mode 100644 index 00000000..16ab3256 --- /dev/null +++ b/tests/cpp/cell/test_g2s_load_v2.cu @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cell/mod.hpp" +#include "common/test_utils.hpp" +#include "cuda_info.hpp" + +#include +#include + +namespace tilefusion::testing { +using namespace cell; +using namespace copy; + +namespace tl = tile_layout; + +namespace { +template +__global__ void copy_g2s(const DType* src_ptr, DType* dst_ptr, Loader& loader, + Storer& storer) { + extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; + auto* buf = reinterpret_cast(buf_); + + Global src(src_ptr); // global memory tile + Shared inter(buf); // shared memory tile + Global dst(dst_ptr); // global memory tile + + // if (thread(0)) { + // printf("\nglobal\n"); + // src.dump_value(); + + // // printf("\nshared\n"); + // // inter.dump_value(); + // // printf("\n"); + // } + + // loader(src, inter); + // __copy_async(); + // __syncthreads(); + + // storer(inter, dst); + // __syncthreads(); +} +} // namespace + +template +void run_test_row_major() { + static const int kThreads = tl::get_numel * 32; + + int numel = kRows * kCols; + thrust::host_vector h_A(numel); + for (int i = 0; i < h_A.size(); ++i) + h_A[i] = static_cast(rand_float(-5.f, 5.f)); + + thrust::device_vector d_B(numel); + thrust::fill(d_B.begin(), d_B.end(), static_cast(0.)); + thrust::device_vector d_A = h_A; + + using Global = GlobalTile>; + + using SharedLayout = + tl::BlockRowMajor, tl::RowMajor<8, 32>>; + using Shared = SharedTileV2; + + std::cout << Global{} << std::endl; + std::cout << Shared{} << std::endl; + + using Loader = GlobalToSharedLoaderV2; + Loader loader; + + using Storer = SharedToGlobalStorerV2; + Storer storer; + + auto kernel = copy_g2s; + int shm_size = kRows * kCols * sizeof(DType); + configure_dynamic_shared_memory(kernel, shm_size); + + dim3 dim_grid(1, 1); + dim3 dim_block(kThreads); + kernel<<>>( + thrust::raw_pointer_cast(d_A.data()), + thrust::raw_pointer_cast(d_B.data()), loader, storer); + cudaDeviceSynchronize(); + + thrust::host_vector h_B(numel); + h_B = d_B; + + // assert_equal(reinterpret_cast(thrust::raw_pointer_cast(h_A.data())), + // reinterpret_cast(thrust::raw_pointer_cast(h_B.data())), + // numel, 1e-5); +} + +TEST(GlobalToSharedLoad, test_row_major_load) { + run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 32>(); +#ifdef CUDA_FP8_AVAILABLE + run_test_row_major<__fp8_e4m3, tl::RowMajor<1, 1>, 16, 32>(); + // run_test_row_major<__fp8_e5m2, tl::RowMajor<1, 1>, 16, 32>(); +#endif +} + +} // namespace tilefusion::testing diff --git a/tests/cpp/cell/test_gemm.cu b/tests/cpp/cell/test_gemm.cu index 0bd6d065..db034da4 100644 --- a/tests/cpp/cell/test_gemm.cu +++ b/tests/cpp/cell/test_gemm.cu @@ -16,13 +16,6 @@ using namespace cell::copy; namespace tl = tile_layout; namespace { -float rand_float(float a = 1e-3, float b = 1) { - float random = ((float)rand()) / (float)RAND_MAX; - float diff = b - a; - float r = random * diff; - return a + r; -} - bool check_correctness(const half* hc1, const float* hc2, int row, int col) { int numel = row * col; bool pass_unittest = true; diff --git a/tests/cpp/common/test_utils.cc b/tests/cpp/common/test_utils.cc index b8bcc1cb..7ba3cdff 100644 --- a/tests/cpp/common/test_utils.cc +++ b/tests/cpp/common/test_utils.cc @@ -27,4 +27,12 @@ void assert_equal(const float* v1, const float* v2, int64_t numel, << "v1[" << i << "] vs. v2[" << i << "] = " << v1[i] << " vs. " << v2[i] << std::endl; } + +float rand_float(float a, float b) { + float random = ((float)rand()) / (float)RAND_MAX; + float diff = b - a; + float r = random * diff; + return a + r; +} + } // namespace tilefusion::testing diff --git a/tests/cpp/common/test_utils.hpp b/tests/cpp/common/test_utils.hpp index 0ec0dc65..56684c43 100644 --- a/tests/cpp/common/test_utils.hpp +++ b/tests/cpp/common/test_utils.hpp @@ -5,7 +5,8 @@ #include "config.hpp" -#include +#include +#include #include #include @@ -14,4 +15,6 @@ namespace tilefusion::testing { template void assert_equal(const T* v1, const T* v2, int64_t numel, float epsilon); +float rand_float(float min = 0.f, float max = 1.f); + } // namespace tilefusion::testing diff --git a/tests/cpp/jit/test_jit.cc b/tests/cpp/jit/test_jit.cc index ebd22091..d006c9bb 100644 --- a/tests/cpp/jit/test_jit.cc +++ b/tests/cpp/jit/test_jit.cc @@ -18,13 +18,6 @@ namespace tilefusion::testing { using namespace tilefusion::jit; namespace { -float rand_float(float a = 1e-3, float b = 1) { - float random = ((float)rand()) / (float)RAND_MAX; - float diff = b - a; - float r = random * diff; - return a + r; -} - std::string generate_add_kernel_source(const std::string& dtype, int numel) { std::stringstream ss; ss << R"(