@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
1111
1212#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
1313#define MMQ_ITER_K 256
14+ #define MMQ_ITER_K_MXFP4_FP4 512
1415#define MMQ_NWARPS 8
1516
1617typedef void (*load_tiles_mmq_t )(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -46,13 +47,13 @@ struct block_q8_1_mmq {
4647};
4748
4849struct block_fp4_mmq {
49- uint32_t d4[2 ]; // 1 8 bit (e8m0) scale per 32 values, packed LSB as d0-d1 in d4[0] and d4[1]
50- int8_t qs[2 * 32 ]; // 128 values to 4 bit each (4 blocks)
50+ uint32_t d4[4 ]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
51+ int8_t qs[4 * 32 ]; // 256 FP4 values packed as 4- bit pairs (2 per byte), 8 blocks of 32 values
5152};
5253
5354static_assert (sizeof (block_q8_1_mmq) == 4 *QK8_1 + 4 *sizeof (half2), " Unexpected block_q8_1_mmq size" );
5455static_assert (sizeof (block_q8_1_mmq) == 4 *sizeof (block_q8_1), " Unexpected block_q8_1_mmq size" );
55- static_assert (sizeof (block_fp4_mmq) == 72 , " Unexpected block_fp4_mmq size" );
56+ static_assert (sizeof (block_fp4_mmq) == sizeof (block_q8_1_mmq), " Unexpected block_fp4_mmq size" );
5657
5758static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout (const ggml_type type_x) {
5859 switch (type_x) {
@@ -136,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
136137 ((GGML_CUDA_CC_IS_NVIDIA (cc) && ggml_cuda_highest_compiled_arch (cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64 );
137138}
138139
140+ static constexpr __device__ int get_iter_k ([[maybe_unused]] const ggml_type type) {
141+ #if defined(BLACKWELL_MMA_AVAILABLE)
142+ return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
143+ #else
144+ return MMQ_ITER_K;
145+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
146+ }
147+
139148static constexpr __device__ int get_mmq_y_device () {
140149#if defined(GGML_USE_HIP)
141150#if defined(RDNA1)
@@ -198,7 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
198207}
199208
200209#define MMQ_MMA_TILE_X_K_Q8_0 (2 *MMQ_TILE_NE_K + 2 *MMQ_TILE_NE_K/QI8_0 + 4 )
201- #define MMQ_MMA_TILE_X_K_FP4 (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_0)
210+ #define MMQ_MMA_TILE_X_K_FP4 (2 * MMQ_TILE_NE_K + 2 * MMQ_TILE_NE_K/ QI8_0 + 4 )
202211#define MMQ_MMA_TILE_X_K_Q8_1 (2 *MMQ_TILE_NE_K + 2 *MMQ_TILE_NE_K/QI8_0 + 4 )
203212#define MMQ_MMA_TILE_X_K_Q2_K (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4 )
204213#define MMQ_MMA_TILE_X_K_Q3_K (2 *MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4 )
@@ -209,7 +218,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
209218static_assert (MMQ_MMA_TILE_X_K_Q2_K % 8 == 4 , " Wrong padding." );
210219static_assert (MMQ_MMA_TILE_X_K_Q3_K % 8 == 4 , " Wrong padding." );
211220static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
212- static_assert (MMQ_MMA_TILE_X_K_FP4 % 8 == 4 , " Wrong padding." );
221+ static_assert (MMQ_MMA_TILE_X_K_FP4 % 8 == 4 , " Wrong padding." );
213222
214223static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
215224 switch (type) {
@@ -218,11 +227,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
218227 case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
219228 case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
220229 case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
221- #ifdef BLACKWELL_MMA_AVAILABLE
222- case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_FP4;
223- #else
230+ // tile sizes are the same for Q8_1 and FP4 for blackwell
224231 case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
225- #endif
226232 case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
227233 case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
228234 case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -242,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
242248
243249// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
244250#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
245- #define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K / 2
251+ // #define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K / 2
252+ #define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
246253
247254static int mmq_get_granularity_host (const int mmq_x, const int cc) {
248255 if (amd_mfma_available (cc) || amd_wmma_available (cc)) {
@@ -785,15 +792,14 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
785792 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
786793
787794 int * x_qs = (int *) x_tile;
788- uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K);
795+ uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
789796
790797 const int txi = threadIdx .x ;
791798
792- // Use all 32 threads: 8 threads per row, process 4 rows per warp per iteration
793- constexpr int threads_per_row = 8 ; // 8 blocks per row
794- constexpr int rows_per_warp = warp_size / threads_per_row; // 4 rows per warp
795- const int kbx = txi % threads_per_row; // block id 0-7
796- const int row_in_warp = txi / threads_per_row; // which of the 4 rows this thread handles
799+ constexpr int threads_per_row = 16 ; // 16 blocks per row for 512 values (ITER_K=512)
800+ constexpr int rows_per_warp = warp_size / threads_per_row;
801+ const int kbx = txi % threads_per_row;
802+ const int row_in_warp = txi / threads_per_row;
797803
798804#pragma unroll
799805 for (int i0 = 0 ; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
@@ -805,6 +811,7 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
805811
806812 const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
807813
814+ // quantize_mxfp4_mmq permutes nibbles to match the quantized format
808815 const int k0 = kbx * 4 ;
809816 memcpy (x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs , 16 );
810817
@@ -1003,12 +1010,12 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
10031010
10041011 // Match layout from load_tiles_mxfp4_fp4
10051012 const int * x_qs = (const int *) x;
1006- const uint32_t * x_sc = (const uint32_t *) (x_qs + MMQ_TILE_NE_K); // E8M0 scales at same offset as load
1007- const int * y_qs = (const int *) y + 2 ;
1008- const uint32_t * y_sc = (const uint32_t *) y; // E8M0 scales for Y
1013+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1014+ const int * y_qs = (const int *) y + 4 ;
1015+ const uint32_t * y_sc = (const uint32_t *) y;
10091016
1010- tile_A A[ntx][MMQ_TILE_NE_K / ( 2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale per tile
1011- uint32_t scaleA[ntx][MMQ_TILE_NE_K / ( 2 * QI8_0)]; // per tile you would only have 1 scale per thread
1017+ tile_A A[ntx][MMQ_TILE_NE_K / QI8_0];
1018+ uint32_t scaleA[ntx][MMQ_TILE_NE_K / QI8_0];
10121019
10131020 // Block scale
10141021 // Each thread has to point to a 4 byte scale value
@@ -1019,8 +1026,8 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
10191026#pragma unroll
10201027 for (int n = 0 ; n < ntx; ++n) {
10211028#pragma unroll
1022- for (int k01 = 0 ; k01 < MMQ_TILE_NE_K / 2 ; k01 += QI8_0) {
1023- const int k0 = k00 / 2 + k01;
1029+ for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
1030+ const int k0 = k00 + k01;
10241031
10251032 load_ldmatrix (A[n][k01 / QI8_0], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
10261033 MMQ_MMA_TILE_X_K_FP4);
@@ -1034,7 +1041,7 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
10341041#pragma unroll
10351042 for (int j0 = 0 ; j0 < mmq_x; j0 += ntx * tile_C::J) {
10361043#pragma unroll
1037- for (int k01 = 0 ; k01 < MMQ_TILE_NE_K / 2 ; k01 += QI8_0) {
1044+ for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
10381045 tile_B B;
10391046 uint32_t scaleB; // 2xN scales
10401047
@@ -3367,34 +3374,24 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
33673374 constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
33683375#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
33693376
3370- #if defined(BLACKWELL_MMA_AVAILABLE)
3371- constexpr bool use_native_mxfp4 = (type == GGML_TYPE_MXFP4);
3372- #else
3373- constexpr bool use_native_mxfp4 = false ;
3374- #endif // defined(BLACKWELL_MMA_AVAILBLE)
3375-
3376-
3377- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3377+ constexpr int ITER_K = get_iter_k (type);
3378+ constexpr int blocks_per_iter = ITER_K / qk;
33783379
33793380 float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0 .0f };
33803381
3381- constexpr size_t sz = use_native_mxfp4 ? sizeof (block_fp4_mmq) : sizeof (block_q8_1_mmq);
3382- constexpr size_t y_stride = use_native_mxfp4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K;
3382+ constexpr int sz = sizeof (block_q8_1_mmq) / sizeof (int );
33833383
3384- constexpr int y_block_stride = use_native_mxfp4 ? (sz / sizeof (int )) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks)
3385- :
3386- (qk * sz / (4 * QK8_1 * sizeof (int ))); // original formula for Q8_1
3384+ // blocks_per_mmq: number of qk-blocks per Y-block structure
3385+ // MXFP4: block_fp4_mmq holds 8 qk-blocks (256 values)
3386+ // Others: block_q8_1_mmq holds 4 qk-blocks (128 values)
3387+ constexpr int blocks_per_mmq = (type == GGML_TYPE_MXFP4) ? 8 : 4 ;
33873388
33883389 for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
33893390 load_tiles (x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
33903391 {
3391- const int * by0 =
3392- use_native_mxfp4 ?
3393- y + ncols_y * ((kb0 / 4 ) * y_block_stride) // kb0/4 for MXFP4 since 4 qk-blocks per block_fp4_mmq
3394- :
3395- y + ncols_y * (kb0 * y_block_stride); // original for Q8_1
3392+ const int * by0 = y + ncols_y * (kb0 / blocks_per_mmq) * sz;
33963393#pragma unroll
3397- for (int l0 = 0 ; l0 < mmq_x * y_stride ; l0 += nwarps * warp_size) {
3394+ for (int l0 = 0 ; l0 < mmq_x * MMQ_TILE_Y_K ; l0 += nwarps * warp_size) {
33983395 int l = l0 + threadIdx .y *warp_size + threadIdx .x ;
33993396
34003397 tile_y[l] = by0[l];
@@ -3408,14 +3405,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
34083405 __syncthreads ();
34093406
34103407 {
3411- const int * by0 =
3412- use_native_mxfp4 ?
3413- y + ncols_y * ((kb0 / 4 ) * y_block_stride + y_block_stride) // advance by one block_fp4_mmq
3414- :
3415- y + ncols_y * (kb0 * y_block_stride +
3416- (int ) (sz / sizeof (int ))); // original for Q8_1 (advance by one block)
3417- #pragma unroll
3418- for (int l0 = 0 ; l0 < mmq_x * y_stride; l0 += nwarps * warp_size) {
3408+ const int * by0 = y + ncols_y * ((kb0 / blocks_per_mmq) * sz + sz);
3409+ #pragma unroll
3410+ for (int l0 = 0 ; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
34193411 int l = l0 + threadIdx .y *warp_size + threadIdx .x ;
34203412
34213413 tile_y[l] = by0[l];
@@ -3547,8 +3539,10 @@ static __global__ void mul_mat_q(
35473539 }
35483540#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
35493541
3542+ constexpr int ITER_K = get_iter_k (type);
3543+
35503544 const int64_t blocks_per_ne00 = ncols_x / qk;
3551- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3545+ constexpr int blocks_per_iter = ITER_K / qk;
35523546
35533547 // kbc == k block continuous, current index in continuous ijk space.
35543548 int64_t kbc = (int64_t ) blockIdx .x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim .x ;
@@ -3609,8 +3603,7 @@ static __global__ void mul_mat_q(
36093603 __syncthreads ();
36103604 }
36113605
3612- constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof (block_fp4_mmq) : sizeof (block_q8_1_mmq);
3613- offset_y += (col_low + jt * mmq_x) * (sz / sizeof (int ));
3606+ offset_y += (col_low + jt * mmq_x) * (sizeof (block_q8_1_mmq) / sizeof (int ));
36143607 offset_dst += it*mmq_y;
36153608
36163609 const int tile_x_max_i = nrows_x - it*mmq_y - 1 ;
@@ -3677,8 +3670,7 @@ static __global__ void mul_mat_q(
36773670 __syncthreads ();
36783671 }
36793672
3680- constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof (block_fp4_mmq) : sizeof (block_q8_1_mmq);
3681- offset_y += (col_low + jt * mmq_x) * (sz / sizeof (int ));
3673+ offset_y += (col_low + jt * mmq_x) * (sizeof (block_q8_1_mmq) / sizeof (int ));
36823674 offset_dst += it*mmq_y;
36833675
36843676 const int tile_x_max_i = nrows_x - it*mmq_y - 1 ;
@@ -3701,7 +3693,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
37013693 const int ncols_max) {
37023694 constexpr int mmq_y = get_mmq_y_device ();
37033695 constexpr int qk = ggml_cuda_type_traits<type>::qk;
3704- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3696+ constexpr int ITER_K = get_iter_k (type);
3697+
3698+ constexpr int blocks_per_iter = ITER_K / qk;
37053699 const int64_t blocks_per_ne00 = ncols_x / qk;
37063700
37073701 constexpr int nwarps = mmq_get_nwarps_device ();
0 commit comments