@@ -47,6 +47,18 @@ static __global__ void quantize_q8_1(
4747 y[ib].ds = make_half2 (d, sum);
4848}
4949
50+ // Helper to compute E8M0 scale from amax using fast math
51+ __device__ __forceinline__ uint8_t compute_e8m0_scale (float amax) {
52+ if (amax == 0 .0f ) {
53+ return 127 ; // Special case: use scale of 1.0 for zero input
54+ }
55+ // log2(amax / 6.0) = log2(amax) - log2(6) ≈ log2(amax) - 2.585
56+ // Use __log2f for fast approximate log2
57+ const float log2_amax = __log2f (amax) - 2 .5849625007211563f ; // log2(6)
58+ const int e_int = __float2int_rd (log2_amax) + 127 ; // floor + bias
59+ return static_cast <uint8_t >(max (1 , min (254 , e_int)));
60+ }
61+
5062static __global__ void quantize_mmq_mxfp4 (const float * __restrict__ x,
5163 const int32_t * __restrict__ ids,
5264 void * __restrict__ vy,
@@ -60,10 +72,15 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
6072 constexpr int vals_per_scale = 32 ;
6173 constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32
6274
63- // Each warp processes 2 adjacent blocks of 32 values (64 values total)
64- const int64_t warp_start_offset = blockIdx .y * vals_per_warp;
65- const int64_t i0_block0 = warp_start_offset + threadIdx .x ; // First block: 0-31
66- const int64_t i0_block1 = warp_start_offset + vals_per_scale + threadIdx .x ; // Second block: 32-63
75+ // Multiple warps per block - each warp handles different data
76+ const int warp_id = threadIdx .y ;
77+ const int lane_id_32 = threadIdx .x ;
78+
79+ const int nwarps = blockDim .y ;
80+
81+ const int64_t warp_start_offset = (blockIdx .y * nwarps + warp_id) * vals_per_warp;
82+ const int64_t i0_block0 = warp_start_offset + lane_id_32;
83+ const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32;
6784
6885 if (i0_block0 >= ne0) {
6986 return ;
@@ -80,117 +97,70 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
8097 block_fp4_mmq * y = (block_fp4_mmq *) vy;
8198
8299 const int64_t block_fp4_mmq_size = 4 * QK_MXFP4; // 128 values
83-
84- const int64_t ib0 =
85- blockIdx .z * ((int64_t ) gridDim .x * gridDim .y * vals_per_warp / block_fp4_mmq_size); // first block of channel
86- const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx .x ; // block index in channel
87- const int64_t pair_idx_in_block =
88- (warp_start_offset % block_fp4_mmq_size) / vals_per_warp; // 0-1: which pair of blocks within block_fp4_mmq
89-
90- uint8_t e_packed[2 ];
91-
92- // Process first block (0-31)
93- {
94- const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block0;
95- const float xi = i0_block0 < ne00 ? x[global_src_pos] : 0 .0f ;
96-
97- float amax = fabsf (xi);
98-
99- // Reduce max across all 32 threads in the warp
100+ const int64_t ib0 = blockIdx .z * ((int64_t ) gridDim .x * gridDim .y * nwarps * vals_per_warp / block_fp4_mmq_size);
101+ const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx .x ;
102+ const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
103+
104+ // Precompute common values
105+ const int lane_id = lane_id_32 % 4 ;
106+ const int group_id = lane_id_32 / 4 ;
107+ const int group_base = group_id * 4 ;
108+ char2 * yqs2 = (char2 *) y[ib].qs ;
109+
110+ const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
111+ const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0 .0f ;
112+ const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0 .0f ;
113+
114+ // === Process first block (0-31) ===
115+ float amax0 = fabsf (xi0);
100116#pragma unroll
101- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
102- amax = fmaxf (amax, __shfl_xor_sync (0xFFFFFFFF , amax, mask, WARP_SIZE));
103- }
104-
105- uint8_t e = amax > 0 .0f ? (uint8_t ) (floorf (log2f (amax / 6 .0f )) + 127 ) : 0 ;
106-
107- float val = ggml_cuda_e8m0_to_fp32 (e);
108- float inv_s = (amax == 0 .0f ) ? 0 .0f : 1 .0f / val;
109-
110- // Quantize: each thread processes 1 value
111- uint8_t q_val = ggml_cuda_float_to_fp4_e2m1 (xi, inv_s);
112-
113- if (e == 0 ) {
114- e = 127 ;
115- }
117+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
118+ amax0 = fmaxf (amax0, __shfl_xor_sync (0xFFFFFFFF , amax0, mask, WARP_SIZE));
119+ }
116120
117- // Pack 4 values into char2: threads 0,1,2,3 -> first char2, etc.
118- const int lane_id = threadIdx . x % 4 ;
119- const int group_id = threadIdx . x / 4 ;
121+ const uint8_t e0 = compute_e8m0_scale (amax0);
122+ const float inv_s0 = (amax0 == 0 . 0f ) ? 0 . 0f : __frcp_rn ( ggml_cuda_e8m0_to_fp32 ( e0 )) ;
123+ const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1 (xi0, inv_s0) ;
120124
121- // Use shuffle to gather values from 4 consecutive threads
122- uint8_t q0 = __shfl_sync (0xFFFFFFFF , q_val, (group_id * 4 ) + 0 , WARP_SIZE);
123- uint8_t q1 = __shfl_sync (0xFFFFFFFF , q_val, (group_id * 4 ) + 1 , WARP_SIZE);
124- uint8_t q2 = __shfl_sync (0xFFFFFFFF , q_val, (group_id * 4 ) + 2 , WARP_SIZE);
125- uint8_t q3 = __shfl_sync (0xFFFFFFFF , q_val, (group_id * 4 ) + 3 , WARP_SIZE);
125+ // Gather 4 values from consecutive threads using shuffle
126+ const uint8_t q0_0 = __shfl_sync (0xFFFFFFFF , q_val0, group_base + 0 , WARP_SIZE);
127+ const uint8_t q0_1 = __shfl_sync (0xFFFFFFFF , q_val0, group_base + 1 , WARP_SIZE);
128+ const uint8_t q0_2 = __shfl_sync (0xFFFFFFFF , q_val0, group_base + 2 , WARP_SIZE);
129+ const uint8_t q0_3 = __shfl_sync (0xFFFFFFFF , q_val0, group_base + 3 , WARP_SIZE);
126130
131+ if (lane_id == 0 ) {
127132 char2 q;
128- if (lane_id == 0 ) {
129- q.x = (q1 << 4 ) | q0;
130- q.y = (q3 << 4 ) | q2;
131-
132- // Write to output: first block in pair uses positions based on pair_idx_in_block
133- // Each pair has 2 blocks of 32 = 64 values = 16 char2 elements
134- char2 * yqs2 = (char2 *) y[ib].qs ;
135- yqs2[pair_idx_in_block * 16 + group_id] = q;
136- }
137-
138- if (threadIdx .x == 0 ) {
139- e_packed[0 ] = e;
140- }
133+ q.x = (q0_1 << 4 ) | q0_0;
134+ q.y = (q0_3 << 4 ) | q0_2;
135+ yqs2[pair_idx_in_block * 16 + group_id] = q;
141136 }
142137
143- // Process second block (32-63)
144- {
145- const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block1;
146- const float xi = i0_block1 < ne00 ? x[global_src_pos] : 0 .0f ;
147-
148- float amax = fabsf (xi);
149-
138+ // === Process second block (32-63) ===
139+ float amax1 = fabsf (xi1);
150140#pragma unroll
151- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
152- amax = fmaxf (amax, __shfl_xor_sync (0xFFFFFFFF , amax, mask, WARP_SIZE));
153- }
154-
155- uint8_t e = amax > 0 .0f ? (uint8_t ) (floorf (log2f (amax / 6 .0f )) + 127 ) : 0 ;
156-
157- float val = ggml_cuda_e8m0_to_fp32 (e);
158- float inv_s = (amax == 0 .0f ) ? 0 .0f : 1 .0f / val;
159-
160- if (e == 0 ) {
161- e = 127 ;
162- }
163-
164- uint8_t q_val = ggml_cuda_float_to_fp4_e2m1 (xi, inv_s);
141+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
142+ amax1 = fmaxf (amax1, __shfl_xor_sync (0xFFFFFFFF , amax1, mask, WARP_SIZE));
143+ }
165144
166- const int lane_id = threadIdx .x % 4 ;
167- const int group_id = threadIdx .x / 4 ;
145+ const uint8_t e1 = compute_e8m0_scale (amax1);
146+ const float inv_s1 = (amax1 == 0 .0f ) ? 0 .0f : __frcp_rn (ggml_cuda_e8m0_to_fp32 (e1 ));
147+ const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1 (xi1, inv_s1);
168148
169- // Use shuffle to gather values from 4 consecutive threads
170- uint8_t q0 = __shfl_sync (0xFFFFFFFF , q_val, (group_id * 4 ) + 0 , WARP_SIZE);
171- uint8_t q1 = __shfl_sync (0xFFFFFFFF , q_val, (group_id * 4 ) + 1 , WARP_SIZE);
172- uint8_t q2 = __shfl_sync (0xFFFFFFFF , q_val, (group_id * 4 ) + 2 , WARP_SIZE);
173- uint8_t q3 = __shfl_sync (0xFFFFFFFF , q_val, (group_id * 4 ) + 3 , WARP_SIZE);
149+ const uint8_t q1_0 = __shfl_sync (0xFFFFFFFF , q_val1, group_base + 0 , WARP_SIZE);
150+ const uint8_t q1_1 = __shfl_sync (0xFFFFFFFF , q_val1, group_base + 1 , WARP_SIZE);
151+ const uint8_t q1_2 = __shfl_sync (0xFFFFFFFF , q_val1, group_base + 2 , WARP_SIZE);
152+ const uint8_t q1_3 = __shfl_sync (0xFFFFFFFF , q_val1, group_base + 3 , WARP_SIZE);
174153
154+ if (lane_id == 0 ) {
175155 char2 q;
176- if (lane_id == 0 ) {
177- q.x = (q1 << 4 ) | q0;
178- q.y = (q3 << 4 ) | q2;
179-
180- // Write to output: second block in pair uses positions 8-15 within the pair
181- char2 * yqs2 = (char2 *) y[ib].qs ;
182- yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
183- }
184-
185- if (threadIdx .x == 0 ) {
186- e_packed[1 ] = e;
187- }
156+ q.x = (q1_1 << 4 ) | q1_0;
157+ q.y = (q1_3 << 4 ) | q1_2;
158+ yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
188159 }
189160
190- // Write packed exponents: d4[0-1] each stores 2 scales (for 2 blocks of 32)
191- // pair_idx_in_block tells us which d4 entry to use (0-1)
192- if (threadIdx .x == 0 ) {
193- y[ib].d4 [pair_idx_in_block] = (e_packed[1 ] << 8 ) | e_packed[0 ];
161+ // Write packed exponents
162+ if (lane_id_32 == 0 ) {
163+ y[ib].d4 [pair_idx_in_block] = (e1 << 8 ) | e0 ;
194164 }
195165}
196166
@@ -353,10 +323,13 @@ void quantize_mmq_mxfp4_cuda(const float * x,
353323 cudaStream_t stream) {
354324 GGML_ASSERT (ne0 % (2 * QK_MXFP4) == 0 ); // Each warp processes 64 values
355325
356- // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
357- constexpr int vals_per_warp = 2 * QK_MXFP4; // 64
358- const int64_t block_num_y = (ne0 + vals_per_warp - 1 ) / vals_per_warp;
326+ constexpr int nwarps = 8 ;
327+ constexpr int vals_per_warp = 2 * QK_MXFP4; // 64 values per warp
328+ constexpr int vals_per_block = nwarps * vals_per_warp; // 512 values per block
329+
330+ const int64_t block_num_y = (ne0 + vals_per_block - 1 ) / vals_per_block;
359331 const dim3 num_blocks (ne1, block_num_y, ne2 * ne3);
360- const dim3 block_size (32 , 1 , 1 ); // Warp size
332+ const dim3 block_size (32 , nwarps, 1 ); // 32 threads x 8 warps = 256 threads per block
333+
361334 quantize_mmq_mxfp4<<<num_blocks, block_size, 0 , stream>>> (x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
362335}
0 commit comments