Skip to content

Commit 90ff460

Browse files
author
Aman Gupta
committed
use interleaved layout for mma
1 parent 3a9d404 commit 90ff460

File tree

2 files changed

+30
-69
lines changed

2 files changed

+30
-69
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -804,25 +804,8 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
804804

805805
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
806806

807-
int aux_q4[4];
808-
memcpy(aux_q4, bxi->qs, 16);
809-
810-
// Compress: extract low nibbles from each byte and pack into 16 bits
811-
// Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0]
812-
// Output: [lo3|lo2|lo1|lo0] as 16 bits
813-
const auto compress = [](const int x) -> int {
814-
const int m = x & 0x0F0F0F0F; // isolate low nibbles: 0x0lo30lo20lo10lo0
815-
// Pack nibbles: shift and combine
816-
const int t1 = (m | (m >> 4)) & 0x00FF00FF; // 0x00_lo3lo2_00_lo1lo0
817-
return (t1 | (t1 >> 8)) & 0x0000FFFF; // 0x0000_lo3lo2lo1lo0
818-
};
819-
820807
const int k0 = kbx * 4;
821-
822-
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]);
823-
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]);
824-
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4);
825-
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4);
808+
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
826809

827810
// Load E8M0 scales: pack 2 consecutive scales into one uint32
828811
if (kbx % 2 == 0) {

ggml/src/ggml-cuda/quantize.cu

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,8 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
7979
const int nwarps = blockDim.y;
8080

8181
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
82-
const int64_t i0_block0 = warp_start_offset + lane_id_32;
83-
const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32;
8482

85-
if (i0_block0 >= ne0) {
83+
if (warp_start_offset >= ne0) {
8684
return;
8785
}
8886

@@ -101,66 +99,46 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
10199
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
102100
const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
103101

104-
// Precompute common values
105-
const int lane_id = lane_id_32 % 4;
106102
const int group_id = lane_id_32 / 4;
107-
const int group_base = group_id * 4;
103+
const int lane_in_group = lane_id_32 % 4;
104+
const int base = group_id * 2;
108105
char2 * yqs2 = (char2 *) y[ib].qs;
109106

110107
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
111-
const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0.0f;
112-
const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0.0f;
113108

114-
// === Process first block (0-31) ===
115-
float amax0 = fabsf(xi0);
116-
#pragma unroll
117-
for (int mask = 16; mask > 0; mask >>= 1) {
118-
amax0 = fmaxf(amax0, __shfl_xor_sync(0xFFFFFFFF, amax0, mask, WARP_SIZE));
119-
}
109+
uint8_t scales[2];
120110

121-
const uint8_t e0 = compute_e8m0_scale(amax0);
122-
const float inv_s0 = (amax0 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e0));
123-
const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1(xi0, inv_s0);
124-
125-
// Gather 4 values from consecutive threads using shuffle
126-
const uint8_t q0_0 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 0, WARP_SIZE);
127-
const uint8_t q0_1 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 1, WARP_SIZE);
128-
const uint8_t q0_2 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 2, WARP_SIZE);
129-
const uint8_t q0_3 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 3, WARP_SIZE);
130-
131-
if (lane_id == 0) {
132-
char2 q;
133-
q.x = (q0_1 << 4) | q0_0;
134-
q.y = (q0_3 << 4) | q0_2;
135-
yqs2[pair_idx_in_block * 16 + group_id] = q;
136-
}
137-
138-
// === Process second block (32-63) ===
139-
float amax1 = fabsf(xi1);
140111
#pragma unroll
141-
for (int mask = 16; mask > 0; mask >>= 1) {
142-
amax1 = fmaxf(amax1, __shfl_xor_sync(0xFFFFFFFF, amax1, mask, WARP_SIZE));
143-
}
112+
for (int b = 0; b < 2; ++b) {
113+
const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
114+
const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
144115

145-
const uint8_t e1 = compute_e8m0_scale(amax1);
146-
const float inv_s1 = (amax1 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e1));
147-
const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1(xi1, inv_s1);
148-
149-
const uint8_t q1_0 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 0, WARP_SIZE);
150-
const uint8_t q1_1 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 1, WARP_SIZE);
151-
const uint8_t q1_2 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 2, WARP_SIZE);
152-
const uint8_t q1_3 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 3, WARP_SIZE);
116+
float amax = fabsf(xi);
117+
#pragma unroll
118+
for (int mask = 16; mask > 0; mask >>= 1) {
119+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
120+
}
153121

154-
if (lane_id == 0) {
155-
char2 q;
156-
q.x = (q1_1 << 4) | q1_0;
157-
q.y = (q1_3 << 4) | q1_2;
158-
yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
122+
const uint8_t e = compute_e8m0_scale(amax);
123+
scales[b] = e;
124+
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
125+
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
126+
127+
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
128+
const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
129+
const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
130+
const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
131+
132+
if (lane_in_group == 0) {
133+
char2 q;
134+
q.x = (q_hi_0 << 4) | q_lo_0;
135+
q.y = (q_hi_1 << 4) | q_lo_1;
136+
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q;
137+
}
159138
}
160139

161-
// Write packed exponents
162140
if (lane_id_32 == 0) {
163-
y[ib].d4[pair_idx_in_block] = (e1 << 8) | e0;
141+
y[ib].d4[pair_idx_in_block] = (scales[1] << 8) | scales[0];
164142
}
165143
}
166144

0 commit comments

Comments
 (0)