Skip to content

FP4#1

Open
z52527 wants to merge 42 commits into
mainfrom
fea-fp4
Open

FP4#1
z52527 wants to merge 42 commits into
mainfrom
fea-fp4

Conversation

@z52527

@z52527 z52527 commented Feb 2, 2026

Copy link
Copy Markdown
Owner

No description provided.

z52527 and others added 4 commits April 14, 2026 07:31
1. FP4-specific block_n heuristic (get_best_fp4_config):
   - Uses stages²×bn composite score to balance pipeline depth vs tile size
   - Replaces generic FP8 heuristic which picked suboptimal block_n for FP4
   - Supports arbitrary block_n (no N%bn divisibility requirement)

2. B-multicast support (2CTA, UMMA_M=256):
   - Enabled for M>=512 when M-tiles are even and GemmType is Normal/KGroupedContiguous
   - Each CTA loads full A, B is split (LOAD_BLOCK_N = BLOCK_N/2)
   - UTCOMMA instruction supports M=128 (1CTA) and M=256 (2CTA)
   - Previously M>=512 shapes crashed; now functional at 77-88% of cuBLAS

3. SFB SMEM zero-fill bugfix (fill_sfb_missing_k_groups):
   - Old code replicated first 32 SF elements across the 128-aligned block,
     corrupting valid data for BLOCK_N in (32, 128)
   - New code zero-fills only positions [BLOCK_N, 128) before warp-transpose
   - Fixes correctness for non-power-of-2 block_n with random scale factors

4. SF SMEM/TMEM calculations now depend on block_k (not hardcoded to 2)

Performance vs cuBLAS NVFP4 (DG/CB ratio, higher=better):
  1CTA (M<512): 0.69→0.78 (+13% relative improvement)
  2CTA (M>=512): CRASH→0.77-0.88 (new functionality)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The FP4 heuristic in get_best_fp4_config() previously picked the
highest-score BLOCK_N within the min-wave tie, which often landed on
BLOCK_N=256. At BN=256 the kernel's TMEM budget (2*BN + SF cols > 512)
forces kNumEpilogueStages=1, serializing MMA and epilogue.

Add a tiebreaker that prefers BNs satisfying the 2-epi-stage constraint,
but only when num_waves >= 2. Single-wave shapes see no benefit from
TMEM double-buffering (the kernel alternates accum_stage_idx across tile
iterations, so there must be >=2 tiles per SM to overlap MMA and
epilogue), so they fall back to pure score selection.

Measured on B200 (148 SMs):
  4096x4096x7168:  3647 -> 3925 TFLOPs (+7.6%)
  256x7168x2048:    289 ->  305 TFLOPs (+5.6%)
  512x2048x7168:    566 ->  594 TFLOPs (+5.0%)
  1024x4096x7168:  2292 -> 2301 TFLOPs (unchanged, 1 wave, gate keeps BN=256)
  Other 8 shapes:  within +/- 2% noise, no regressions.

Also adds:
  - tests/bench_fp4_vs_cublas.py: perf + correctness sweep vs cuBLAS NVFP4
  - tests/probe_tmem_doublebuf.py: per-shape BN sweep that validated the gate

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
z52527 and others added 8 commits April 21, 2026 02:58
Lay the groundwork for Cluster Launch Control (CLC) dynamic tile
scheduling. This commit does not change runtime behavior: kUseCLC
defaults to false and the host wrapper hardcodes it to false.

What's here:
  - scheduler_clc.cuh: standalone header with PTX wrappers
    (clc_try_cancel / clc_query_cancel) and a SchedulerCLC<N, C>
    mailbox pipeline class modeled on CUTLASS 3.x
    PersistentTileSchedulerSm100.
  - sm100_fp4_gemm_1d1d.cuh: add `bool kUseCLC = false` template
    param at the end of the parameter list; static_assert restricts
    CLC to kNumMulticast == 1 (cluster-wide sync for 2-CTA is a
    follow-up).
  - sm100_fp4_gemm_1d1d.hpp: pass `use_clc = false` to the kernel
    instantiation string.

What's NOT here (follow-ups):
  - Kernel SMEM allocation for CLC mailbox + barriers when kUseCLC.
  - Warp 3 scheduler role and work-warp consumption of mailbox.
  - Host launch attributes to enable CLC queue population.
  - Heuristic selection of kUseCLC for shapes that would benefit
    (small-M single-wave cases at ~74-79% of cuBLAS).

Correctness: all 7 tests (constant/random/sweep/asymmetric/
uniform_sf/random_sf/multicast) pass after JIT cache wipe.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous `bench_deepgemm` used `bench()` which measures end-to-end
wall time of `fp8_gemm_nt`, including two launches of the host-side
`transpose_and_pack_fp32_into_ue8m0` kernel (~10 us fixed overhead per
GEMM call for SFA+SFB float32->UE8M0 conversion).

cuBLAS NVFP4 and FlashInfer MXF4 benchmarks measure only their main
GEMM kernel (SF quantization runs once outside the bench loop), as does
DeepGEMM's own FP8 benchmark in tests/test_fp8.py which uses
`bench_kineto(fn, 'fp8_gemm')` to filter to the GEMM kernel only.

Switching `bench_deepgemm` to `bench_kineto(fn, 'sm100_fp4_gemm')` makes
the comparison apples-to-apples: the ~10 us SF transform overhead is
now excluded, which is also the realistic deployment scenario since SF
in real inference is pre-quantized as part of the model weights.

Measured impact (B200): most shapes flip from behind cuBLAS to at or
above cuBLAS. Small-M shapes (M<=256, K=2048) see the biggest swing
because the fixed ~10 us SF transform was dominating their ~20-30 us
total call.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The cuBLAS path in bench_fp4_vs_cublas.py was producing bf16 output
while DG hardcodes fp32, which inflated DG's apparent bandwidth.
Switch the cuBLAS call to fp32 output so both sides move the same
bytes. Also flip the ratio so >1.0 means DG is faster, and add an
8192^3 square shape.

Caveat: this bench routes cuBLAS through flashinfer, which adds
wrapper overhead and inflates DG's reported speedup vs a direct
cuBLAS call. Use the cublasTest-direct path for an honest baseline.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Extends the existing FP4 (MXFP4) kernel with grouped-along-M support:

- csrc/apis/gemm.hpp: m_grouped_fp8_gemm_nt_contiguous detects FP4
  packed (int32) inputs and dispatches to the new FP4 path. Relaxes
  dtype assertions (kInt for a/b, kFloat for d) accordingly.

- csrc/jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp: new
  sm100_m_grouped_fp4_gemm_contiguous_1d1d. Mirrors the FP8 grouped
  contiguous wrapper but handles FP4-specific quirks: K aligned to
  32 int32 elements, zero-pad A (2D) and B (3D [G,N,K]) when k is
  unaligned, SF block_k = 4 (int32 unit, not config.block_k), and
  num_groups stride for both B and SFB TMA descriptors.

- tests/test_fp4.py: adds two tests:
  * test_m_grouped_contiguous: debug shapes (7 configs including
    uneven m_per_group exercising padding rows) + 4 production MoE
    shapes mirroring tests/generators.py:enumerate_m_grouped_contiguous
    (G=4/8, m_per_group 4096/8192, N=4096/7168, K=2048/7168). Reports
    perf via bench_kineto('sm100_fp4_gemm').
  * test_m_grouped_trtllm_comparable: 256-expert sweep matching the
    trtllm-gen DeepSeek-R1 MoE setup (numTokens 32..8192 for FC1+FC2),
    apples-to-apples vs trtllm-gen batch=M throughput baseline. CPU
    LUT reference is bit-exact for tokens<=256.

Heuristic (get_best_fp4_config) was already GemmType-aware so no
changes needed there. Device kernel (sm100_fp4_gemm_1d1d.cuh) already
had kGemmType branches mirroring FP8 — no device changes needed.

All FP4 tests pass with max_diff=0.0000.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a swap_ab path to the FP4 kernel that swaps A/B operand roles at
MMA level so the epilogue can skip writing padding rows. For sparse
MoE (256 experts × small per-expert M), this is 20-25% faster than the
non-swap path and beats trtllm-gen NonSwap by 8-37% in the wave-plateau
region (tokens 32-2048). For dense shapes, the heuristic keeps swap_ab
off to avoid the 8x small-store overhead.

Implementation:

- deep_gemm/include/deep_gemm/impls/sm100_fp4_gemm_1d1d.cuh:
  Add bool kSwapAB template parameter (default false). Swap branches
  at three places:
    1. MMA path: instr_desc and mma::fma() take (b_dtype, a_dtype) and
       (b_desc, a_desc) when kSwapAB; SF column args also swap.
       UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N.
    2. STORE_BLOCK_M = kSwapAB ? 16 : BLOCK_M; STORE_BLOCK_N =
       kSwapAB ? BLOCK_N : kSwizzleCDMode/sizeof(D). SMEM_CD size
       formula branches accordingly.
    3. Epilogue: `if constexpr (kSwapAB) { ... } else { existing ... }`.
       Swap-AB epilogue uses SM100_TMEM_LOAD_32dp32b8x (8 datapaths)
       and loops `effective_m / STORE_BLOCK_M` iterations, skipping
       padding rows entirely.

- deep_gemm/include/deep_gemm/common/scheduler.cuh:
  Add get_aligned_effective_m_in_block<kAlign>(m_block_idx). For
  MGroupedContiguous it linearly scans m_indices to find the first -1;
  for MGroupedMasked it returns min(BLOCK_M, masked_m[g] - prefix);
  for Normal/KGrouped it returns BLOCK_M.

- csrc/jit_kernels/heuristics/sm100.hpp:
  get_best_fp4_config gains an expected_m_per_group parameter (default
  INT_MAX). When m-grouped AND expected_m_per_group < BLOCK_M, set
  swap_ab=true, force BLOCK_N=128 (= LAYOUT_AD_M required by kernel
  assert), and disable multicast.

- csrc/jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp:
  m_grouped wrapper computes useful_per_group via
  (m_indices >= 0).sum().item() / num_groups and passes it to the
  heuristic. TMA D descriptor uses store_block_m=16 when swap_ab is
  active. swap_ab flag is threaded through to the kernel template
  instantiation.

- csrc/jit_kernels/heuristics/common.hpp:
  GemmConfig gains bool swap_ab=false.

Tested:
- tests/test_fp4.py: all 9 test groups pass with max_diff=0.0000.
  Dense / multicast paths unchanged. m-grouped contiguous prod shapes
  (G=4/8, dense per-group) hit the OFF branch (~5% noise vs baseline).
  trtllm-comparable (256 experts) tokens 32-2048 hit the ON branch:
  FC2 0.458 -> 0.343 ms (-25%), beats trtllm-gen NonSwap baseline by
  8-37%. tokens 4096+ correctly fall back to non-swap path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the masked-layout grouped GEMM for FP4 (MoE decode path where each
expert's tokens are pre-grouped into a fixed [G, max_M, K] buffer and
masked_m[g] indicates the actual valid row count per group).

- csrc/apis/gemm.hpp: m_grouped_fp8_gemm_nt_masked detects FP4 packed
  (int32) inputs, relaxes dtype assertions (kInt a/b, kFloat d), and
  dispatches to the new FP4 wrapper.

- csrc/jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp: new
  sm100_m_grouped_fp4_gemm_masked_1d1d. Mirrors the FP8 masked wrapper
  but handles FP4-specific quirks: K aligned to 32 int32 (zero-pad both
  A=[G, max_M, K] and B=[G, N, K] when needed), SF block_k=4, and
  num_groups stride for A/B/D/SFA/SFB TMA descriptors. Passes expected_m
  as expected_m_per_group to the heuristic.

- csrc/jit_kernels/heuristics/sm100.hpp: restricts swap_ab activation to
  MGroupedContiguous only. swap_ab + MGroupedMasked currently produces
  NaN/inf for partial-tile cases (max_m <= BLOCK_M with masked_m
  < BLOCK_M); needs follow-up debugging. Masked uses the non-swap
  kernel path which is fully correct.

- tests/test_fp4.py: adds test_m_grouped_masked with 7 shapes covering:
  small + production-scale (max_m up to 4096), varying utilization
  (25-100%), single-tile and multi-tile-per-group cases. Helper
  functions: pack_fp4_random_3d_ga (3D [G,max_M,K] FP4 packing),
  generate_mxf4_sfa_3d, fp4_reference_masked, run_kernel_grouped_masked.

All 10 FP4 test groups pass with max_diff=0.0000 bit-exact.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bumps test coverage for the m-grouped contiguous and masked variants to
mirror tests/test_fp8.py's iteration depth and shape variety:

- test_m_grouped_contiguous:
  * 9 debug shapes (added (N=512,K=256) and (N=768,K=256) for BLOCK_N
    coverage) × 2 iters
  * 6 prod shapes (added (N=24576,K=1536) and (N=32768,K=512) from
    enumerate_normal for large-N variety) × 3 iters
  Each iter regenerates A/B/m_indices with fresh RNG to surface flaky
  data-dependent bugs.

- test_m_grouped_masked:
  * 5 debug shapes × 3 iters
  * 6 prod shapes mirroring FP8 enumerate_m_grouped_masked
    ((num_groups, m) ∈ {(1,1024), (2,512), (4,256)} × (n,k) ∈
    {(4096,7168), (7168,2048)}) × 10 iters each
  Matches FP8 test_m_grouped_gemm_masked's "for i in range(10)" pattern
  so flaky masked_m distributions are caught.

Total evaluation points: 36 → 129 (+3.6x).

All 129 evaluation points pass with max_diff=0.0000 bit-exact, including
the new large-N shapes (24576 / 32768) and the 60 random-masked_m
production-scale runs.

New perf data points (FP4 contiguous):
  G=4 m=8192 N=24576 K=1536:  937 us / 2640 TFLOPS / 3546 GB/s
  G=4 m=8192 N=32768 K= 512:  858 us / 1282 TFLOPS / 5057 GB/s

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant