[gfx1250][gemm] Add PTPC FP8/A8W4 and non-tile-aligned M support#649
Draft
aoli26 wants to merge 7 commits into
Draft
[gfx1250][gemm] Add PTPC FP8/A8W4 and non-tile-aligned M support#649aoli26 wants to merge 7 commits into
aoli26 wants to merge 7 commits into
Conversation
5fea303 to
49a21d4
Compare
- kernel: m_oob_clip + m_oob_store {buffer, tdm_tail}. A/A-scale load clip via
TDM tensor_dim1, C-store clips via buffer num_records, split-K via per-lane
(row < M) predicate on the atomic path.
- tdm_ops: make_tensor_descriptor_2d gains oob_outer_bound. It sets only
tensor_dim1 (HW OOB field); tile_dim1 stays the full per-warp tile. Accepts
int|index|i32, raises otherwise. None keeps the original (byte-identical) path.
- tests: M-pad coverage (M=16..1000 x buffer/tdm_tail x bf16/f32 + split-K).
Remove the m_oob_store parameter from compile_fp8fp4_gemm / compile_ptpc_gemm
and pick the non-aligned-M output clip internally:
tdm_tail when use_tdm_store and split_k == 1 (full tiles keep the fast TDM
store; the <=1 partial last M-tile falls back to buffer num_records)
buffer otherwise (whole-output num_records clip; split_k>1 uses the
per-lane row < M atomic predicate)
A whole-output buffer clip regressed aligned production prefill by +15%..+82%,
while tdm_tail stays within ~2% of the no-clip path, so a static buffer default
was wrong. The choice is fully derivable from use_tdm_store/split_k, so cache_tag
drops m_oob_store too (no collision).
Tests: the mxscale mpad test now parametrizes use_tdm_store to cover both auto
branches (tdm_tail / buffer); the atomic branch stays covered by the split-k mpad
test.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add per-token per-channel (PTPC) scaling to the gfx1250 GEMM kernel, where scales are per-token
sa[M]and per-channelsb[N](constant along K) fp32 data and thus applied once in the epilogue rather than per K-block. Also add non-tile-aligned M (M-OOB) support so the host can pass an unpadded runtime M directly, dropping the per-call A/C pad alloc + memcpy.Technical Details
PTPC FP8 runs the unscaled WMMA in the K-loop while A8W4 uses the scaled f8f6f4 op with an identity scale, and
sa*sbis applied in fp32 in the epilogue (split-K supported via per-chunk scale + atomic add). All changes are compile-time gated to PTPC so the mxscale path is untouched; PTPC additionally skips scale TDM/LDS (only 2 loader waves needed) and prefetches the epiloguesa/sbloads behind the last WMMAs.M-OOB: A/A-scale loads skip rows ≥ M via the TDM descriptor
tensor_dim1(fault-safe, no OOB fetch), and the output clip is auto-selected —tdm_tail(TDM store for full tiles, buffernum_recordsfor the partial last M-tile) on the TDM-store path, else a whole-output buffer clip or a per-lanerow < Mpredicate on the split-K atomic path.make_tensor_descriptor_2dgainsoob_outer_bound; aligned-M is mostly unaffected by this change.Test Plan
pytest tests/kernels/test_gemm_fp8fp4_gfx1250.py -k 'ptpc or mpad', plus ISA inspection of the PTPC kernels.Test Result
All PTPC (FP8 + A8W4 + split-K) and M-pad (M=16..1000, buffer/tdm/atomic output paths) tests pass; ISA confirms scale TDM removal and epilogue prefetch with lower VGPR count and 0 spill.
Submission Checklist