EP+pad support for Step-3.5-Flash-FP8#1091
Conversation
Add Step3p5ForCausalLM model support for the Step-3.5-Flash architecture,
and fix a critical MoE correctness bug on gfx950 (MI350X).
Core MoE fix (atom/model_ops/moe.py):
Previously skipped shuffle_weights() for gfx950 BF16 g1u1 based on the
incorrect assumption that the CK 2-stage preshuffle_off (NSwizzle=0)
kernel expects un-shuffled weights. Verified: preshuffle_off GEMM is
wrong on gfx950; preshuffle_on (NSwizzle=1) is correct. Always call
shuffle_weights() so the correct kernel path is selected.
Step-3.5-Flash model support (atom/models/step3p5.py):
- Mixed full/sliding window attention (per layer_types config)
- 288 routed + 1 shared expert MoE with sigmoid routing
- Per-layer SwigluStep activation: layers with swiglu_limits[i]>0 use
ActivationType.SwigluStep (CK kernel applies silu(g).clamp(7)*up.clamp(±7));
other layers use plain Silu. Shared expert at SwigluStep layers is kept
on the dense MLP path (kernel clamp is routed-expert-only).
- Fused expert loading (flat [E,I,H] checkpoint format)
- clamp_limit applied to dense MLP and shared expert via Step3p5MLP
atom/model_engine/model_runner.py:
- Register Step3p5ForCausalLM architecture
- Handle num_attention_groups config key (Step-3.5 uses this instead of
num_key_value_heads) in KV head count calculations
atom/model_loader/loader.py:
- Fix fused expert detection order: check before packed_modules_mapping
to prevent moe.gate_proj being matched as gate_up_proj
atom/model_ops/attentions/aiter_attention.py:
- Handle num_attention_groups config key for KV head count
atom/examples/simple_inference.py:
- Add --max-tokens arg and trust_remote_code support
Verified: tp=2 Step-3.5-Flash inference, 4 prompts, no NaN/crash,
coherent output (with ATOM_STEP3P5_NO_SLIDING=1 workaround for
pa_decode_gluon bug on gfx950, tracked separately).
Co-Authored-By: Jun Lin <junlin12@amd.com>
CK 2-stage MoE kernel (gemm_moe_ck2stages.cu L98) computes stage1 N as w1.size(1)/2 = inter_dim. The stage1 dispatch selects NPerBlock based on inter_dim range: - inter <= 192: NPerBlock = 64 -> need inter % 64 == 0 - inter > 192: NPerBlock = 128 -> need inter % 128 == 0 Step-3.5-Flash with tp=4 gives inter=320 (320%128=64 != 0, crash) and with tp=8 gives inter=160 (160%64=32 != 0, crash). Fix: in process_weights_after_loading, pad inter_dim before shuffle_weights() using alignment = 64 if inter<=192 else 128: - inter=160 -> 192 (tp=8, 192%64=0) - inter=320 -> 384 (tp=4, 384%128=0, 384%64=0) Zero-padding is safe: padded rows carry zero weight so contribute nothing to fused_moe output. Verified 2026-04-24 on gfx950 (MI350X): - cos_sim >= 0.9999 vs torch reference (M=1..256) - tp=4 inference: 4 prompts complete, no crash, output correct Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The else branch in get_fused_moe_quant_config was shared between block_quant (per_1x128/per_1x32) and per_tensor paths, hardcoding block_shape=None for all. Block-quantized FP8 models should receive block_shape=[128,128] (per_1x128) or [1,32] (per_1x32) to correctly configure the quant config, particularly for EP paths. Split the else branch into explicit per_1x128/per_1x32/fallback cases and unify the fp8_w8a8_moe_quant_config call.
Three coordinated fixes in Fp8MoEMethod for per_1x128 block scale: 1. create_weights: make ValueError check padding-aware Compute padded_inter = ceil(inter/block_n)*block_n and check against padded_inter instead of raw inter, allowing tp=4 (inter=320) to pass while preserving the guard for truly unaligned cases. 2. _process_block_quant: zero-pad weights before shuffle_weights After normalize and before shuffle, zero-pad w13 from [E,2*320,H] to [E,2*384,H] and w2 from [E,H,320] to [E,H,384], mirroring the BF16 approach in UnquantizedFusedMoEMethod.process_weights_after_loading. Padding zeros contribute 0 to GEMM output (dequant(0, scale)=0). Scale tensors already use ceil(inter/block_n) and need no change. 3. _load_w13 / _load_w2: fix scale TP sharding floor→ceil (root cause) The per_1x128 scale for full inter=1280 has 10 N-blocks. TP=4 sharding with floor gives 10//4=2 blocks per rank; the 3rd (partial) block is never copied and stays at the torch.ones() init value of 1.0. With scale=1.0 instead of ~0.0002, dequant amplifies by ~5000× causing complete garbage output despite correct weight loading. Fix: use ceil division and add narrow() bounds protection for the last rank which may have fewer elements than the ceil size. Safe for tp=2 (10/2=5 exact, ceil==floor) and tp=1 (no sharding). Verification: FP8 tp=4: 4 prompts, TTFT=92ms, TPOT=14ms, coherent output ✅ BF16 tp=4 regression: TTFT=76-77ms, coherent output ✅ FP8 tp=2 regression: TTFT=86ms, coherent output ✅
…ding With NPerBlock=64 CK kernel support, inter_dim=320 (tp=4) is 64-aligned and no longer requires zero-padding to 384. Changed align from '64 if inter<=192 else block_n' to always 64, so: - tp=4 (inter=320): 320%64=0 -> no padding (was 320->384, saved 17% compute) - tp=8 (inter=160): 160%64=32 -> pad to 192 (unchanged) - tp=2 (inter=640): 640%64=0 -> no padding (unchanged) Scale tensor shape (ceil(320/128)=3) unchanged; no re-quantization needed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Stage2 KPerBlock=64 is not compilable on gfx950 (FP8 mfma KPack=32 constraint). Since stage1 output and stage2 weight K must match, both w13 and w2 require the same inter_dim padding. Restoring: align = 64 if inter_dim <= 192 else block_n (=128) Added comment explaining why full no-padding is currently blocked. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…onfigs _process_block_quant used 'align = 64 if inter_dim <= 192 else block_n', copied from the BF16 path. For FP8 blockscale this is wrong: - FP8 stage2 only has KPerBlock=128 (KPack=32 mfma constraint prevents KPerBlock=64) - align=64 gives inter_pad=192 for tp=8 (inter=160), but 192 % 128 = 64 != 0 - device_moe_gemm_blockscale.hpp L448 rejects K % KPerBlock != 0 → kernel fails Fix: always use align = block_n (=128 for per_1x128), so inter_pad is always a multiple of 128 and stage2 KPerBlock=128 dispatch succeeds: tp=2: inter=640 → 640 (no padding, unchanged) tp=4: inter=320 → 384 (unchanged) tp=8: inter=160 → 256 (was 192, now correctly aligned) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When the per_1x128 scale block count is smaller than tp_size (observed on Step-3.5-Flash-FP8 at tp=8 with inter_dim=1280 → D=10), the ceil split leaves trailing ranks with start >= D so narrow(start, size) hits size<0 and crashes weight load. Skip narrow + copy_ for those ranks. For fp8 scale tensors (torch.ones() initialised in Fp8MoEMethod._create_weights), additionally zero the rank's slot before the early return. Otherwise the downstream fp8 dequant multiplies the (uninitialised) fp8 weight by stale 1.0 instead of the correct quantization scale, contaminating the column gather / row reduction and producing garbled output. Matches MXFP4 scale init (moe.py:776,813). Verified on stepfun-ai/Step-3.5-Flash-FP8 (gfx942 / MI308X): - tp=8 A1/A2/A4 PASS — 4/4 prompts coherent (was: weight-load crash pre-patch; was: garbled output with early-return-only) - tp=2/tp=4 A1/A2/A3 PASS — no regression, zero-trigger confirmed (D=10, starts=[0,3,6,9] for tp=4, starts=[0,5] for tp=2 — all < D) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- black: reformat 6 PR-touched files - ruff (atom/models/step3p5.py): drop unused `Any` import, un-f-string two no-placeholder debug prints, remove unused `clamp_limit` and `swiglu_limits` locals in `Step3p5DecoderLayer.__init__`
…ds_kv 32 -> per-layer 4); fixes EP/TP e2e on PR#641
# Conflicts: # atom/examples/simple_inference.py # atom/model_ops/moe.py # atom/plugin/vllm/attention/metadata.py
… after main merge - layernorm GemmaRMSNorm.forward_cuda: contiguity guard for non-contiguous QK-norm input (regressed by main 28f9702 which kernelized forward_cuda). - config SWA guard: disable prefix_caching/chunked_prefill for sliding-window models (Step-3.5), generalizing the DeepSeek-V4-only guard (main default flip). - moe expert_mask: aiter rocm_aiter_fused_moe expects a binary 0/1 mask; feed (expert_map > -1) instead of the index-map expert_map. Root cause = main e33e678 (#875 enable-EP DeepSeek-V4) changed forward_impl to pass index-map. Also applied defensively at CompressedTensorsFp8MoEMethod + modular_kernel (same pattern, GPU-untested on those non-Step-3.5 paths). Verified on merged HEAD: EP 4/4 + pad 4/4 coherent, 0 HIP illegal access, SWA/custom-routing/swiglustep all green.
Resolve 2 conflicts (atom/model_ops/moe.py, atom/model_ops/fused_moe/modular_kernel.py) by adopting main's binary-expert_mask handling, which converged on the same fix as our EP regression patch: - moe.py rocm_aiter_fused_moe call sites: take main's `expert_mask=layer.expert_mask` (binary 0/1) over our equivalent `(expert_map > -1)`. main fixed the index-map-into- mask bug upstream (#1076/#1087/#1132), so our conversion is now redundant. - modular_kernel.py fused_moe call: take main's version which now has a dedicated binary `expert_mask` param (apply() passes layer.expert_mask) + forwards model-specific `moe_extra_args` (e.g. DeepSeek-V4 MXFP4 gate_mode/swiglu_limit). Invariant preserved: aiter fused_moe / rocm_aiter_fused_moe expert_mask slot receives a binary 0/1 mask, never the index-map expert_map -> no ck_moe_stage1 OOB. Our other fixes survive the merge: q_norm contiguity guard (layernorm.py), SWA prefix/chunked-prefill guard (config.py _uses_sliding_window), Step-3.5 support.
In _get_num_kv_heads() and the kv-cache block computation, the num_key_value_heads -> num_attention_groups getattr fallback returns None when a model config has neither field, which then hits `None >= world_size` and raises an opaque TypeError. Add an explicit None check that raises a ValueError naming both missing fields and the model architectures, so the failure is actionable. No behavior change for models that define either field (Step-3.5 uses num_attention_groups=8).
The fused-expert load path recorded the checkpoint name (e.g. moe.gate_proj.weight) in loaded_weights_record, but the post-load verification diffs against params_dict keys (param names, e.g. moe.experts.w13_weight / w2_weight). The name-space mismatch made fused-expert params always appear in the "NOT loaded" list - a false positive (pad e2e arithmetic is correct, proving the weights are in fact loaded). Record the already-computed name_mapped (guaranteed to be a valid params_dict key) so the verification matches, eliminating the false positive while preserving detection of genuinely missing params. Only affects models that expose detect_fused_expert_format; no change for other models.
…h path c732b99 added the num_key_value_heads -> num_attention_groups getattr fallback to two of the three KV-head computations in this file, but missed the one in _set_ubatch_pa_buffers (block_size==1024 ubatch path). Step-3.5 configs have no num_key_value_heads, so that path would AttributeError when ubatch + block_size==1024 is used. Apply the same fallback for consistency. Not exercised by the current e2e (block_size=16) but closes the gap.
bf66cd1 to
e807a49
Compare
swa_workspace uses swa_max_num_heads_kv for its head dim; if no attention layer exposes num_kv_heads it would be 0, producing a zero-sized workspace and silently broken SWA. Assert it is > 0 when SWA is enabled so the failure surfaces at build time. Step-3.5 always has >=1 so this is a guard for future SWA models.
Readability-only, no behavior change: - forward() now reads self._need_fp32_gate directly (set in __init__) instead of getattr(..., True), removing the dead/contradictory default. - Note in the module docstring that MTP (num_nextn_predict_layers) is config metadata only and not implemented (not needed for standard inference). - Comment the config.sink bool/int semantics (only a positive int enables sinks).
There was a problem hiding this comment.
Pull request overview
Adds end-to-end inference support for stepfun Step-3.5-Flash-FP8 on ATOM (including EP and pure-TP+padding paths), with additional fixes for sliding-window attention workspace sizing and MoE FP8/block-quant weight handling.
Changes:
- Introduce a new
Step3p5ForCausalLMmodel implementation (MoE + per-layer attention configuration, custom routing, SWA/sinks support). - Extend MoE kernels/weight-loading to handle intermediate-dim padding/alignment and TP sharding edge-cases for FP8 block-quant (and related BF16 padding/shuffle behavior).
- Improve runtime configuration compatibility for SWA and KV-head derivation (fallback to
num_attention_groups, disable prefix caching/chunked prefill for SWA models).
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/plugin/vllm/attention/metadata.py | Sizes SWA workspace using per-layer KV-head count (max across layers). |
| atom/models/step3p5.py | New Step-3.5-Flash(-FP8) inference-only model definition and fused-expert loading hooks. |
| atom/model_ops/moe.py | Adds inter-dim padding for MoE weights, block-quant padding logic, and TP sharded loading fixes. |
| atom/model_ops/layernorm.py | Contiguity guard to avoid .view() failures on non-contiguous Q/K slices. |
| atom/model_ops/attentions/aiter_attention.py | KV-head config fallback to num_attention_groups for models lacking num_key_value_heads. |
| atom/model_loader/loader.py | Prevents fused expert weights from being mis-mapped by packed module rules; fixes loaded-weight bookkeeping. |
| atom/model_engine/model_runner.py | KV-head derivation now supports num_attention_groups and raises a clear error if missing. |
| atom/examples/simple_inference.py | Passes trust_remote_code through to tokenizer loading. |
| atom/config.py | Detects SWA usage and disables prefix caching/chunked prefill accordingly. |
Comments suppressed due to low confidence (1)
atom/model_ops/moe.py:2714
_load_w13usesceilsharding and narrowsloaded_weighttosize, but it still narrowsexpert_datatoload_shard_size. For the last TP rank when the sharded dimension isn't divisible bytp_size(common for block-scale tensors, e.g. 10 blocks over TP4), this makesexpert_datalarger thanloaded_weightandcopy_()will fail. Also, the remainder of the scale slice should be zeroed so padded blocks are a true no-op.
size = min(load_shard_size, loaded_weight.shape[shard_dim] - start)
loaded_weight = loaded_weight.narrow(shard_dim, start, size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, expert_shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(
shard_dim, expert_shard_size, expert_shard_size
)
# When expert_data is padded beyond the actual weight size, narrow to
# the loaded weight size so the copy shape matches.
if load_shard_size != expert_shard_size:
expert_data = expert_data.narrow(shard_dim, 0, load_shard_size)
if expert_data.dtype != dtypes.fp4x2:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _tp_size = ( | ||
| getattr(self, "tp_size", getattr(self.moe, "tp_size", 1)) | ||
| if hasattr(self, "moe") | ||
| else getattr(self, "tp_size", 1) | ||
| ) | ||
| _ep_size = ( | ||
| getattr(self, "ep_size", getattr(self.moe, "ep_size", 1)) | ||
| if hasattr(self, "moe") | ||
| else getattr(self, "ep_size", 1) | ||
| ) | ||
| if layer.reduce_results and (_tp_size > 1 or _ep_size > 1): | ||
| from aiter.dist.parallel_state import get_tp_group | ||
|
|
||
| _moe_result = get_tp_group().all_reduce( | ||
| _moe_result, ca_fp8_quant=False | ||
| ) | ||
| return _moe_result |
There was a problem hiding this comment.
Fixed in 94e031d: removed the redundant all_reduce in Mxfp4MoEMethod.apply — other apply() methods don't reduce internally, forward_impl already all-reduces when reduce_results is set, so reduction now happens once. Good catch
| if not load_full: | ||
| # Derive shard size from loaded_weight (unpadded checkpoint) to | ||
| # avoid out-of-bounds when expert_data is padded (e.g. MXFP4). | ||
| # Use ceil (same reason as _load_w13: partial last scale block). | ||
| load_shard_size = ( | ||
| loaded_weight.shape[shard_dim] + self.tp_size - 1 | ||
| ) // self.tp_size | ||
| start = load_shard_size * tp_rank | ||
| # See _load_w13 comment above: when D < tp_size the ceil split | ||
| # leaves trailing ranks with no slice; skip narrow + copy_. | ||
| if start >= loaded_weight.shape[shard_dim]: | ||
| # Zero the scale slice so dequant=0 instead of multiplying by | ||
| # stale init=1.0; see _load_w13 comment for full rationale. | ||
| if expert_data.dtype == torch.float32: | ||
| if load_shard_size != shard_size: | ||
| expert_data.narrow(shard_dim, 0, load_shard_size).zero_() | ||
| else: | ||
| expert_data.zero_() | ||
| return | ||
| size = min(load_shard_size, loaded_weight.shape[shard_dim] - start) | ||
| loaded_weight = loaded_weight.narrow(shard_dim, start, size) | ||
| if load_shard_size != shard_size: | ||
| expert_data = expert_data.narrow(shard_dim, 0, load_shard_size) |
There was a problem hiding this comment.
Fixed in 5663a66: expert_data now narrowed to actual shard size with remainder zeroed — trailing partial rank (e.g. tp=4, D=10) no longer hits copy_() shape mismatch. No change for tp=8
| num_head_k = max( | ||
| 1, | ||
| getattr( | ||
| hf_config, | ||
| "num_key_value_heads", | ||
| getattr(hf_config, "num_attention_groups", None), | ||
| ) | ||
| // get_tp_group().world_size, | ||
| ) |
There was a problem hiding this comment.
Fixed in 4ea9985: raises a clear ValueError when neither num_key_value_heads nor num_attention_groups is present (mirrors model_runner), instead of a None // world_size TypeError.
| if sink_size > 0: | ||
| sinks = nn.Parameter(torch.empty(self.num_heads, requires_grad=False)) | ||
|
|
There was a problem hiding this comment.
Fixed in 4ea9985: sinks now zero-initialized (torch.empty →torch.zeros)
… reduce) FusedMoE.forward_impl/forward_impl_graph already all-reduce the MoE output when reduce_results is set, like the other quant methods' apply() which do not reduce internally. Mxfp4MoEMethod.apply additionally all-reduced, so a reduce_results mxfp4 layer reduced twice (output scaled by group size). Drop the in-apply all_reduce so reduction happens once in forward_impl (#1091 Copilot review). Affects mxfp4 TP/EP only.
… rank In _load_w13/_load_w2, when the per-rank scale-block count does not divide evenly (e.g. tp=4 with D=10), the trailing partial rank had loaded_weight narrowed to `size` but expert_data still narrowed to load_shard_size, causing a copy_() shape mismatch and leaving the remainder uninitialized. Narrow expert_data to the same `size` and zero the remainder. No change for tp=8 (D divides evenly; trailing ranks already zero+return) (#1091 Copilot review).
- aiter_attention: raise a clear ValueError when neither num_key_value_heads nor num_attention_groups is present (mirror model_runner) at all 3 KV-head computations, instead of a None // world_size TypeError. - metadata: replace the SWA swa_max_num_heads_kv assert with an explicit raise (assert is stripped under python -O). - step3p5: zero-initialize attention sinks (torch.empty -> torch.zeros) to avoid uninitialized data if a checkpoint omits them. (#1091 Copilot review.)
| inter_dim = layer.w2_weight.shape[-1] | ||
| block_n = 128 if self.quant_type == QuantType.per_1x128 else 32 | ||
| # FP8 blockscale stage2 requires KPerBlock=128 (gfx950 FP8 mfma KPack=32 constraint | ||
| # prevents KPerBlock=64). align must always be block_n(=128) so that inter_pad%128==0. | ||
| # Bug fix: previously used align=64 for inter<=192 (copied from BF16 path), but | ||
| # 192%128=64!=0 → stage2 kernel dispatch fails. Correct: always align to block_n. | ||
| # tp=8 inter=160 → 256 (3×128→no, ceil(160/128)*128=256); tp=4 inter=320 → 384. | ||
| align = block_n | ||
| inter_pad = (inter_dim + align - 1) // align * align |
| "Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5.Qwen3_5MoeMultimodalModel", | ||
| "KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM", | ||
| "MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM", | ||
| "Step3p5ForCausalLM": "atom.models.step3p5.Step3p5ForCausalLM", | ||
| "MiMoV2FlashForCausalLM": "atom.models.mimo_v2_flash.MiMoV2FlashForCausalLM", |
# Conflicts: # atom/model_engine/model_runner.py # atom/model_ops/moe.py
Motivation
Enable end-to-end serving of stepfun Step-3.5-Flash-FP8 (FP8 block-quantized MoE) on ATOM at TP8 on AMD gfx942 (MI308X), across the two production parallelism paths: Expert-Parallel (EP) and pure Tensor-Parallel with padding.
Technical Details
This branch adds the support required to load and run Step-3.5-Flash-FP8:
step3p5model definition and MoE weight shuffling for the routed-expert layout.1x128block scaling (block_shape = [128, 128]), with correctinter_dimpadding-alignment across all TP configurations, plumbed throughFp8MoEMethod.get_fused_moe_quant_config.D < tp_sizein the fp8_load_w13/_load_w2shard loaders so the routed-expert weights load correctly at TP8.sliding_window.ModelConfig.get_num_kv_heads()returns the model-level value (32), but the per-layer attention modules use 4 KV heads; the SWA workspace is now sized from the true per-layernum_kv_heads(max across layers), so the workspace head dimension is correct. Touches onlyatom/plugin/attention.py.--enable-expert-parallel): experts sharded,inter_dim = 1280.inter_dimsharded1280/8 = 160, padded to 256.Test Plan
End-to-end correctness via the ATOM-native
simple_inferenceexample on 8×gfx942 (TP8, FP8 block-quantized MoE, FP8 KV cache), over the example's 4 prompts, for both production paths:--enable-expert-parallel.Test Result
Both paths 4/4 coherent (exit 0), non-garbled, with natural EOS and no faults:
1 + 2 + 3 = 6).1 + 2 + 3 = 6).Submission Checklist
main).Related
af7118e). Note: the SwigluStep host codegen is not yet in aitermain, so this stack requires that aiter branch.