Skip to content

Commit 694b7ff

Browse files
author
Aman Gupta
committed
CUDA: experimental native mxfp4 support for blackwell
1 parent 2fbe3b7 commit 694b7ff

File tree

7 files changed

+455
-15
lines changed

7 files changed

+455
-15
lines changed

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
1515
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
1616
# 86 == RTX 3000, needs CUDA v11.1
1717
# 89 == RTX 4000, needs CUDA v11.8
18+
# 100 == Blackwell, needs CUDA v12.8, native FP4 tensor cores
1819
#
1920
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
2021
# XX-real == compile CUDA code as device code for this specific architecture
@@ -34,6 +35,10 @@ if (CUDAToolkit_FOUND)
3435
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
3536
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
3637
endif()
38+
39+
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
40+
list(APPEND CMAKE_CUDA_ARCHITECTURES 100-real)
41+
endif()
3742
endif()
3843
endif()
3944
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

ggml/src/ggml-cuda/common.cuh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#define GGML_CUDA_CC_TURING 750
5151
#define GGML_CUDA_CC_AMPERE 800
5252
#define GGML_CUDA_CC_ADA_LOVELACE 890
53+
#define GGML_CUDA_CC_BLACKWELL 1000
5354
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
5455
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
5556
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@@ -243,6 +244,10 @@ static const char * cu_get_error_str(CUresult err) {
243244
#define AMPERE_MMA_AVAILABLE
244245
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
245246

247+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
248+
# define BLACKWELL_MMA_AVAILABLE
249+
#endif
250+
246251
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
247252
#define CP_ASYNC_AVAILABLE
248253
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -313,6 +318,10 @@ static bool cp_async_available(const int cc) {
313318
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
314319
}
315320

321+
static bool blackwell_mma_available(const int cc) {
322+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL;
323+
}
324+
316325
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
317326
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
318327
return 64;
@@ -698,6 +707,41 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
698707
#endif // CUDART_VERSION >= 12050
699708
}
700709

710+
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
711+
// Handle exact zero early
712+
if (x == 0.0f) {
713+
return 0;
714+
}
715+
716+
const float sign = x < 0.0f ? -1.0f : 1.0f;
717+
float ax = fabsf(x) * e;
718+
719+
// Positive LUT
720+
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
721+
722+
// Saturate to max representable magnitude
723+
if (ax > pos_lut[7]) {
724+
ax = pos_lut[7];
725+
}
726+
727+
int best_i = 0;
728+
float best_err = fabsf(ax - pos_lut[0]);
729+
for (int i = 1; i < 8; ++i) {
730+
float err = fabsf(ax - pos_lut[i]);
731+
if (err < best_err) {
732+
best_err = err;
733+
best_i = i;
734+
}
735+
}
736+
737+
// Positive codes: 0..7, negative: 8..15 (sign bit = MSB)
738+
if (sign > 0.0f) {
739+
return static_cast<uint8_t>(best_i); // 0..7
740+
} else {
741+
return static_cast<uint8_t>(best_i | 0x8); // 8..15
742+
}
743+
}
744+
701745
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
702746
// Precompute mp (m' in the paper) and L such that division
703747
// can be computed using a multiply (high 32b of 64b result)

ggml/src/ggml-cuda/mma.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,25 @@ namespace ggml_cuda_mma {
812812
#endif // AMPERE_MMA_AVAILABLE
813813
}
814814

815+
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
816+
const tile<16, 8, int> & A,
817+
const tile<8, 8, int> & B,
818+
uint32_t a_scale,
819+
uint32_t b_scale) {
820+
#ifdef BLACKWELL_MMA_AVAILABLE
821+
const int * Axi = (const int *) A.x;
822+
const int * Bxi = (const int *) B.x;
823+
float * Dxi = (float *) D.x;
824+
825+
asm volatile(
826+
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
827+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
828+
"%10, {0, 0}, %11, {0, 0};"
829+
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
830+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
831+
#endif // BLACKWELL_MMA_AVAILABLE
832+
}
833+
815834
static __device__ __forceinline__ void mma(
816835
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
817836
#ifdef TURING_MMA_AVAILABLE

ggml/src/ggml-cuda/mmq.cu

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,23 @@ void ggml_cuda_mul_mat_q(
123123
const int64_t s11 = src1->nb[1] / ts_src1;
124124
const int64_t s12 = src1->nb[2] / ts_src1;
125125
const int64_t s13 = src1->nb[3] / ts_src1;
126-
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
127-
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
126+
if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) {
127+
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
128+
ne11, ne12, ne13, stream);
129+
130+
} else {
131+
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
132+
ne11, ne12, ne13, stream);
133+
}
128134
CUDA_CHECK(cudaGetLastError());
129135
}
130136

131-
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
137+
// Stride depends on quantization format
138+
const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ?
139+
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
140+
(4 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 128 values
141+
:
142+
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
132143
const int64_t s13 = ne12*s12;
133144

134145
const mmq_args args = {
@@ -175,12 +186,20 @@ void ggml_cuda_mul_mat_q(
175186
const int64_t s11 = src1->nb[1] / ts_src1;
176187
const int64_t s12 = src1->nb[2] / ts_src1;
177188
const int64_t s13 = src1->nb[2] / ts_src1;
178-
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
179-
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
189+
190+
if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) {
191+
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
192+
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
193+
} else {
194+
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
195+
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
196+
}
180197
CUDA_CHECK(cudaGetLastError());
181198
}
182199

183-
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
200+
const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ?
201+
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (4 * QK_MXFP4 * sizeof(int)) :
202+
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
184203
const int64_t s13 = ne12*s12;
185204

186205
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.

0 commit comments

Comments
 (0)