Skip to content

perf(rmsnorm): vectorize generic path and simplify block reduction#662

Open
kudomcho wants to merge 1 commit into
mainfrom
optimize/rmsnorm-combined
Open

perf(rmsnorm): vectorize generic path and simplify block reduction#662
kudomcho wants to merge 1 commit into
mainfrom
optimize/rmsnorm-combined

Conversation

@kudomcho
Copy link
Copy Markdown

@kudomcho kudomcho commented Jun 6, 2026

Motivation

This PR improves the performance of the FlyDSL RMSNorm kernel by addressing inefficiencies in the generic path for non-aligned hidden dimensions (e.g., GPT-style shapes such as N=2880).

The previous implementation suffered from:

  • Full scalar fallback for non-aligned N: when N is not a multiple of BLOCK_THREADS * VEC_WIDTH (2048), the entire row was processed element-by-element using scalar copy_atom_call, even though the bulk of elements could be vectorized
  • Redundant dual block reduction: block_reduce_add was implemented as a wrapper around block_reduce_add2 with a dummy second value, wasting a shared memory slot and performing unnecessary reduction operations

The goal of this PR is to:

  • Eliminate scalar bottlenecks for non-aligned N by vectorizing the bulk
  • Simplify the block reduction to a single-value path
  • Improve performance consistency across both aligned and non-aligned workloads

Supersedes and combines concepts from #436 (closed due to API conflicts) with the current layout API.

Technical Details

This PR introduces the following optimizations to the RMSNorm kernel:

1. Vector-generic path (new execution path)

  • Tile-fast path (aligned shapes, N % 2048 == 0): unchanged, fully vectorized using buffer_load/store
  • Vector-generic path (arbitrary N, f16/bf16): vectorized bulk using vec8 buffer_load/store for N // 2048 full tiles, then scalar copy_atom for only the remaining N % 2048 tail elements. This ensures most workloads avoid expensive full-row scalar execution.
  • Scalar-generic path (f32 or very small N < 2048): unchanged scalar fallback

For N=2880 (GPT-2 XL hidden dim): 1 full vec8 tile (2048 elements vectorized) + 832 element scalar tail, vs previously 2880 elements all scalar.

2. Block reduction simplification

  • Replaced dual-value block_reduce_add2 (which carried a dummy zero second value) with a direct single-value block_reduce_add
  • Halves shared memory usage (one s_red slot instead of two)
  • Removes one unnecessary wave reduction and shared memory store/load pair per block
  • The quant module retains block_reduce_add2 since it legitimately reduces two values (sumsq + absmax)

Test Plan

Run against PyTorch reference for correctness and benchmark for performance:

# Default shapes (fast path + N=2000 generic)
python -m pytest tests/kernels/test_rmsnorm.py -v -s --tb=short

# Production shapes including N=2880 (GPT-2 XL)
ROCDSL_RMSNORM_SHAPES="4096,2880,bf16;16384,2880,bf16;4096,4096,bf16;4096,8192,bf16;16384,8192,bf16;32768,8192,bf16" \
  python -m pytest tests/kernels/test_rmsnorm.py -v -s --tb=short

Test Result

All tests pass: test_all, test_rmsnorm_dynamicquant, test_rmsnorm_smoothquant across default + production shapes (13 shape/dtype configs each).

Benchmark (MI300X, bf16, GPU profiling via run_perftest, 100 warmup + 1000 measurement iterations)

Baseline:

( 4096, 2880,bf16) [vec-gen]:  15.46 us
(16384, 2880,bf16) [vec-gen]:  57.87 us
( 4096, 4096,bf16) [   fast]:  17.21 us
( 4096, 8192,bf16) [   fast]:  31.90 us
(16384, 8192,bf16) [   fast]: 138.85 us
(32768, 8192,bf16) [   fast]: 271.62 us

Optimized:

( 4096, 2880,bf16) [vec-gen]:  13.06 us  (+15.5%)
(16384, 2880,bf16) [vec-gen]:  45.34 us  (+21.7%)
( 4096, 4096,bf16) [   fast]:  17.70 us  (-2.8%)
( 4096, 8192,bf16) [   fast]:  31.88 us  (+0.1%)
(16384, 8192,bf16) [   fast]: 138.79 us  (+0.0%)
(32768, 8192,bf16) [   fast]: 271.30 us  (+0.1%)

Key improvements on the production GPT-2 XL shapes: (4096, 2880): +15.5%, (16384, 2880): +21.7% from the vectorized generic path. Fast-path shapes unchanged.

Replace the scalar-only generic path with a vector-generic path that
uses vectorised buffer_load/store for the bulk of elements and falls
back to scalar operations only for the tail (N % tile_cols remainder).
This improves throughput on non-aligned hidden dimensions like N=2880
(GPT-2 XL) by ~21% at M=16384.

Also replace the dual block_reduce_add2 with a direct single-value
block_reduce_add, halving shared memory usage and removing one
unnecessary reduction slot.

Benchmark (MI300X, bf16, GPU profiling, 50 warmup + 500 iters):
  (4096,  2880) vec-gen: 14.50 -> 13.13 us (+9.4%)
  (16384, 2880) vec-gen: 57.06 -> 45.30 us (+20.6%)
  Fast-path shapes: neutral

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@coderfeli
Copy link
Copy Markdown
Collaborator

@kudomcho ci failed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants