Conversation
…m data back to global.
…from tmem to shared memory in the epilogue.
1. FP4-specific block_n heuristic (get_best_fp4_config):
- Uses stages²×bn composite score to balance pipeline depth vs tile size
- Replaces generic FP8 heuristic which picked suboptimal block_n for FP4
- Supports arbitrary block_n (no N%bn divisibility requirement)
2. B-multicast support (2CTA, UMMA_M=256):
- Enabled for M>=512 when M-tiles are even and GemmType is Normal/KGroupedContiguous
- Each CTA loads full A, B is split (LOAD_BLOCK_N = BLOCK_N/2)
- UTCOMMA instruction supports M=128 (1CTA) and M=256 (2CTA)
- Previously M>=512 shapes crashed; now functional at 77-88% of cuBLAS
3. SFB SMEM zero-fill bugfix (fill_sfb_missing_k_groups):
- Old code replicated first 32 SF elements across the 128-aligned block,
corrupting valid data for BLOCK_N in (32, 128)
- New code zero-fills only positions [BLOCK_N, 128) before warp-transpose
- Fixes correctness for non-power-of-2 block_n with random scale factors
4. SF SMEM/TMEM calculations now depend on block_k (not hardcoded to 2)
Performance vs cuBLAS NVFP4 (DG/CB ratio, higher=better):
1CTA (M<512): 0.69→0.78 (+13% relative improvement)
2CTA (M>=512): CRASH→0.77-0.88 (new functionality)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The FP4 heuristic in get_best_fp4_config() previously picked the highest-score BLOCK_N within the min-wave tie, which often landed on BLOCK_N=256. At BN=256 the kernel's TMEM budget (2*BN + SF cols > 512) forces kNumEpilogueStages=1, serializing MMA and epilogue. Add a tiebreaker that prefers BNs satisfying the 2-epi-stage constraint, but only when num_waves >= 2. Single-wave shapes see no benefit from TMEM double-buffering (the kernel alternates accum_stage_idx across tile iterations, so there must be >=2 tiles per SM to overlap MMA and epilogue), so they fall back to pure score selection. Measured on B200 (148 SMs): 4096x4096x7168: 3647 -> 3925 TFLOPs (+7.6%) 256x7168x2048: 289 -> 305 TFLOPs (+5.6%) 512x2048x7168: 566 -> 594 TFLOPs (+5.0%) 1024x4096x7168: 2292 -> 2301 TFLOPs (unchanged, 1 wave, gate keeps BN=256) Other 8 shapes: within +/- 2% noise, no regressions. Also adds: - tests/bench_fp4_vs_cublas.py: perf + correctness sweep vs cuBLAS NVFP4 - tests/probe_tmem_doublebuf.py: per-shape BN sweep that validated the gate Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lay the groundwork for Cluster Launch Control (CLC) dynamic tile
scheduling. This commit does not change runtime behavior: kUseCLC
defaults to false and the host wrapper hardcodes it to false.
What's here:
- scheduler_clc.cuh: standalone header with PTX wrappers
(clc_try_cancel / clc_query_cancel) and a SchedulerCLC<N, C>
mailbox pipeline class modeled on CUTLASS 3.x
PersistentTileSchedulerSm100.
- sm100_fp4_gemm_1d1d.cuh: add `bool kUseCLC = false` template
param at the end of the parameter list; static_assert restricts
CLC to kNumMulticast == 1 (cluster-wide sync for 2-CTA is a
follow-up).
- sm100_fp4_gemm_1d1d.hpp: pass `use_clc = false` to the kernel
instantiation string.
What's NOT here (follow-ups):
- Kernel SMEM allocation for CLC mailbox + barriers when kUseCLC.
- Warp 3 scheduler role and work-warp consumption of mailbox.
- Host launch attributes to enable CLC queue population.
- Heuristic selection of kUseCLC for shapes that would benefit
(small-M single-wave cases at ~74-79% of cuBLAS).
Correctness: all 7 tests (constant/random/sweep/asymmetric/
uniform_sf/random_sf/multicast) pass after JIT cache wipe.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This reverts commit 0630cb2.
The previous `bench_deepgemm` used `bench()` which measures end-to-end wall time of `fp8_gemm_nt`, including two launches of the host-side `transpose_and_pack_fp32_into_ue8m0` kernel (~10 us fixed overhead per GEMM call for SFA+SFB float32->UE8M0 conversion). cuBLAS NVFP4 and FlashInfer MXF4 benchmarks measure only their main GEMM kernel (SF quantization runs once outside the bench loop), as does DeepGEMM's own FP8 benchmark in tests/test_fp8.py which uses `bench_kineto(fn, 'fp8_gemm')` to filter to the GEMM kernel only. Switching `bench_deepgemm` to `bench_kineto(fn, 'sm100_fp4_gemm')` makes the comparison apples-to-apples: the ~10 us SF transform overhead is now excluded, which is also the realistic deployment scenario since SF in real inference is pre-quantized as part of the model weights. Measured impact (B200): most shapes flip from behind cuBLAS to at or above cuBLAS. Small-M shapes (M<=256, K=2048) see the biggest swing because the fixed ~10 us SF transform was dominating their ~20-30 us total call. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The cuBLAS path in bench_fp4_vs_cublas.py was producing bf16 output while DG hardcodes fp32, which inflated DG's apparent bandwidth. Switch the cuBLAS call to fp32 output so both sides move the same bytes. Also flip the ratio so >1.0 means DG is faster, and add an 8192^3 square shape. Caveat: this bench routes cuBLAS through flashinfer, which adds wrapper overhead and inflates DG's reported speedup vs a direct cuBLAS call. Use the cublasTest-direct path for an honest baseline. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Extends the existing FP4 (MXFP4) kernel with grouped-along-M support:
- csrc/apis/gemm.hpp: m_grouped_fp8_gemm_nt_contiguous detects FP4
packed (int32) inputs and dispatches to the new FP4 path. Relaxes
dtype assertions (kInt for a/b, kFloat for d) accordingly.
- csrc/jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp: new
sm100_m_grouped_fp4_gemm_contiguous_1d1d. Mirrors the FP8 grouped
contiguous wrapper but handles FP4-specific quirks: K aligned to
32 int32 elements, zero-pad A (2D) and B (3D [G,N,K]) when k is
unaligned, SF block_k = 4 (int32 unit, not config.block_k), and
num_groups stride for both B and SFB TMA descriptors.
- tests/test_fp4.py: adds two tests:
* test_m_grouped_contiguous: debug shapes (7 configs including
uneven m_per_group exercising padding rows) + 4 production MoE
shapes mirroring tests/generators.py:enumerate_m_grouped_contiguous
(G=4/8, m_per_group 4096/8192, N=4096/7168, K=2048/7168). Reports
perf via bench_kineto('sm100_fp4_gemm').
* test_m_grouped_trtllm_comparable: 256-expert sweep matching the
trtllm-gen DeepSeek-R1 MoE setup (numTokens 32..8192 for FC1+FC2),
apples-to-apples vs trtllm-gen batch=M throughput baseline. CPU
LUT reference is bit-exact for tokens<=256.
Heuristic (get_best_fp4_config) was already GemmType-aware so no
changes needed there. Device kernel (sm100_fp4_gemm_1d1d.cuh) already
had kGemmType branches mirroring FP8 — no device changes needed.
All FP4 tests pass with max_diff=0.0000.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a swap_ab path to the FP4 kernel that swaps A/B operand roles at
MMA level so the epilogue can skip writing padding rows. For sparse
MoE (256 experts × small per-expert M), this is 20-25% faster than the
non-swap path and beats trtllm-gen NonSwap by 8-37% in the wave-plateau
region (tokens 32-2048). For dense shapes, the heuristic keeps swap_ab
off to avoid the 8x small-store overhead.
Implementation:
- deep_gemm/include/deep_gemm/impls/sm100_fp4_gemm_1d1d.cuh:
Add bool kSwapAB template parameter (default false). Swap branches
at three places:
1. MMA path: instr_desc and mma::fma() take (b_dtype, a_dtype) and
(b_desc, a_desc) when kSwapAB; SF column args also swap.
UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N.
2. STORE_BLOCK_M = kSwapAB ? 16 : BLOCK_M; STORE_BLOCK_N =
kSwapAB ? BLOCK_N : kSwizzleCDMode/sizeof(D). SMEM_CD size
formula branches accordingly.
3. Epilogue: `if constexpr (kSwapAB) { ... } else { existing ... }`.
Swap-AB epilogue uses SM100_TMEM_LOAD_32dp32b8x (8 datapaths)
and loops `effective_m / STORE_BLOCK_M` iterations, skipping
padding rows entirely.
- deep_gemm/include/deep_gemm/common/scheduler.cuh:
Add get_aligned_effective_m_in_block<kAlign>(m_block_idx). For
MGroupedContiguous it linearly scans m_indices to find the first -1;
for MGroupedMasked it returns min(BLOCK_M, masked_m[g] - prefix);
for Normal/KGrouped it returns BLOCK_M.
- csrc/jit_kernels/heuristics/sm100.hpp:
get_best_fp4_config gains an expected_m_per_group parameter (default
INT_MAX). When m-grouped AND expected_m_per_group < BLOCK_M, set
swap_ab=true, force BLOCK_N=128 (= LAYOUT_AD_M required by kernel
assert), and disable multicast.
- csrc/jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp:
m_grouped wrapper computes useful_per_group via
(m_indices >= 0).sum().item() / num_groups and passes it to the
heuristic. TMA D descriptor uses store_block_m=16 when swap_ab is
active. swap_ab flag is threaded through to the kernel template
instantiation.
- csrc/jit_kernels/heuristics/common.hpp:
GemmConfig gains bool swap_ab=false.
Tested:
- tests/test_fp4.py: all 9 test groups pass with max_diff=0.0000.
Dense / multicast paths unchanged. m-grouped contiguous prod shapes
(G=4/8, dense per-group) hit the OFF branch (~5% noise vs baseline).
trtllm-comparable (256 experts) tokens 32-2048 hit the ON branch:
FC2 0.458 -> 0.343 ms (-25%), beats trtllm-gen NonSwap baseline by
8-37%. tokens 4096+ correctly fall back to non-swap path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the masked-layout grouped GEMM for FP4 (MoE decode path where each expert's tokens are pre-grouped into a fixed [G, max_M, K] buffer and masked_m[g] indicates the actual valid row count per group). - csrc/apis/gemm.hpp: m_grouped_fp8_gemm_nt_masked detects FP4 packed (int32) inputs, relaxes dtype assertions (kInt a/b, kFloat d), and dispatches to the new FP4 wrapper. - csrc/jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp: new sm100_m_grouped_fp4_gemm_masked_1d1d. Mirrors the FP8 masked wrapper but handles FP4-specific quirks: K aligned to 32 int32 (zero-pad both A=[G, max_M, K] and B=[G, N, K] when needed), SF block_k=4, and num_groups stride for A/B/D/SFA/SFB TMA descriptors. Passes expected_m as expected_m_per_group to the heuristic. - csrc/jit_kernels/heuristics/sm100.hpp: restricts swap_ab activation to MGroupedContiguous only. swap_ab + MGroupedMasked currently produces NaN/inf for partial-tile cases (max_m <= BLOCK_M with masked_m < BLOCK_M); needs follow-up debugging. Masked uses the non-swap kernel path which is fully correct. - tests/test_fp4.py: adds test_m_grouped_masked with 7 shapes covering: small + production-scale (max_m up to 4096), varying utilization (25-100%), single-tile and multi-tile-per-group cases. Helper functions: pack_fp4_random_3d_ga (3D [G,max_M,K] FP4 packing), generate_mxf4_sfa_3d, fp4_reference_masked, run_kernel_grouped_masked. All 10 FP4 test groups pass with max_diff=0.0000 bit-exact. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bumps test coverage for the m-grouped contiguous and masked variants to
mirror tests/test_fp8.py's iteration depth and shape variety:
- test_m_grouped_contiguous:
* 9 debug shapes (added (N=512,K=256) and (N=768,K=256) for BLOCK_N
coverage) × 2 iters
* 6 prod shapes (added (N=24576,K=1536) and (N=32768,K=512) from
enumerate_normal for large-N variety) × 3 iters
Each iter regenerates A/B/m_indices with fresh RNG to surface flaky
data-dependent bugs.
- test_m_grouped_masked:
* 5 debug shapes × 3 iters
* 6 prod shapes mirroring FP8 enumerate_m_grouped_masked
((num_groups, m) ∈ {(1,1024), (2,512), (4,256)} × (n,k) ∈
{(4096,7168), (7168,2048)}) × 10 iters each
Matches FP8 test_m_grouped_gemm_masked's "for i in range(10)" pattern
so flaky masked_m distributions are caught.
Total evaluation points: 36 → 129 (+3.6x).
All 129 evaluation points pass with max_diff=0.0000 bit-exact, including
the new large-N shapes (24576 / 32768) and the 60 random-masked_m
production-scale runs.
New perf data points (FP4 contiguous):
G=4 m=8192 N=24576 K=1536: 937 us / 2640 TFLOPS / 3546 GB/s
G=4 m=8192 N=32768 K= 512: 858 us / 1282 TFLOPS / 5057 GB/s
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.