Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c732b99
feat: add Step-3.5-Flash support and fix MoE weight shuffling on gfx950
LJ-underdog Apr 23, 2026
26585d4
fix: pad inter_dim in UnquantizedFusedMoEMethod for gfx950 tp=4/8
LJ-underdog Apr 24, 2026
841dc4e
fix: pass correct block_shape in Fp8MoEMethod.get_fused_moe_quant_config
LJ-underdog Apr 24, 2026
ccb6462
fix: support FP8 block-quantized inference at tp=4 (inter_dim=320)
LJ-underdog Apr 25, 2026
270fee7
fix(moe): use align=64 for FP8 blockscale to remove inter_dim=320 pad…
LJ-underdog Apr 27, 2026
3696345
revert(moe): restore FP8 blockscale inter_dim padding align logic
LJ-underdog Apr 27, 2026
acff926
fix(moe): correct FP8 blockscale inter_dim padding align for all tp c…
LJ-underdog Apr 27, 2026
969d564
fix(moe): handle D < tp_size in fp8 _load_w13/_load_w2
LJ-underdog Apr 29, 2026
2d8cc80
style: black format and ruff fixes for PR #641
arthurliu1998 May 6, 2026
2c29406
fix(attn): stepfun-Flash-FP8 SWA per-layer kv-head workspace (num_hea…
LJ-underdog Jun 3, 2026
63be9c2
chore: remove debug instrumentation from step3p5
LJ-underdog Jun 5, 2026
bb06f0a
Merge remote-tracking branch 'origin/main' into feat-ep-pad-clean
LJ-underdog Jun 5, 2026
7c48c52
fix(step3p5-EP): resolve 3 merge regressions breaking Step-3.5 EP/pad…
LJ-underdog Jun 9, 2026
20c627d
Merge origin/main (eecc546a) into feat-ep-pad-clean
LJ-underdog Jun 9, 2026
17de92a
fix(model_runner): raise clear error when KV-head config fields missing
LJ-underdog Jun 10, 2026
f8f8f79
fix(loader): record mapped param name for fused-expert weights
LJ-underdog Jun 10, 2026
e807a49
fix(aiter_attention): complete num_attention_groups fallback in ubatc…
LJ-underdog Jun 10, 2026
a33508d
fix(metadata): assert non-zero SWA per-layer kv-head count
LJ-underdog Jun 10, 2026
a37dfdb
chore(step3p5): clarify fp32_gate default, MTP docstring, sink semantics
LJ-underdog Jun 10, 2026
94e031d
fix(moe): remove redundant all_reduce in Mxfp4MoEMethod.apply (double…
LJ-underdog Jun 10, 2026
5663a66
fix(moe): narrow expert_data to actual shard size on trailing partial…
LJ-underdog Jun 10, 2026
4ea9985
fix(attn/metadata/step3p5): harden KV-head/SWA/sink edge cases (Copilot)
LJ-underdog Jun 10, 2026
f3e1235
Merge branch 'main' into feat-ep-pad-clean
LJ-underdog Jun 10, 2026
eec21f3
Merge remote-tracking branch 'origin/main' into feat-ep-pad-clean
LJ-underdog Jun 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,29 @@ def _set_cudagraph_sizes(self):
elif len(cuda_graph_sizes) > 1:
self.graph_bs = cuda_graph_sizes

def _uses_sliding_window(self) -> bool:
"""True iff the model uses sliding-window attention (global or interleaved).

Prefix caching's classical KV pool cannot restore the per-request SWA
buffer on a cache hit, so SWA models must run with it disabled.
"""
hf = self.hf_config
# Global sliding_window field (Step-3.5=512, Gemma, Mistral, Qwen-SWA, ...).
sw = getattr(hf, "sliding_window", None)
if isinstance(sw, int) and not isinstance(sw, bool) and sw > 0:
return True
# Interleaved SWA via layer_types
# (Step-3.5: ['full_attention', 'sliding_attention', ...]).
layer_types = getattr(hf, "layer_types", None) or []
if any("sliding" in str(t) for t in layer_types):
return True
# DeepSeek-V4: model_type is remapped to deepseek_v3, so detect SWA via the
# preserved architectures name (kept as a fallback).
arches = getattr(hf, "architectures", None) or []
if any("DeepseekV4" in str(a) for a in arches):
return True
return False

def __post_init__(self):
if isinstance(self.compilation_config, dict):
self.compilation_config = CompilationConfig(**self.compilation_config)
Expand Down Expand Up @@ -1163,16 +1186,32 @@ def __post_init__(self):
v4_block_size = 128
if self.kv_cache_block_size != v4_block_size:
self.kv_cache_block_size = v4_block_size
# TODO: V4's per-request SWA buffer cannot be restored from the classical
# KV pool on prefix cache hit, so disable prefix caching silently.
if self.enable_prefix_caching:
import logging

logging.getLogger(__name__).warning(
"DeepSeek-V4 does not support prefix caching "
"(SWA buffer is not cacheable); disabling automatically."
# SWA models cannot restore the per-request sliding-window KV buffer from the
# classical KV pool on a prefix-cache hit, so disable prefix caching for any
# sliding-window model (DeepSeek-V4, Step-3.5, ...). Generalizes main's
# original V4-only guard, which left Step-3.5/SWA exposed to a merge default
# flip (enable_prefix_caching default False->True). Non-SWA models keep
# main's prefix-caching optimization.
if self._uses_sliding_window():
import logging

_log = logging.getLogger(__name__)
if self.enable_prefix_caching:
_log.warning(
"Model uses sliding-window attention (SWA buffer is not "
"cacheable); disabling prefix caching automatically."
)
self.enable_prefix_caching = False
if self.enable_chunked_prefill:
# Conservative: SWA + chunked prefill cross-chunk window correctness
# is unverified on GPU; restore the pre-merge-safe default (off).
# TODO: re-enable after SWA + chunked-prefill GPU validation.
_log.warning(
"Model uses sliding-window attention; disabling chunked "
"prefill (SWA + chunked prefill unverified)."
)
self.enable_chunked_prefill = False

def compute_hash(self) -> str:
"""
Expand Down
7 changes: 3 additions & 4 deletions atom/examples/simple_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def main():
engine_args = EngineArgs.from_cli_args(args)
llm = engine_args.create_engine()

tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(
args.model, trust_remote_code=getattr(args, "trust_remote_code", False)
)

sampling_params = SamplingParams(
temperature=args.temperature, max_tokens=args.max_tokens
Expand All @@ -70,9 +72,6 @@ def main():
for p in prompts
]
print("This is prompts:", prompts)
# print("Warming up...")
# _ = llm.generate(["warmup"], sampling_params)
# print("Warm up done")

print("\n" + "=" * 70)
print("Starting profiling...")
Expand Down
39 changes: 31 additions & 8 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5.Qwen3_5MoeMultimodalModel",
"KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM",
"MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM",
"Step3p5ForCausalLM": "atom.models.step3p5.Step3p5ForCausalLM",
"MiMoV2ForCausalLM": "atom.models.mimo_v2.MiMoV2ForCausalLM",
"MiMoV2FlashForCausalLM": "atom.models.mimo_v2.MiMoV2ForCausalLM",
}
Expand Down Expand Up @@ -1200,11 +1201,22 @@ def allocate_forward_vars(self):
def _get_num_kv_heads(self):
"""Return the per-rank number of KV heads."""
hf_config = self.config.hf_config
if hf_config.num_key_value_heads >= self.world_size:
assert hf_config.num_key_value_heads % self.world_size == 0
return hf_config.num_key_value_heads // self.world_size
num_kv_heads_cfg = getattr(
hf_config,
"num_key_value_heads",
getattr(hf_config, "num_attention_groups", None),
)
if num_kv_heads_cfg is None:
raise ValueError(
"Model config has neither 'num_key_value_heads' nor "
"'num_attention_groups'; cannot determine number of KV heads "
f"for {getattr(hf_config, 'architectures', hf_config)}"
)
if num_kv_heads_cfg >= self.world_size:
assert num_kv_heads_cfg % self.world_size == 0
return num_kv_heads_cfg // self.world_size
else:
assert self.world_size % hf_config.num_key_value_heads == 0
assert self.world_size % num_kv_heads_cfg == 0
return 1

def _mrope_positions_view(self, num_tokens: int) -> torch.Tensor:
Expand Down Expand Up @@ -1453,11 +1465,22 @@ def allocate_kv_cache(self, num_kvcache_blocks):
self.num_physical_kvcache_blocks = (
num_kvcache_blocks * self.attn_metadata_builder.block_ratio
)
if hf_config.num_key_value_heads >= self.world_size:
assert hf_config.num_key_value_heads % self.world_size == 0
num_kv_heads = hf_config.num_key_value_heads // self.world_size
num_kv_heads_cfg = getattr(
hf_config,
"num_key_value_heads",
getattr(hf_config, "num_attention_groups", None),
)
if num_kv_heads_cfg is None:
raise ValueError(
"Model config has neither 'num_key_value_heads' nor "
"'num_attention_groups'; cannot determine number of KV heads "
f"for {getattr(hf_config, 'architectures', hf_config)}"
)
if num_kv_heads_cfg >= self.world_size:
assert num_kv_heads_cfg % self.world_size == 0
num_kv_heads = num_kv_heads_cfg // self.world_size
else:
assert self.world_size % hf_config.num_key_value_heads == 0
assert self.world_size % num_kv_heads_cfg == 0
num_kv_heads = 1
# Promote to self so attention builders' build_kv_cache_tensor()
# hooks can access it without re-deriving from hf_config.
Expand Down
24 changes: 23 additions & 1 deletion atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,27 @@ def _submit(fn, *args):
maybe_matching_name,
f"{module_prefix}experts.{hf_config.n_routed_experts}.",
)
# Check fused expert format before packed_modules_mapping to avoid
# expert weights (e.g. moe.gate_proj) being incorrectly matched
# by packed_modules_mapping entries (e.g. gate_proj -> gate_up_proj).
if detect_fused_expert_fn is not None and not is_fused_expert:
if detect_fused_expert_fn(name):
is_fused_expert = True
if get_fused_expert_mapping_fn is not None:
fused_expert_params_mapping = get_fused_expert_mapping_fn()
for k in packed_modules_mapping:
# We handle the experts below in expert_params_mapping
if (
"mlp.experts." in name or "ffn.experts." in name
) and name not in params_dict:
continue
# Skip fused expert weights — handled below in expert loading path
if (
is_fused_expert
and detect_fused_expert_fn is not None
and detect_fused_expert_fn(name)
):
continue
if k in name:
packed_value = packed_modules_mapping[k]
# Handle both tuple (fuse parameter) and list (shard parameter)
Expand Down Expand Up @@ -556,7 +571,14 @@ def _submit(fn, *args):
)

if matched:
loaded_weights_record.add(prefix + name)
# Record the MAPPED param name (e.g.
# moe.experts.w13_weight), not the ckpt name
# (e.g. moe.gate_proj.weight): the post-load
# verification below diffs against params_dict
# keys (param names), so recording the ckpt name
# makes fused-expert params (w13_weight/w2_weight)
# falsely show up as "NOT loaded".
loaded_weights_record.add(prefix + name_mapped)
break

if matched:
Expand Down
39 changes: 34 additions & 5 deletions atom/model_ops/attentions/aiter_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,18 @@ def __init__(
else:
max_qlen = 1

num_head_k = max(1, hf_config.num_key_value_heads // get_tp_group().world_size)
num_kv_heads_cfg = getattr(
hf_config,
"num_key_value_heads",
getattr(hf_config, "num_attention_groups", None),
)
if num_kv_heads_cfg is None:
raise ValueError(
"Model config has neither 'num_key_value_heads' nor "
"'num_attention_groups'; cannot determine number of KV heads "
f"for {getattr(hf_config, 'architectures', hf_config)}"
)
num_head_k = max(1, num_kv_heads_cfg // get_tp_group().world_size)
(
(work_meta_data_size, work_meta_data_type),
(work_indptr_size, work_indptr_type),
Expand Down Expand Up @@ -236,9 +247,18 @@ def set_aiter_persistent_worker_buffers(self, bs: int):
config = self.model_runner.config
hf_config = config.hf_config
num_query_heads = self.num_attention_heads
num_kv_heads = max(
1, hf_config.num_key_value_heads // get_tp_group().world_size
num_kv_heads_cfg = getattr(
hf_config,
"num_key_value_heads",
getattr(hf_config, "num_attention_groups", None),
)
if num_kv_heads_cfg is None:
raise ValueError(
"Model config has neither 'num_key_value_heads' nor "
"'num_attention_groups'; cannot determine number of KV heads "
f"for {getattr(hf_config, 'architectures', hf_config)}"
)
num_kv_heads = max(1, num_kv_heads_cfg // get_tp_group().world_size)
block_size = self.block_size

var = self.model_runner.forward_vars
Expand Down Expand Up @@ -884,9 +904,18 @@ def _set_ubatch_pa_buffers(self, padded_bs, max_q_len, ubatch_idx):
config = self.model_runner.config
hf_config = config.hf_config
num_query_heads = self.num_attention_heads
num_kv_heads = max(
1, hf_config.num_key_value_heads // get_tp_group().world_size
num_kv_heads_cfg = getattr(
hf_config,
"num_key_value_heads",
getattr(hf_config, "num_attention_groups", None),
)
if num_kv_heads_cfg is None:
raise ValueError(
"Model config has neither 'num_key_value_heads' nor "
"'num_attention_groups'; cannot determine number of KV heads "
f"for {getattr(hf_config, 'architectures', hf_config)}"
)
num_kv_heads = max(1, num_kv_heads_cfg // get_tp_group().world_size)
p = f"ub{ubatch_idx}_"
var = self.model_runner.forward_vars

Expand Down
10 changes: 10 additions & 0 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,16 @@ def forward_cuda(
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# Contiguity guard (merge-regression fix): the aiter HIP fused-kernel path
# below uses x.view(), which requires a contiguous input. QK-norm feeds a
# non-contiguous GQA slice here (step3p5 torch.split -> reshape keeps the
# qkv row stride, e.g. (1280,128,1) for an 8x128 q), so x.view(-1, head_dim)
# raises "Cannot view a tensor ...". Fall back to the pre-merge native math
# for non-contiguous inputs; contiguous callers keep main's fast HIP kernel.
# Under Dynamo, FakeTensor.is_contiguous() resolves to a concrete bool from
# static strides, so this short-circuits at trace time before the .view().
if not x.is_contiguous():
return self.forward_native(x, residual)
# Use the aiter HIP fused_qk_rmsnorm_group_quant kernel in no-quant mode
# (q_out_scale=None) to perform Gemma RMSNorm + optional residual add.
# Same math as the Triton kernel: out = rmsnorm(x [+ residual]) * (1 + w),
Expand Down
Loading
Loading