Skip to content

[gfx1250][gemm] Add PTPC FP8/A8W4 and non-tile-aligned M support#649

Draft
aoli26 wants to merge 7 commits into
mainfrom
gfx1250/gemm_ptpc
Draft

[gfx1250][gemm] Add PTPC FP8/A8W4 and non-tile-aligned M support#649
aoli26 wants to merge 7 commits into
mainfrom
gfx1250/gemm_ptpc

Conversation

@aoli26
Copy link
Copy Markdown
Contributor

@aoli26 aoli26 commented Jun 3, 2026

Motivation

Add per-token per-channel (PTPC) scaling to the gfx1250 GEMM kernel, where scales are per-token sa[M] and per-channel sb[N] (constant along K) fp32 data and thus applied once in the epilogue rather than per K-block. Also add non-tile-aligned M (M-OOB) support so the host can pass an unpadded runtime M directly, dropping the per-call A/C pad alloc + memcpy.

Technical Details

PTPC FP8 runs the unscaled WMMA in the K-loop while A8W4 uses the scaled f8f6f4 op with an identity scale, and sa*sb is applied in fp32 in the epilogue (split-K supported via per-chunk scale + atomic add). All changes are compile-time gated to PTPC so the mxscale path is untouched; PTPC additionally skips scale TDM/LDS (only 2 loader waves needed) and prefetches the epilogue sa/sb loads behind the last WMMAs.

M-OOB: A/A-scale loads skip rows ≥ M via the TDM descriptor tensor_dim1 (fault-safe, no OOB fetch), and the output clip is auto-selected — tdm_tail (TDM store for full tiles, buffer num_records for the partial last M-tile) on the TDM-store path, else a whole-output buffer clip or a per-lane row < M predicate on the split-K atomic path. make_tensor_descriptor_2d gains oob_outer_bound; aligned-M is mostly unaffected by this change.

Test Plan

pytest tests/kernels/test_gemm_fp8fp4_gfx1250.py -k 'ptpc or mpad', plus ISA inspection of the PTPC kernels.

Test Result

All PTPC (FP8 + A8W4 + split-K) and M-pad (M=16..1000, buffer/tdm/atomic output paths) tests pass; ISA confirms scale TDM removal and epilogue prefetch with lower VGPR count and 0 spill.

Submission Checklist

@aoli26 aoli26 force-pushed the gfx1250/gemm_ptpc branch 2 times, most recently from 5fea303 to 49a21d4 Compare June 6, 2026 03:05
@aoli26 aoli26 changed the title [gfx1250][gemm] Add PTPC FP8/A8W4 support [gfx1250][gemm] Add PTPC FP8/A8W4 and non-tile-aligned M support Jun 6, 2026
Base automatically changed from gfx1250/gemm_fp8_opt to main June 6, 2026 14:15
aoli26 added 7 commits June 6, 2026 15:32
- kernel: m_oob_clip + m_oob_store {buffer, tdm_tail}. A/A-scale load clip via
  TDM tensor_dim1, C-store clips via buffer num_records, split-K via per-lane
  (row < M) predicate on the atomic path.
- tdm_ops: make_tensor_descriptor_2d gains oob_outer_bound. It sets only
  tensor_dim1 (HW OOB field); tile_dim1 stays the full per-warp tile. Accepts
  int|index|i32, raises otherwise. None keeps the original (byte-identical) path.
- tests: M-pad coverage (M=16..1000 x buffer/tdm_tail x bf16/f32 + split-K).
Remove the m_oob_store parameter from compile_fp8fp4_gemm / compile_ptpc_gemm
and pick the non-aligned-M output clip internally:
  tdm_tail  when use_tdm_store and split_k == 1 (full tiles keep the fast TDM
            store; the <=1 partial last M-tile falls back to buffer num_records)
  buffer    otherwise (whole-output num_records clip; split_k>1 uses the
            per-lane row < M atomic predicate)

A whole-output buffer clip regressed aligned production prefill by +15%..+82%,
while tdm_tail stays within ~2% of the no-clip path, so a static buffer default
was wrong. The choice is fully derivable from use_tdm_store/split_k, so cache_tag
drops m_oob_store too (no collision).

Tests: the mxscale mpad test now parametrizes use_tdm_store to cover both auto
branches (tdm_tail / buffer); the atomic branch stays covered by the split-k mpad
test.
@aoli26 aoli26 force-pushed the gfx1250/gemm_ptpc branch from 1b5e5a5 to 542641d Compare June 6, 2026 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant