Skip to content

Commit 16e8a11

Browse files
author
Aman Gupta
committed
optimize quantize_mxfp4
1 parent e71ba01 commit 16e8a11

File tree

2 files changed

+80
-105
lines changed

2 files changed

+80
-105
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,8 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
726726

727727
int best_i = 0;
728728
float best_err = fabsf(ax - pos_lut[0]);
729+
730+
#pragma unroll
729731
for (int i = 1; i < 8; ++i) {
730732
float err = fabsf(ax - pos_lut[i]);
731733
if (err < best_err) {

ggml/src/ggml-cuda/quantize.cu

Lines changed: 78 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ static __global__ void quantize_q8_1(
4747
y[ib].ds = make_half2(d, sum);
4848
}
4949

50+
// Helper to compute E8M0 scale from amax using fast math
51+
__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
52+
if (amax == 0.0f) {
53+
return 127; // Special case: use scale of 1.0 for zero input
54+
}
55+
// log2(amax / 6.0) = log2(amax) - log2(6) ≈ log2(amax) - 2.585
56+
// Use __log2f for fast approximate log2
57+
const float log2_amax = __log2f(amax) - 2.5849625007211563f; // log2(6)
58+
const int e_int = __float2int_rd(log2_amax) + 127; // floor + bias
59+
return static_cast<uint8_t>(max(1, min(254, e_int)));
60+
}
61+
5062
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
5163
const int32_t * __restrict__ ids,
5264
void * __restrict__ vy,
@@ -60,10 +72,15 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
6072
constexpr int vals_per_scale = 32;
6173
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32
6274

63-
// Each warp processes 2 adjacent blocks of 32 values (64 values total)
64-
const int64_t warp_start_offset = blockIdx.y * vals_per_warp;
65-
const int64_t i0_block0 = warp_start_offset + threadIdx.x; // First block: 0-31
66-
const int64_t i0_block1 = warp_start_offset + vals_per_scale + threadIdx.x; // Second block: 32-63
75+
// Multiple warps per block - each warp handles different data
76+
const int warp_id = threadIdx.y;
77+
const int lane_id_32 = threadIdx.x;
78+
79+
const int nwarps = blockDim.y;
80+
81+
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
82+
const int64_t i0_block0 = warp_start_offset + lane_id_32;
83+
const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32;
6784

6885
if (i0_block0 >= ne0) {
6986
return;
@@ -80,117 +97,70 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
8097
block_fp4_mmq * y = (block_fp4_mmq *) vy;
8198

8299
const int64_t block_fp4_mmq_size = 4 * QK_MXFP4; // 128 values
83-
84-
const int64_t ib0 =
85-
blockIdx.z * ((int64_t) gridDim.x * gridDim.y * vals_per_warp / block_fp4_mmq_size); // first block of channel
86-
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; // block index in channel
87-
const int64_t pair_idx_in_block =
88-
(warp_start_offset % block_fp4_mmq_size) / vals_per_warp; // 0-1: which pair of blocks within block_fp4_mmq
89-
90-
uint8_t e_packed[2];
91-
92-
// Process first block (0-31)
93-
{
94-
const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block0;
95-
const float xi = i0_block0 < ne00 ? x[global_src_pos] : 0.0f;
96-
97-
float amax = fabsf(xi);
98-
99-
// Reduce max across all 32 threads in the warp
100+
const int64_t ib0 = blockIdx.z * ((int64_t) gridDim.x * gridDim.y * nwarps * vals_per_warp / block_fp4_mmq_size);
101+
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
102+
const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
103+
104+
// Precompute common values
105+
const int lane_id = lane_id_32 % 4;
106+
const int group_id = lane_id_32 / 4;
107+
const int group_base = group_id * 4;
108+
char2 * yqs2 = (char2 *) y[ib].qs;
109+
110+
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
111+
const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0.0f;
112+
const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0.0f;
113+
114+
// === Process first block (0-31) ===
115+
float amax0 = fabsf(xi0);
100116
#pragma unroll
101-
for (int mask = 16; mask > 0; mask >>= 1) {
102-
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
103-
}
104-
105-
uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax / 6.0f)) + 127) : 0;
106-
107-
float val = ggml_cuda_e8m0_to_fp32(e);
108-
float inv_s = (amax == 0.0f) ? 0.0f : 1.0f / val;
109-
110-
// Quantize: each thread processes 1 value
111-
uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
112-
113-
if (e == 0) {
114-
e = 127;
115-
}
117+
for (int mask = 16; mask > 0; mask >>= 1) {
118+
amax0 = fmaxf(amax0, __shfl_xor_sync(0xFFFFFFFF, amax0, mask, WARP_SIZE));
119+
}
116120

117-
// Pack 4 values into char2: threads 0,1,2,3 -> first char2, etc.
118-
const int lane_id = threadIdx.x % 4;
119-
const int group_id = threadIdx.x / 4;
121+
const uint8_t e0 = compute_e8m0_scale(amax0);
122+
const float inv_s0 = (amax0 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e0));
123+
const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1(xi0, inv_s0);
120124

121-
// Use shuffle to gather values from 4 consecutive threads
122-
uint8_t q0 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 0, WARP_SIZE);
123-
uint8_t q1 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 1, WARP_SIZE);
124-
uint8_t q2 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 2, WARP_SIZE);
125-
uint8_t q3 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 3, WARP_SIZE);
125+
// Gather 4 values from consecutive threads using shuffle
126+
const uint8_t q0_0 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 0, WARP_SIZE);
127+
const uint8_t q0_1 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 1, WARP_SIZE);
128+
const uint8_t q0_2 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 2, WARP_SIZE);
129+
const uint8_t q0_3 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 3, WARP_SIZE);
126130

131+
if (lane_id == 0) {
127132
char2 q;
128-
if (lane_id == 0) {
129-
q.x = (q1 << 4) | q0;
130-
q.y = (q3 << 4) | q2;
131-
132-
// Write to output: first block in pair uses positions based on pair_idx_in_block
133-
// Each pair has 2 blocks of 32 = 64 values = 16 char2 elements
134-
char2 * yqs2 = (char2 *) y[ib].qs;
135-
yqs2[pair_idx_in_block * 16 + group_id] = q;
136-
}
137-
138-
if (threadIdx.x == 0) {
139-
e_packed[0] = e;
140-
}
133+
q.x = (q0_1 << 4) | q0_0;
134+
q.y = (q0_3 << 4) | q0_2;
135+
yqs2[pair_idx_in_block * 16 + group_id] = q;
141136
}
142137

143-
// Process second block (32-63)
144-
{
145-
const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block1;
146-
const float xi = i0_block1 < ne00 ? x[global_src_pos] : 0.0f;
147-
148-
float amax = fabsf(xi);
149-
138+
// === Process second block (32-63) ===
139+
float amax1 = fabsf(xi1);
150140
#pragma unroll
151-
for (int mask = 16; mask > 0; mask >>= 1) {
152-
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
153-
}
154-
155-
uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax / 6.0f)) + 127) : 0;
156-
157-
float val = ggml_cuda_e8m0_to_fp32(e);
158-
float inv_s = (amax == 0.0f) ? 0.0f : 1.0f / val;
159-
160-
if (e == 0) {
161-
e = 127;
162-
}
163-
164-
uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
141+
for (int mask = 16; mask > 0; mask >>= 1) {
142+
amax1 = fmaxf(amax1, __shfl_xor_sync(0xFFFFFFFF, amax1, mask, WARP_SIZE));
143+
}
165144

166-
const int lane_id = threadIdx.x % 4;
167-
const int group_id = threadIdx.x / 4;
145+
const uint8_t e1 = compute_e8m0_scale(amax1);
146+
const float inv_s1 = (amax1 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e1));
147+
const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1(xi1, inv_s1);
168148

169-
// Use shuffle to gather values from 4 consecutive threads
170-
uint8_t q0 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 0, WARP_SIZE);
171-
uint8_t q1 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 1, WARP_SIZE);
172-
uint8_t q2 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 2, WARP_SIZE);
173-
uint8_t q3 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 3, WARP_SIZE);
149+
const uint8_t q1_0 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 0, WARP_SIZE);
150+
const uint8_t q1_1 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 1, WARP_SIZE);
151+
const uint8_t q1_2 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 2, WARP_SIZE);
152+
const uint8_t q1_3 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 3, WARP_SIZE);
174153

154+
if (lane_id == 0) {
175155
char2 q;
176-
if (lane_id == 0) {
177-
q.x = (q1 << 4) | q0;
178-
q.y = (q3 << 4) | q2;
179-
180-
// Write to output: second block in pair uses positions 8-15 within the pair
181-
char2 * yqs2 = (char2 *) y[ib].qs;
182-
yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
183-
}
184-
185-
if (threadIdx.x == 0) {
186-
e_packed[1] = e;
187-
}
156+
q.x = (q1_1 << 4) | q1_0;
157+
q.y = (q1_3 << 4) | q1_2;
158+
yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
188159
}
189160

190-
// Write packed exponents: d4[0-1] each stores 2 scales (for 2 blocks of 32)
191-
// pair_idx_in_block tells us which d4 entry to use (0-1)
192-
if (threadIdx.x == 0) {
193-
y[ib].d4[pair_idx_in_block] = (e_packed[1] << 8) | e_packed[0];
161+
// Write packed exponents
162+
if (lane_id_32 == 0) {
163+
y[ib].d4[pair_idx_in_block] = (e1 << 8) | e0;
194164
}
195165
}
196166

@@ -353,10 +323,13 @@ void quantize_mmq_mxfp4_cuda(const float * x,
353323
cudaStream_t stream) {
354324
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); // Each warp processes 64 values
355325

356-
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
357-
constexpr int vals_per_warp = 2 * QK_MXFP4; // 64
358-
const int64_t block_num_y = (ne0 + vals_per_warp - 1) / vals_per_warp;
326+
constexpr int nwarps = 8;
327+
constexpr int vals_per_warp = 2 * QK_MXFP4; // 64 values per warp
328+
constexpr int vals_per_block = nwarps * vals_per_warp; // 512 values per block
329+
330+
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
359331
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
360-
const dim3 block_size(32, 1, 1); // Warp size
332+
const dim3 block_size(32, nwarps, 1); // 32 threads x 8 warps = 256 threads per block
333+
361334
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
362335
}

0 commit comments

Comments
 (0)