Skip to content

Commit e71ba01

Browse files
author
Aman Gupta
committed
optimize load_tiles
1 parent 694b7ff commit e71ba01

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -782,56 +782,56 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
782782
const int i_max,
783783
const int stride) {
784784
constexpr int nwarps = mmq_get_nwarps_device();
785+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
785786

786787
#if defined(BLACKWELL_MMA_AVAILABLE)
787788
int * x_qs = (int *) x_tile;
788-
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); // Same offset as original: 2*MMQ_TILE_NE_K
789+
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K);
789790

790-
constexpr int nrows = 1;
791-
const int txi = threadIdx.x; // txi
792-
const int kbx = txi;
791+
const int txi = threadIdx.x;
793792

794-
// TODO: only 8 threads of a warp at the moment for simplicity, use more threads
795-
if (txi >= 8) {
796-
return;
797-
}
798-
# pragma unroll
799-
for (int i0 = 0; i0 < mmq_y; i0 += nrows * nwarps) {
800-
int i = i0 + threadIdx.y;
793+
// Use all 32 threads: 8 threads per row, process 4 rows per warp per iteration
794+
constexpr int threads_per_row = 8; // 8 blocks per row
795+
constexpr int rows_per_warp = warp_size / threads_per_row; // 4 rows per warp
796+
const int kbx = txi % threads_per_row; // block id 0-7
797+
const int row_in_warp = txi / threads_per_row; // which of the 4 rows this thread handles
798+
799+
#pragma unroll
800+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
801+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
801802

802803
if (need_check) {
803804
i = min(i, i_max);
804805
}
805806

806807
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
807808

808-
// Load packed FP4 data directly (no LUT dequantization)
809-
const int aux_q4_0 = get_int_b1(bxi->qs, 0);
810-
const int aux_q4_1 = get_int_b1(bxi->qs, 1);
811-
const int aux_q4_2 = get_int_b1(bxi->qs, 2);
812-
const int aux_q4_3 = get_int_b1(bxi->qs, 3);
809+
// Load 16 bytes more efficiently using memcpy (compiler optimizes to vector loads)
810+
int aux_q4[4];
811+
memcpy(aux_q4, bxi->qs, 16);
813812

813+
// Compress: extract low nibbles from each byte and pack into 16 bits
814+
// Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0]
815+
// Output: [lo3|lo2|lo1|lo0] as 16 bits
814816
const auto compress = [](const int x) -> int {
815-
uint16_t a = (x >> 24) & 0xF;
816-
uint16_t b = (x >> 16) & 0xF;
817-
uint16_t c = (x >> 8) & 0xF;
818-
uint16_t d = x & 0xF;
819-
820-
return (a << 12) | (b << 8) | (c << 4) | d;
817+
const int m = x & 0x0F0F0F0F; // isolate low nibbles: 0x0lo30lo20lo10lo0
818+
// Pack nibbles: shift and combine
819+
const int t1 = (m | (m >> 4)) & 0x00FF00FF; // 0x00_lo3lo2_00_lo1lo0
820+
return (t1 | (t1 >> 8)) & 0x0000FFFF; // 0x0000_lo3lo2lo1lo0
821821
};
822822

823-
const int k0 = kbx * 4; // each block takes 4 bytes
823+
const int k0 = kbx * 4;
824824

825-
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4_1) << 16 | compress(aux_q4_0);
826-
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4_3) << 16 | compress(aux_q4_2);
827-
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4_1 >> 4) << 16 | compress(aux_q4_0 >> 4);
828-
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4_3 >> 4) << 16 | compress(aux_q4_2 >> 4);
825+
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]);
826+
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]);
827+
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4);
828+
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4);
829829

830-
if (txi % 2 == 0) {
830+
// Load E8M0 scales: pack 2 consecutive scales into one uint32
831+
if (kbx % 2 == 0) {
831832
uint32_t e = bxi->e;
832-
bxi++;
833-
e |= (bxi->e << 8);
834-
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + txi / 2] = e;
833+
e |= ((bxi + 1)->e << 8);
834+
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
835835
}
836836
}
837837
#endif

0 commit comments

Comments
 (0)