Skip to content

Commit 61c41a0

Browse files
author
Aman Gupta
committed
use iter_k as 512, cleanup
1 parent a1672f6 commit 61c41a0

File tree

4 files changed

+75
-89
lines changed

4 files changed

+75
-89
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -719,11 +719,6 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
719719
// Positive LUT
720720
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
721721

722-
// Saturate to max representable magnitude
723-
if (ax > pos_lut[7]) {
724-
ax = pos_lut[7];
725-
}
726-
727722
int best_i = 0;
728723
float best_err = fabsf(ax - pos_lut[0]);
729724

ggml/src/ggml-cuda/mmq.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ void ggml_cuda_mul_mat_q(
140140
// Stride depends on quantization format
141141
const int64_t s12 = use_native_mxfp4 ?
142142
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
143-
(4 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 128 values
143+
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
144144
:
145145
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
146146
const int64_t s13 = ne12*s12;
@@ -200,9 +200,8 @@ void ggml_cuda_mul_mat_q(
200200
CUDA_CHECK(cudaGetLastError());
201201
}
202202

203-
const int64_t s12 = use_native_mxfp4 ?
204-
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (4 * QK_MXFP4 * sizeof(int)) :
205-
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
203+
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
204+
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
206205
const int64_t s13 = ne12*s12;
207206

208207
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 51 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
1111

1212
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
1313
#define MMQ_ITER_K 256
14+
#define MMQ_ITER_K_MXFP4_FP4 512
1415
#define MMQ_NWARPS 8
1516

1617
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -46,13 +47,13 @@ struct block_q8_1_mmq {
4647
};
4748

4849
struct block_fp4_mmq {
49-
uint32_t d4[2]; // 1 8 bit (e8m0) scale per 32 values, packed LSB as d0-d1 in d4[0] and d4[1]
50-
int8_t qs[2 * 32]; // 128 values to 4 bit each (4 blocks)
50+
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
51+
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
5152
};
5253

5354
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
5455
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
55-
static_assert(sizeof(block_fp4_mmq) == 72, "Unexpected block_fp4_mmq size");
56+
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
5657

5758
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
5859
switch (type_x) {
@@ -136,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
136137
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
137138
}
138139

140+
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
141+
#if defined(BLACKWELL_MMA_AVAILABLE)
142+
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
143+
#else
144+
return MMQ_ITER_K;
145+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
146+
}
147+
139148
static constexpr __device__ int get_mmq_y_device() {
140149
#if defined(GGML_USE_HIP)
141150
#if defined(RDNA1)
@@ -198,7 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
198207
}
199208

200209
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
201-
#define MMQ_MMA_TILE_X_K_FP4 (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_0)
210+
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
202211
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
203212
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
204213
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@@ -209,7 +218,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
209218
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
210219
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
211220
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
212-
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
221+
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
213222

214223
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
215224
switch (type) {
@@ -218,11 +227,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
218227
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
219228
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
220229
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
221-
#ifdef BLACKWELL_MMA_AVAILABLE
222-
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_FP4;
223-
#else
230+
// tile sizes are the same for Q8_1 and FP4 for blackwell
224231
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
225-
#endif
226232
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
227233
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
228234
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -242,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
242248

243249
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
244250
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
245-
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K / 2
251+
//#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K / 2
252+
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
246253

247254
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
248255
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
@@ -785,15 +792,14 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
785792
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
786793

787794
int * x_qs = (int *) x_tile;
788-
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K);
795+
uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
789796

790797
const int txi = threadIdx.x;
791798

792-
// Use all 32 threads: 8 threads per row, process 4 rows per warp per iteration
793-
constexpr int threads_per_row = 8; // 8 blocks per row
794-
constexpr int rows_per_warp = warp_size / threads_per_row; // 4 rows per warp
795-
const int kbx = txi % threads_per_row; // block id 0-7
796-
const int row_in_warp = txi / threads_per_row; // which of the 4 rows this thread handles
799+
constexpr int threads_per_row = 16; // 16 blocks per row for 512 values (ITER_K=512)
800+
constexpr int rows_per_warp = warp_size / threads_per_row;
801+
const int kbx = txi % threads_per_row;
802+
const int row_in_warp = txi / threads_per_row;
797803

798804
#pragma unroll
799805
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
@@ -805,6 +811,7 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
805811

806812
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
807813

814+
// quantize_mxfp4_mmq permutes nibbles to match the quantized format
808815
const int k0 = kbx * 4;
809816
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
810817

@@ -1003,12 +1010,12 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
10031010

10041011
// Match layout from load_tiles_mxfp4_fp4
10051012
const int * x_qs = (const int *) x;
1006-
const uint32_t * x_sc = (const uint32_t *) (x_qs + MMQ_TILE_NE_K); // E8M0 scales at same offset as load
1007-
const int * y_qs = (const int *) y + 2;
1008-
const uint32_t * y_sc = (const uint32_t *) y; // E8M0 scales for Y
1013+
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1014+
const int * y_qs = (const int *) y + 4;
1015+
const uint32_t * y_sc = (const uint32_t *) y;
10091016

1010-
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale per tile
1011-
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // per tile you would only have 1 scale per thread
1017+
tile_A A[ntx][MMQ_TILE_NE_K / QI8_0];
1018+
uint32_t scaleA[ntx][MMQ_TILE_NE_K / QI8_0];
10121019

10131020
// Block scale
10141021
// Each thread has to point to a 4 byte scale value
@@ -1019,8 +1026,8 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
10191026
#pragma unroll
10201027
for (int n = 0; n < ntx; ++n) {
10211028
#pragma unroll
1022-
for (int k01 = 0; k01 < MMQ_TILE_NE_K / 2; k01 += QI8_0) {
1023-
const int k0 = k00 / 2 + k01;
1029+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
1030+
const int k0 = k00 + k01;
10241031

10251032
load_ldmatrix(A[n][k01 / QI8_0], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
10261033
MMQ_MMA_TILE_X_K_FP4);
@@ -1034,7 +1041,7 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
10341041
#pragma unroll
10351042
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
10361043
#pragma unroll
1037-
for (int k01 = 0; k01 < MMQ_TILE_NE_K / 2; k01 += QI8_0) {
1044+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
10381045
tile_B B;
10391046
uint32_t scaleB; // 2xN scales
10401047

@@ -3367,34 +3374,24 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
33673374
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
33683375
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
33693376

3370-
#if defined(BLACKWELL_MMA_AVAILABLE)
3371-
constexpr bool use_native_mxfp4 = (type == GGML_TYPE_MXFP4);
3372-
#else
3373-
constexpr bool use_native_mxfp4 = false;
3374-
#endif // defined(BLACKWELL_MMA_AVAILBLE)
3375-
3376-
3377-
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3377+
constexpr int ITER_K = get_iter_k(type);
3378+
constexpr int blocks_per_iter = ITER_K / qk;
33783379

33793380
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
33803381

3381-
constexpr size_t sz = use_native_mxfp4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
3382-
constexpr size_t y_stride = use_native_mxfp4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K;
3382+
constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
33833383

3384-
constexpr int y_block_stride = use_native_mxfp4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks)
3385-
:
3386-
(qk * sz / (4 * QK8_1 * sizeof(int))); // original formula for Q8_1
3384+
// blocks_per_mmq: number of qk-blocks per Y-block structure
3385+
// MXFP4: block_fp4_mmq holds 8 qk-blocks (256 values)
3386+
// Others: block_q8_1_mmq holds 4 qk-blocks (128 values)
3387+
constexpr int blocks_per_mmq = (type == GGML_TYPE_MXFP4) ? 8 : 4;
33873388

33883389
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
33893390
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
33903391
{
3391-
const int * by0 =
3392-
use_native_mxfp4 ?
3393-
y + ncols_y * ((kb0 / 4) * y_block_stride) // kb0/4 for MXFP4 since 4 qk-blocks per block_fp4_mmq
3394-
:
3395-
y + ncols_y * (kb0 * y_block_stride); // original for Q8_1
3392+
const int * by0 = y + ncols_y * (kb0 / blocks_per_mmq) * sz;
33963393
#pragma unroll
3397-
for (int l0 = 0; l0 < mmq_x * y_stride; l0 += nwarps * warp_size) {
3394+
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
33983395
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
33993396

34003397
tile_y[l] = by0[l];
@@ -3408,14 +3405,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
34083405
__syncthreads();
34093406

34103407
{
3411-
const int * by0 =
3412-
use_native_mxfp4 ?
3413-
y + ncols_y * ((kb0 / 4) * y_block_stride + y_block_stride) // advance by one block_fp4_mmq
3414-
:
3415-
y + ncols_y * (kb0 * y_block_stride +
3416-
(int) (sz / sizeof(int))); // original for Q8_1 (advance by one block)
3417-
#pragma unroll
3418-
for (int l0 = 0; l0 < mmq_x * y_stride; l0 += nwarps * warp_size) {
3408+
const int * by0 = y + ncols_y * ((kb0 / blocks_per_mmq) * sz + sz);
3409+
#pragma unroll
3410+
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
34193411
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
34203412

34213413
tile_y[l] = by0[l];
@@ -3547,8 +3539,10 @@ static __global__ void mul_mat_q(
35473539
}
35483540
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
35493541

3542+
constexpr int ITER_K = get_iter_k(type);
3543+
35503544
const int64_t blocks_per_ne00 = ncols_x / qk;
3551-
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3545+
constexpr int blocks_per_iter = ITER_K / qk;
35523546

35533547
// kbc == k block continuous, current index in continuous ijk space.
35543548
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@@ -3609,8 +3603,7 @@ static __global__ void mul_mat_q(
36093603
__syncthreads();
36103604
}
36113605

3612-
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
3613-
offset_y += (col_low + jt * mmq_x) * (sz / sizeof(int));
3606+
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
36143607
offset_dst += it*mmq_y;
36153608

36163609
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3677,8 +3670,7 @@ static __global__ void mul_mat_q(
36773670
__syncthreads();
36783671
}
36793672

3680-
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
3681-
offset_y += (col_low + jt * mmq_x) * (sz / sizeof(int));
3673+
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
36823674
offset_dst += it*mmq_y;
36833675

36843676
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3701,7 +3693,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
37013693
const int ncols_max) {
37023694
constexpr int mmq_y = get_mmq_y_device();
37033695
constexpr int qk = ggml_cuda_type_traits<type>::qk;
3704-
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3696+
constexpr int ITER_K = get_iter_k(type);
3697+
3698+
constexpr int blocks_per_iter = ITER_K / qk;
37053699
const int64_t blocks_per_ne00 = ncols_x / qk;
37063700

37073701
constexpr int nwarps = mmq_get_nwarps_device();

0 commit comments

Comments
 (0)