Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,14 @@ function(cuda_test TARGET_NAME)
cmake_parse_arguments(cuda_test "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})

list(APPEND UT_SRCS "${PROJECT_SOURCE_DIR}/tests/cpp/test_unit.cc"
"${PROJECT_SOURCE_DIR}/src/cuda_utils.cc"
"${PROJECT_SOURCE_DIR}/tests/cpp/common/test_utils.cc" ${cuda_test_SRCS})
list(
APPEND
UT_SRCS
"${PROJECT_SOURCE_DIR}/tests/cpp/test_unit.cc"
"${PROJECT_SOURCE_DIR}/src/cuda_utils.cc"
"${PROJECT_SOURCE_DIR}/tests/cpp/common/test_utils.cc"
"${PROJECT_SOURCE_DIR}/src/cuda_info.cc"
${cuda_test_SRCS})

cuda_add_executable(${TARGET_NAME} ${UT_SRCS})
target_link_libraries(${TARGET_NAME} ${cuda_test_DEPS} gtest glog::glog)
Expand Down
44 changes: 44 additions & 0 deletions include/cell/copy/global_to_shared_v2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include "cell/copy/mod.hpp"
#include "types/mod.hpp"

namespace tilefusion::cell::copy {
using namespace atom;
namespace tl = tile_layout;

/// @brief The thread-block level API that cooperatively transfers a data tile
/// from global memory to shared memory by all the threads within a
/// thread block.
template <typename Shared_, typename WarpLayout_>
struct GlobalToSharedLoaderV2 {
using Shared = Shared_;
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

// NOTE: The WarpShape calculated here is for the warp reuse mode `kCont`.
// If you use a different mode, update the WarpShape accordingly.
static_assert((Shared::kRows % WarpLayout ::kRows == 0) &&
(Shared::kCols % WarpLayout::kCols == 0),
"The shape of SharedTile must be divisible by the shape of "
"WarpLayout.");

template <typename Global>
DEVICE void operator()(const Global& src, Shared& dst) {}
};

template <typename Shared_, typename WarpLayout_>
struct SharedToGlobalStorerV2 {
using Shared = Shared_;
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode

template <typename Global>
DEVICE void operator()(const Shared& src_, Global& dst_) {}
};

} // namespace tilefusion::cell::copy
1 change: 1 addition & 0 deletions include/cell/copy/mod.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "cell/copy/copy_atom.hpp"
#include "cell/copy/global_to_register.hpp"
#include "cell/copy/global_to_shared.hpp"
#include "cell/copy/global_to_shared_v2.hpp"
#include "cell/copy/register.hpp"
#include "cell/copy/shared_to_register.hpp"
#include "cell/copy/sync.hpp"
Expand Down
11 changes: 11 additions & 0 deletions include/cuda_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,15 @@ std::string GetComputeCapability();

// Returns the maximum shared memory per block for the current device.
int GetMaxSharedMemoryPerBlock();

void check_gpu_memory();

/**
* Configure dynamic shared memory for a kernel if needed
* @param kernel The CUDA kernel function pointer
* @param shared_memory_size Required shared memory size in bytes
*/
template <typename KernelFunc>
void configure_dynamic_shared_memory(KernelFunc kernel, int shared_memory_size);

} // namespace tilefusion
15 changes: 0 additions & 15 deletions include/cuda_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

#pragma once

#include "config.hpp"

#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -54,17 +52,4 @@ inline void __cublasCheck(const cublasStatus_t err, const char* file,
} \
} while (0)

inline void check_gpu_memory() {
size_t free_byte;
size_t total_byte;
CUDA_CHECK(cudaMemGetInfo(&free_byte, &total_byte));

double free_db = (double)free_byte;
double total_db = (double)total_byte;
double used_db = total_db - free_db;
printf("GPU memory usage: used = %f MB, free = %f MB, total = %f MB\n",
used_db / 1024.0 / 1024.0, free_db / 1024.0 / 1024.0,
total_db / 1024.0 / 1024.0);
}

} // namespace tilefusion
8 changes: 3 additions & 5 deletions include/types/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,17 @@ namespace tilefusion {
namespace tl = tile_layout;

namespace {

/// @brief Helper for pretty printing a GlobalTile's static shape-related
/// information. This printer works ONLY on the host.
struct GlobalTilePrettyPrinter {
template <typename Global>
static HOST void print(std::ostream& out, const Global& tile) {
// parameter `tile` here is not used
out << layout_type_to_str(Global::kType) << "(" << Global::kRows << ", "
<< Global::kCols << ", " << Global::kRowStride << ", "
<< Global::kColStride << "), numel = " << Global::kNumel;
out << "GlobalTile {" << std::endl
<< " " << typename Global::Layout{} << std::endl
<< "}";
}
};

} // namespace

template <typename Element_, typename Layout_>
Expand Down
63 changes: 62 additions & 1 deletion include/types/shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ namespace tilefusion {
namespace tl = tile_layout;

namespace {

/// @brief Helper for pretty printing a SharedTile's static shape-related
/// information. This printer works ONLY on the host.
struct SharedTilePrettyPrinter {
Expand All @@ -27,6 +26,17 @@ struct SharedTilePrettyPrinter {
}
};

/// @brief Helper for pretty printing a SharedTile's static shape-related
/// information. This printer works ONLY on the host.
struct SharedTileV2PrettyPrinter {
template <typename Shared>
static HOST void print(std::ostream& out, const Shared& tile) {
// parameter `tile` here is not used
out << "SharedTile {" << std::endl
<< " " << typename Shared::Layout{} << std::endl
<< "}";
}
};
} // namespace

template <typename Element_, typename Layout_, const bool kSwizzled_ = false,
Expand Down Expand Up @@ -158,4 +168,55 @@ static HOST std::ostream& operator<<(
return out;
}

template <typename Element_, typename Layout_>
class SharedTileV2 {
public:
using DType = Element_;
using Layout = Layout_;

// pre-computed values
static constexpr int kRows = Layout::kRows;
static constexpr int kCols = Layout::kCols;
static constexpr int kNumel = Layout::kNumel;

static constexpr tl::Layout kType = Layout::kType;
static constexpr bool isRowMajor = tl::is_row_major<Layout>::value;

DEVICE SharedTileV2() : data_(nullptr), layout_(Layout{}) {}

DEVICE SharedTileV2(DType* data) : data_(data), layout_(Layout{}) {}

DEVICE SharedTileV2(const DType* data)
: data_(const_cast<DType*>(data)), layout_(Layout{}) {}

DEVICE DType* mutable_data() { return data_; }

DEVICE const DType* data() const { return data_; }

HOST_DEVICE const Layout& layout() const { return layout_; }

// for write access
DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; }

// for read access
DEVICE const DType& operator()(int x, int y) const {
return data_[layout_(x, y)];
}

DEVICE void dump_value() { util::print_tile(data_, layout_); }

private:
DType* data_;
Layout layout_;
};

/// @brief Pretty printer for the static shape information of a SharedTile.
/// Note: This printer function works ONLY on the host.
template <typename Element, typename Layout>
static HOST std::ostream& operator<<(
std::ostream& out, const SharedTileV2<Element, Layout>& tile) {
SharedTileV2PrettyPrinter::print(out, tile);
return out;
}

} // namespace tilefusion
26 changes: 13 additions & 13 deletions include/types/swizzle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,19 +173,6 @@ struct SwizzleBaseTileShape<Element, 128> {
static constexpr int S = 3;
};

template <>
struct SwizzleBaseTileShape<float, 128> {
using DType = float;

static constexpr int kRows = 8;
static constexpr int kCols = 32;
static constexpr int kNumel = kRows * kCols;

static constexpr int B = 3;
static constexpr int M = 2;
static constexpr int S = 3;
};

template <typename Element>
requires HalfType<Element>
struct SwizzleBaseTileShape<Element, 64> {
Expand All @@ -200,6 +187,19 @@ struct SwizzleBaseTileShape<Element, 64> {
static constexpr int S = 2;
};

template <>
struct SwizzleBaseTileShape<float, 128> {
using DType = float;

static constexpr int kRows = 8;
static constexpr int kCols = 32;
static constexpr int kNumel = kRows * kCols;

static constexpr int B = 3;
static constexpr int M = 2;
static constexpr int S = 3;
};

template <>
struct SwizzleBaseTileShape<float, 64> {
using DType = float;
Expand Down
28 changes: 28 additions & 0 deletions src/cuda_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,32 @@ int GetMaxSharedMemoryPerBlock() {
return prop.sharedMemPerBlock;
}

void check_gpu_memory() {
size_t free_byte;
size_t total_byte;
CUDA_CHECK(cudaMemGetInfo(&free_byte, &total_byte));

double free_db = (double)free_byte;
double total_db = (double)total_byte;
double used_db = total_db - free_db;
printf("GPU memory usage: used = %f MB, free = %f MB, total = %f MB\n",
used_db / 1024.0 / 1024.0, free_db / 1024.0 / 1024.0,
total_db / 1024.0 / 1024.0);
}

/**
* Configure dynamic shared memory for a kernel if needed
* @param kernel The CUDA kernel function pointer
* @param shared_memory_size Required shared memory size in bytes
*/
template <typename KernelFunc>
void configure_dynamic_shared_memory(KernelFunc kernel,
int shared_memory_size) {
if (shared_memory_size > GetMaxSharedMemoryPerBlock()) {
CUDA_CHECK(cudaFuncSetAttribute(kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
shared_memory_size));
}
}

} // namespace tilefusion
102 changes: 102 additions & 0 deletions tests/cpp/cell/test_g2s_load_v2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "cell/mod.hpp"
#include "common/test_utils.hpp"
#include "cuda_info.hpp"

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

namespace tilefusion::testing {
using namespace cell;
using namespace copy;

namespace tl = tile_layout;

namespace {
template <typename DType, typename Global, typename Shared, //
typename Loader, typename Storer>
__global__ void copy_g2s(const DType* src_ptr, DType* dst_ptr, Loader& loader,
Storer& storer) {
extern __shared__ __align__(sizeof(double)) unsigned char buf_[];
auto* buf = reinterpret_cast<DType*>(buf_);

Global src(src_ptr); // global memory tile
Shared inter(buf); // shared memory tile
Global dst(dst_ptr); // global memory tile

// if (thread(0)) {
// printf("\nglobal\n");
// src.dump_value();

// // printf("\nshared\n");
// // inter.dump_value();
// // printf("\n");
// }

// loader(src, inter);
// __copy_async();
// __syncthreads();

// storer(inter, dst);
// __syncthreads();
}
} // namespace

template <typename DType, typename WarpLayout, const int kRows, const int kCols>
void run_test_row_major() {
static const int kThreads = tl::get_numel<WarpLayout> * 32;

int numel = kRows * kCols;
thrust::host_vector<DType> h_A(numel);
for (int i = 0; i < h_A.size(); ++i)
h_A[i] = static_cast<DType>(rand_float(-5.f, 5.f));

thrust::device_vector<DType> d_B(numel);
thrust::fill(d_B.begin(), d_B.end(), static_cast<DType>(0.));
thrust::device_vector<DType> d_A = h_A;

using Global = GlobalTile<DType, tl::RowMajor<kRows, kCols>>;

using SharedLayout =
tl::BlockRowMajor<tl::RowMajor<kRows, kCols>, tl::RowMajor<8, 32>>;
using Shared = SharedTileV2<DType, SharedLayout>;

std::cout << Global{} << std::endl;
std::cout << Shared{} << std::endl;

using Loader = GlobalToSharedLoaderV2<Global, WarpLayout>;
Loader loader;

using Storer = SharedToGlobalStorerV2<Shared, WarpLayout>;
Storer storer;

auto kernel = copy_g2s<DType, Global, Shared, Loader, Storer>;
int shm_size = kRows * kCols * sizeof(DType);
configure_dynamic_shared_memory(kernel, shm_size);

dim3 dim_grid(1, 1);
dim3 dim_block(kThreads);
kernel<<<dim_grid, dim_block, shm_size>>>(
thrust::raw_pointer_cast(d_A.data()),
thrust::raw_pointer_cast(d_B.data()), loader, storer);
cudaDeviceSynchronize();

thrust::host_vector<DType> h_B(numel);
h_B = d_B;

// assert_equal(reinterpret_cast<DType*>(thrust::raw_pointer_cast(h_A.data())),
// reinterpret_cast<DType*>(thrust::raw_pointer_cast(h_B.data())),
// numel, 1e-5);
}

TEST(GlobalToSharedLoad, test_row_major_load) {
run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 32>();
#ifdef CUDA_FP8_AVAILABLE
run_test_row_major<__fp8_e4m3, tl::RowMajor<1, 1>, 16, 32>();
// run_test_row_major<__fp8_e5m2, tl::RowMajor<1, 1>, 16, 32>();
#endif
}

} // namespace tilefusion::testing
Loading