@@ -782,56 +782,56 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
782782 const int i_max,
783783 const int stride) {
784784 constexpr int nwarps = mmq_get_nwarps_device ();
785+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
785786
786787#if defined(BLACKWELL_MMA_AVAILABLE)
787788 int * x_qs = (int *) x_tile;
788- uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); // Same offset as original: 2*MMQ_TILE_NE_K
789+ uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K);
789790
790- constexpr int nrows = 1 ;
791- const int txi = threadIdx .x ; // txi
792- const int kbx = txi;
791+ const int txi = threadIdx .x ;
793792
794- // TODO: only 8 threads of a warp at the moment for simplicity, use more threads
795- if (txi >= 8 ) {
796- return ;
797- }
798- # pragma unroll
799- for (int i0 = 0 ; i0 < mmq_y; i0 += nrows * nwarps) {
800- int i = i0 + threadIdx .y ;
793+ // Use all 32 threads: 8 threads per row, process 4 rows per warp per iteration
794+ constexpr int threads_per_row = 8 ; // 8 blocks per row
795+ constexpr int rows_per_warp = warp_size / threads_per_row; // 4 rows per warp
796+ const int kbx = txi % threads_per_row; // block id 0-7
797+ const int row_in_warp = txi / threads_per_row; // which of the 4 rows this thread handles
798+
799+ #pragma unroll
800+ for (int i0 = 0 ; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
801+ int i = i0 + threadIdx .y * rows_per_warp + row_in_warp;
801802
802803 if (need_check) {
803804 i = min (i, i_max);
804805 }
805806
806807 const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
807808
808- // Load packed FP4 data directly (no LUT dequantization)
809- const int aux_q4_0 = get_int_b1 (bxi->qs , 0 );
810- const int aux_q4_1 = get_int_b1 (bxi->qs , 1 );
811- const int aux_q4_2 = get_int_b1 (bxi->qs , 2 );
812- const int aux_q4_3 = get_int_b1 (bxi->qs , 3 );
809+ // Load 16 bytes more efficiently using memcpy (compiler optimizes to vector loads)
810+ int aux_q4[4 ];
811+ memcpy (aux_q4, bxi->qs , 16 );
813812
813+ // Compress: extract low nibbles from each byte and pack into 16 bits
814+ // Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0]
815+ // Output: [lo3|lo2|lo1|lo0] as 16 bits
814816 const auto compress = [](const int x) -> int {
815- uint16_t a = (x >> 24 ) & 0xF ;
816- uint16_t b = (x >> 16 ) & 0xF ;
817- uint16_t c = (x >> 8 ) & 0xF ;
818- uint16_t d = x & 0xF ;
819-
820- return (a << 12 ) | (b << 8 ) | (c << 4 ) | d;
817+ const int m = x & 0x0F0F0F0F ; // isolate low nibbles: 0x0lo30lo20lo10lo0
818+ // Pack nibbles: shift and combine
819+ const int t1 = (m | (m >> 4 )) & 0x00FF00FF ; // 0x00_lo3lo2_00_lo1lo0
820+ return (t1 | (t1 >> 8 )) & 0x0000FFFF ; // 0x0000_lo3lo2lo1lo0
821821 };
822822
823- const int k0 = kbx * 4 ; // each block takes 4 bytes
823+ const int k0 = kbx * 4 ;
824824
825- x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0 ] = compress (aux_q4_1 ) << 16 | compress (aux_q4_0 );
826- x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1 ] = compress (aux_q4_3 ) << 16 | compress (aux_q4_2 );
827- x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2 ] = compress (aux_q4_1 >> 4 ) << 16 | compress (aux_q4_0 >> 4 );
828- x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3 ] = compress (aux_q4_3 >> 4 ) << 16 | compress (aux_q4_2 >> 4 );
825+ x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0 ] = compress (aux_q4[ 1 ] ) << 16 | compress (aux_q4[ 0 ] );
826+ x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1 ] = compress (aux_q4[ 3 ] ) << 16 | compress (aux_q4[ 2 ] );
827+ x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2 ] = compress (aux_q4[ 1 ] >> 4 ) << 16 | compress (aux_q4[ 0 ] >> 4 );
828+ x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3 ] = compress (aux_q4[ 3 ] >> 4 ) << 16 | compress (aux_q4[ 2 ] >> 4 );
829829
830- if (txi % 2 == 0 ) {
830+ // Load E8M0 scales: pack 2 consecutive scales into one uint32
831+ if (kbx % 2 == 0 ) {
831832 uint32_t e = bxi->e ;
832- bxi++;
833- e |= (bxi->e << 8 );
834- x_sc[i * MMQ_MMA_TILE_X_K_FP4 + txi / 2 ] = e;
833+ e |= ((bxi + 1 )->e << 8 );
834+ x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2 ] = e;
835835 }
836836 }
837837#endif
0 commit comments