Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,34 @@ void abquant_quantgrouped_instance_factory(
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
}
12 changes: 7 additions & 5 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;

static constexpr bool PreshuffleQuant = false;
static constexpr bool APreshuffleQuant = false;
static constexpr bool BPreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false;
Expand Down Expand Up @@ -128,7 +129,8 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();

static constexpr bool PreshuffleQuant = true;
static constexpr bool APreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};

template <typename PrecType>
Expand Down Expand Up @@ -158,7 +160,7 @@ template <typename PrecType>
struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode
: public GemmConfigPreshuffleB_BQuant_Decode<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};

template <typename PrecType>
Expand Down Expand Up @@ -189,7 +191,7 @@ template <typename PrecType>
struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
: public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};

template <typename PrecType>
Expand Down Expand Up @@ -241,7 +243,7 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};

template <typename PrecType>
Expand Down
42 changes: 19 additions & 23 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::PreshuffleQuant,
GemmConfig::APreshuffleQuant,
GemmConfig::BPreshuffleQuant,
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout, // for AQLayout
BQLayout, // for BQLayout
AQLayout,
BQLayout,
transpose_c,
GemmConfig::DoubleSmemBuffer>;

Expand All @@ -73,7 +74,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
Expand Down Expand Up @@ -147,7 +148,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
has_hot_loop_v,
tail_number_v>>>>;
using AQuantPipeline =
std::conditional_t<GemmConfig::PreshuffleQuant,
std::conditional_t<GemmConfig::APreshuffleQuant,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;

Expand Down Expand Up @@ -391,8 +392,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = " << quant_type_to_string(QuantMode)
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< " APreshuffleQuant = " << (GemmConfig::APreshuffleQuant ? "true" : "false")
<< " : "
<< " BPreshuffleQuant = " << (GemmConfig::BPreshuffleQuant ? "true" : "false")
<< " : " << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;

Expand Down Expand Up @@ -537,21 +540,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
// Create BQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I know why we need to delete this part? TensorQuant still need a 1,1 quant tensor.

ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
}

std::random_device rd;
std::mt19937 gen(rd());
Expand Down Expand Up @@ -694,7 +689,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleQuant)
if constexpr(GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / AQuantGroupSize::kK);
Expand Down Expand Up @@ -753,7 +748,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);

if constexpr(GemmConfig::PreshuffleQuant)
if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<BQDataType> bq_shuffle_host = ck_tile::shuffle_bq(
&bq_permuted_host, GemmConfig::K_Tile / BQuantGroupSize::kK);
Expand All @@ -764,7 +759,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
}
}
else if constexpr(GemmConfig::PreshuffleQuant)
else if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / BQuantGroupSize::kK);
Expand Down Expand Up @@ -945,7 +940,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)

if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
{
if(a_layout == "R" && b_layout == "R")
{
Expand All @@ -966,7 +961,8 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
}
}
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
!GemmConfig::APreshuffleQuant)
{
if(a_layout == "C" && b_layout == "C")
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ struct AQPickerCommon : public BlockGemmQuantBase
if constexpr(Traits::TransposeC) // transposed C
{
index_t reg_offset =
Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
Traits::APreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
auto scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
if constexpr(Traits::PreshuffleQuant)
if constexpr(Traits::APreshuffleQuant)
{
auto pull_from_lane =
(__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock + kQScale;
Expand All @@ -121,7 +121,7 @@ struct AQPickerCommon : public BlockGemmQuantBase
}
else
{
if constexpr(Traits::PreshuffleQuant)
if constexpr(Traits::APreshuffleQuant)
{
// A view is created on top of the preshuffled AQ, where each row of
// the view is composed of a row from a warp tile within an AQ block
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;

static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;

static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
Expand Down Expand Up @@ -127,9 +128,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;

static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is named as the BQuantGroupSize, why it has kM in here?


static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
Expand Down Expand Up @@ -162,12 +163,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;

static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;

static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); // 128 / 128 = 1
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);

static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
Expand Down Expand Up @@ -289,9 +290,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
CBlockTensor::PackedSize>{};

index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
kQScale;
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;

static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the previous file comment


static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
Expand All @@ -53,7 +53,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg

static constexpr index_t kBlockSize = Problem::kBlockSize;

static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;

static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp =
Expand All @@ -63,12 +63,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;

static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;

static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);

static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
Expand Down Expand Up @@ -173,7 +173,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};

if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
Expand Down Expand Up @@ -205,9 +205,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
else
{
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
KPerBlockBQ +
kQScale;
}
else
Expand Down
Loading