@@ -79,10 +79,8 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
7979 const int nwarps = blockDim .y ;
8080
8181 const int64_t warp_start_offset = (blockIdx .y * nwarps + warp_id) * vals_per_warp;
82- const int64_t i0_block0 = warp_start_offset + lane_id_32;
83- const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32;
8482
85- if (i0_block0 >= ne0) {
83+ if (warp_start_offset >= ne0) {
8684 return ;
8785 }
8886
@@ -101,66 +99,46 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
10199 const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx .x ;
102100 const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
103101
104- // Precompute common values
105- const int lane_id = lane_id_32 % 4 ;
106102 const int group_id = lane_id_32 / 4 ;
107- const int group_base = group_id * 4 ;
103+ const int lane_in_group = lane_id_32 % 4 ;
104+ const int base = group_id * 2 ;
108105 char2 * yqs2 = (char2 *) y[ib].qs ;
109106
110107 const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
111- const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0 .0f ;
112- const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0 .0f ;
113108
114- // === Process first block (0-31) ===
115- float amax0 = fabsf (xi0);
116- #pragma unroll
117- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
118- amax0 = fmaxf (amax0, __shfl_xor_sync (0xFFFFFFFF , amax0, mask, WARP_SIZE));
119- }
109+ uint8_t scales[2 ];
120110
121- const uint8_t e0 = compute_e8m0_scale (amax0);
122- const float inv_s0 = (amax0 == 0 .0f ) ? 0 .0f : __frcp_rn (ggml_cuda_e8m0_to_fp32 (e0 ));
123- const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1 (xi0, inv_s0);
124-
125- // Gather 4 values from consecutive threads using shuffle
126- const uint8_t q0_0 = __shfl_sync (0xFFFFFFFF , q_val0, group_base + 0 , WARP_SIZE);
127- const uint8_t q0_1 = __shfl_sync (0xFFFFFFFF , q_val0, group_base + 1 , WARP_SIZE);
128- const uint8_t q0_2 = __shfl_sync (0xFFFFFFFF , q_val0, group_base + 2 , WARP_SIZE);
129- const uint8_t q0_3 = __shfl_sync (0xFFFFFFFF , q_val0, group_base + 3 , WARP_SIZE);
130-
131- if (lane_id == 0 ) {
132- char2 q;
133- q.x = (q0_1 << 4 ) | q0_0;
134- q.y = (q0_3 << 4 ) | q0_2;
135- yqs2[pair_idx_in_block * 16 + group_id] = q;
136- }
137-
138- // === Process second block (32-63) ===
139- float amax1 = fabsf (xi1);
140111#pragma unroll
141- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
142- amax1 = fmaxf (amax1, __shfl_xor_sync ( 0xFFFFFFFF , amax1, mask, WARP_SIZE)) ;
143- }
112+ for (int b = 0 ; b < 2 ; ++b ) {
113+ const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32 ;
114+ const float xi = (i0 < ne00) ? x[base_pos + i0] : 0 . 0f ;
144115
145- const uint8_t e1 = compute_e8m0_scale (amax1);
146- const float inv_s1 = (amax1 == 0 .0f ) ? 0 .0f : __frcp_rn (ggml_cuda_e8m0_to_fp32 (e1 ));
147- const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1 (xi1, inv_s1);
148-
149- const uint8_t q1_0 = __shfl_sync (0xFFFFFFFF , q_val1, group_base + 0 , WARP_SIZE);
150- const uint8_t q1_1 = __shfl_sync (0xFFFFFFFF , q_val1, group_base + 1 , WARP_SIZE);
151- const uint8_t q1_2 = __shfl_sync (0xFFFFFFFF , q_val1, group_base + 2 , WARP_SIZE);
152- const uint8_t q1_3 = __shfl_sync (0xFFFFFFFF , q_val1, group_base + 3 , WARP_SIZE);
116+ float amax = fabsf (xi);
117+ #pragma unroll
118+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
119+ amax = fmaxf (amax, __shfl_xor_sync (0xFFFFFFFF , amax, mask, WARP_SIZE));
120+ }
153121
154- if (lane_id == 0 ) {
155- char2 q;
156- q.x = (q1_1 << 4 ) | q1_0;
157- q.y = (q1_3 << 4 ) | q1_2;
158- yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
122+ const uint8_t e = compute_e8m0_scale (amax);
123+ scales[b] = e;
124+ const float inv_s = (amax == 0 .0f ) ? 0 .0f : __frcp_rn (ggml_cuda_e8m0_to_fp32 (e));
125+ const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1 (xi, inv_s);
126+
127+ const uint8_t q_lo_0 = __shfl_sync (0xFFFFFFFF , q_val, base, WARP_SIZE);
128+ const uint8_t q_lo_1 = __shfl_sync (0xFFFFFFFF , q_val, base + 1 , WARP_SIZE);
129+ const uint8_t q_hi_0 = __shfl_sync (0xFFFFFFFF , q_val, base + 16 , WARP_SIZE);
130+ const uint8_t q_hi_1 = __shfl_sync (0xFFFFFFFF , q_val, base + 17 , WARP_SIZE);
131+
132+ if (lane_in_group == 0 ) {
133+ char2 q;
134+ q.x = (q_hi_0 << 4 ) | q_lo_0;
135+ q.y = (q_hi_1 << 4 ) | q_lo_1;
136+ yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q;
137+ }
159138 }
160139
161- // Write packed exponents
162140 if (lane_id_32 == 0 ) {
163- y[ib].d4 [pair_idx_in_block] = (e1 << 8 ) | e0 ;
141+ y[ib].d4 [pair_idx_in_block] = (scales[ 1 ] << 8 ) | scales[ 0 ] ;
164142 }
165143}
166144
0 commit comments