diff --git a/atom/config.py b/atom/config.py index d9b582d601..2d04688eef 100644 --- a/atom/config.py +++ b/atom/config.py @@ -1040,6 +1040,29 @@ def _set_cudagraph_sizes(self): elif len(cuda_graph_sizes) > 1: self.graph_bs = cuda_graph_sizes + def _uses_sliding_window(self) -> bool: + """True iff the model uses sliding-window attention (global or interleaved). + + Prefix caching's classical KV pool cannot restore the per-request SWA + buffer on a cache hit, so SWA models must run with it disabled. + """ + hf = self.hf_config + # Global sliding_window field (Step-3.5=512, Gemma, Mistral, Qwen-SWA, ...). + sw = getattr(hf, "sliding_window", None) + if isinstance(sw, int) and not isinstance(sw, bool) and sw > 0: + return True + # Interleaved SWA via layer_types + # (Step-3.5: ['full_attention', 'sliding_attention', ...]). + layer_types = getattr(hf, "layer_types", None) or [] + if any("sliding" in str(t) for t in layer_types): + return True + # DeepSeek-V4: model_type is remapped to deepseek_v3, so detect SWA via the + # preserved architectures name (kept as a fallback). + arches = getattr(hf, "architectures", None) or [] + if any("DeepseekV4" in str(a) for a in arches): + return True + return False + def __post_init__(self): if isinstance(self.compilation_config, dict): self.compilation_config = CompilationConfig(**self.compilation_config) @@ -1163,16 +1186,32 @@ def __post_init__(self): v4_block_size = 128 if self.kv_cache_block_size != v4_block_size: self.kv_cache_block_size = v4_block_size - # TODO: V4's per-request SWA buffer cannot be restored from the classical - # KV pool on prefix cache hit, so disable prefix caching silently. - if self.enable_prefix_caching: - import logging - logging.getLogger(__name__).warning( - "DeepSeek-V4 does not support prefix caching " - "(SWA buffer is not cacheable); disabling automatically." + # SWA models cannot restore the per-request sliding-window KV buffer from the + # classical KV pool on a prefix-cache hit, so disable prefix caching for any + # sliding-window model (DeepSeek-V4, Step-3.5, ...). Generalizes main's + # original V4-only guard, which left Step-3.5/SWA exposed to a merge default + # flip (enable_prefix_caching default False->True). Non-SWA models keep + # main's prefix-caching optimization. + if self._uses_sliding_window(): + import logging + + _log = logging.getLogger(__name__) + if self.enable_prefix_caching: + _log.warning( + "Model uses sliding-window attention (SWA buffer is not " + "cacheable); disabling prefix caching automatically." ) self.enable_prefix_caching = False + if self.enable_chunked_prefill: + # Conservative: SWA + chunked prefill cross-chunk window correctness + # is unverified on GPU; restore the pre-merge-safe default (off). + # TODO: re-enable after SWA + chunked-prefill GPU validation. + _log.warning( + "Model uses sliding-window attention; disabling chunked " + "prefill (SWA + chunked prefill unverified)." + ) + self.enable_chunked_prefill = False def compute_hash(self) -> str: """ diff --git a/atom/examples/simple_inference.py b/atom/examples/simple_inference.py index 73d51c2e69..a5ce4de525 100644 --- a/atom/examples/simple_inference.py +++ b/atom/examples/simple_inference.py @@ -58,7 +58,9 @@ def main(): engine_args = EngineArgs.from_cli_args(args) llm = engine_args.create_engine() - tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained( + args.model, trust_remote_code=getattr(args, "trust_remote_code", False) + ) sampling_params = SamplingParams( temperature=args.temperature, max_tokens=args.max_tokens @@ -70,9 +72,6 @@ def main(): for p in prompts ] print("This is prompts:", prompts) - # print("Warming up...") - # _ = llm.generate(["warmup"], sampling_params) - # print("Warm up done") print("\n" + "=" * 70) print("Starting profiling...") diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index db018399e2..023fe47efa 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -72,6 +72,7 @@ "Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5.Qwen3_5MoeMultimodalModel", "KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM", "MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM", + "Step3p5ForCausalLM": "atom.models.step3p5.Step3p5ForCausalLM", "MiMoV2ForCausalLM": "atom.models.mimo_v2.MiMoV2ForCausalLM", "MiMoV2FlashForCausalLM": "atom.models.mimo_v2.MiMoV2ForCausalLM", } @@ -1200,11 +1201,22 @@ def allocate_forward_vars(self): def _get_num_kv_heads(self): """Return the per-rank number of KV heads.""" hf_config = self.config.hf_config - if hf_config.num_key_value_heads >= self.world_size: - assert hf_config.num_key_value_heads % self.world_size == 0 - return hf_config.num_key_value_heads // self.world_size + num_kv_heads_cfg = getattr( + hf_config, + "num_key_value_heads", + getattr(hf_config, "num_attention_groups", None), + ) + if num_kv_heads_cfg is None: + raise ValueError( + "Model config has neither 'num_key_value_heads' nor " + "'num_attention_groups'; cannot determine number of KV heads " + f"for {getattr(hf_config, 'architectures', hf_config)}" + ) + if num_kv_heads_cfg >= self.world_size: + assert num_kv_heads_cfg % self.world_size == 0 + return num_kv_heads_cfg // self.world_size else: - assert self.world_size % hf_config.num_key_value_heads == 0 + assert self.world_size % num_kv_heads_cfg == 0 return 1 def _mrope_positions_view(self, num_tokens: int) -> torch.Tensor: @@ -1453,11 +1465,22 @@ def allocate_kv_cache(self, num_kvcache_blocks): self.num_physical_kvcache_blocks = ( num_kvcache_blocks * self.attn_metadata_builder.block_ratio ) - if hf_config.num_key_value_heads >= self.world_size: - assert hf_config.num_key_value_heads % self.world_size == 0 - num_kv_heads = hf_config.num_key_value_heads // self.world_size + num_kv_heads_cfg = getattr( + hf_config, + "num_key_value_heads", + getattr(hf_config, "num_attention_groups", None), + ) + if num_kv_heads_cfg is None: + raise ValueError( + "Model config has neither 'num_key_value_heads' nor " + "'num_attention_groups'; cannot determine number of KV heads " + f"for {getattr(hf_config, 'architectures', hf_config)}" + ) + if num_kv_heads_cfg >= self.world_size: + assert num_kv_heads_cfg % self.world_size == 0 + num_kv_heads = num_kv_heads_cfg // self.world_size else: - assert self.world_size % hf_config.num_key_value_heads == 0 + assert self.world_size % num_kv_heads_cfg == 0 num_kv_heads = 1 # Promote to self so attention builders' build_kv_cache_tensor() # hooks can access it without re-deriving from hf_config. diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index ff4771f8a5..996bc8b984 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -478,12 +478,27 @@ def _submit(fn, *args): maybe_matching_name, f"{module_prefix}experts.{hf_config.n_routed_experts}.", ) + # Check fused expert format before packed_modules_mapping to avoid + # expert weights (e.g. moe.gate_proj) being incorrectly matched + # by packed_modules_mapping entries (e.g. gate_proj -> gate_up_proj). + if detect_fused_expert_fn is not None and not is_fused_expert: + if detect_fused_expert_fn(name): + is_fused_expert = True + if get_fused_expert_mapping_fn is not None: + fused_expert_params_mapping = get_fused_expert_mapping_fn() for k in packed_modules_mapping: # We handle the experts below in expert_params_mapping if ( "mlp.experts." in name or "ffn.experts." in name ) and name not in params_dict: continue + # Skip fused expert weights — handled below in expert loading path + if ( + is_fused_expert + and detect_fused_expert_fn is not None + and detect_fused_expert_fn(name) + ): + continue if k in name: packed_value = packed_modules_mapping[k] # Handle both tuple (fuse parameter) and list (shard parameter) @@ -556,7 +571,14 @@ def _submit(fn, *args): ) if matched: - loaded_weights_record.add(prefix + name) + # Record the MAPPED param name (e.g. + # moe.experts.w13_weight), not the ckpt name + # (e.g. moe.gate_proj.weight): the post-load + # verification below diffs against params_dict + # keys (param names), so recording the ckpt name + # makes fused-expert params (w13_weight/w2_weight) + # falsely show up as "NOT loaded". + loaded_weights_record.add(prefix + name_mapped) break if matched: diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 078fed004f..6c6ca3cf53 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -86,7 +86,18 @@ def __init__( else: max_qlen = 1 - num_head_k = max(1, hf_config.num_key_value_heads // get_tp_group().world_size) + num_kv_heads_cfg = getattr( + hf_config, + "num_key_value_heads", + getattr(hf_config, "num_attention_groups", None), + ) + if num_kv_heads_cfg is None: + raise ValueError( + "Model config has neither 'num_key_value_heads' nor " + "'num_attention_groups'; cannot determine number of KV heads " + f"for {getattr(hf_config, 'architectures', hf_config)}" + ) + num_head_k = max(1, num_kv_heads_cfg // get_tp_group().world_size) ( (work_meta_data_size, work_meta_data_type), (work_indptr_size, work_indptr_type), @@ -236,9 +247,18 @@ def set_aiter_persistent_worker_buffers(self, bs: int): config = self.model_runner.config hf_config = config.hf_config num_query_heads = self.num_attention_heads - num_kv_heads = max( - 1, hf_config.num_key_value_heads // get_tp_group().world_size + num_kv_heads_cfg = getattr( + hf_config, + "num_key_value_heads", + getattr(hf_config, "num_attention_groups", None), ) + if num_kv_heads_cfg is None: + raise ValueError( + "Model config has neither 'num_key_value_heads' nor " + "'num_attention_groups'; cannot determine number of KV heads " + f"for {getattr(hf_config, 'architectures', hf_config)}" + ) + num_kv_heads = max(1, num_kv_heads_cfg // get_tp_group().world_size) block_size = self.block_size var = self.model_runner.forward_vars @@ -884,9 +904,18 @@ def _set_ubatch_pa_buffers(self, padded_bs, max_q_len, ubatch_idx): config = self.model_runner.config hf_config = config.hf_config num_query_heads = self.num_attention_heads - num_kv_heads = max( - 1, hf_config.num_key_value_heads // get_tp_group().world_size + num_kv_heads_cfg = getattr( + hf_config, + "num_key_value_heads", + getattr(hf_config, "num_attention_groups", None), ) + if num_kv_heads_cfg is None: + raise ValueError( + "Model config has neither 'num_key_value_heads' nor " + "'num_attention_groups'; cannot determine number of KV heads " + f"for {getattr(hf_config, 'architectures', hf_config)}" + ) + num_kv_heads = max(1, num_kv_heads_cfg // get_tp_group().world_size) p = f"ub{ubatch_idx}_" var = self.model_runner.forward_vars diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 624b2b0f84..13bb2a1c10 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -667,6 +667,16 @@ def forward_cuda( x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # Contiguity guard (merge-regression fix): the aiter HIP fused-kernel path + # below uses x.view(), which requires a contiguous input. QK-norm feeds a + # non-contiguous GQA slice here (step3p5 torch.split -> reshape keeps the + # qkv row stride, e.g. (1280,128,1) for an 8x128 q), so x.view(-1, head_dim) + # raises "Cannot view a tensor ...". Fall back to the pre-merge native math + # for non-contiguous inputs; contiguous callers keep main's fast HIP kernel. + # Under Dynamo, FakeTensor.is_contiguous() resolves to a concrete bool from + # static strides, so this short-circuits at trace time before the .view(). + if not x.is_contiguous(): + return self.forward_native(x, residual) # Use the aiter HIP fused_qk_rmsnorm_group_quant kernel in no-quant mode # (q_out_scale=None) to perform Gemma RMSNorm + optional residual add. # Same math as the Triton kernel: out = rmsnorm(x [+ residual]) * (1 + w), diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index e002923348..f11f2b20f7 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -581,7 +581,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight = atom_parameter(self._maybe_pad_weight(layer.w13_weight.data)) layer.w2_weight = atom_parameter(self._maybe_pad_weight(layer.w2_weight.data)) - # reshaping weights is required for aiter moe kernel. + + # gfx950 CK a16w16 stage2 requires inter_dim % 64 == 0. + # For tp=4 (inter=320) and tp=8 (inter=160), pad inter_dim up to the + # next multiple of 64. Zero padding is safe because fused_moe clips + # routed-weight contributions and zero-padded rows contribute nothing. + # Verified 2026-04-24: cos_sim >= 0.9999 for inter=160->192 and + # inter=320->384 vs torch reference. + w13 = layer.w13_weight.data # [E, 2*inter, hidden] + w2 = layer.w2_weight.data # [E, hidden, inter] + inter_dim = w2.shape[2] + # Stage1 dispatch: inter<=192 uses NPerBlock=64, inter>192 uses NPerBlock=128. + # Stage2 dispatch: inter>192 uses KPerBlock=64. + # So required alignment: 64 when inter<=192, 128 when inter>192. + # (inter=160->192 satisfies 192%64=0; inter=320->384 satisfies 384%128=0 and 384%64=0) + align = 64 if inter_dim <= 192 else 128 + inter_pad = (inter_dim + align - 1) // align * align + if inter_pad != inter_dim: + E, _, hidden = w13.shape + # pad w13: gate half [E, inter, hidden] and up half [E, inter, hidden] + w13_new = torch.zeros( + E, 2 * inter_pad, hidden, dtype=w13.dtype, device=w13.device + ) + w13_new[:, :inter_dim, :] = w13[:, :inter_dim, :] # gate + w13_new[:, inter_pad : inter_pad + inter_dim, :] = w13[ + :, inter_dim:, : + ] # up + # pad w2: [E, hidden, inter_pad] + w2_new = torch.zeros(E, hidden, inter_pad, dtype=w2.dtype, device=w2.device) + w2_new[:, :, :inter_dim] = w2 + layer.w13_weight = atom_parameter(w13_new) + layer.w2_weight = atom_parameter(w2_new) + + # Shuffle weights for CK/ASM kernels. + # Previously skipped for gfx950 bf16 g1u1 on the assumption that the CK + # 2-stage preshuffle_off (NSwizzle=0) kernel expected un-shuffled weights. + # Verified 2026-04-23: preshuffle_off GEMM is wrong on gfx950; preshuffle_on + # (NSwizzle=1) is correct. Always shuffle so the right kernel path is used. shuffle_weights(layer.w13_weight, layer.w2_weight) def get_fused_moe_quant_config( @@ -1743,25 +1779,38 @@ def create_weights( block_n = 1 block_k = 32 tp_size = get_tp_group().world_size + # Pad intermediate_size_per_partition to the nearest multiple of + # block_n so that both the CK blockscale kernel alignment constraints + # and the block-quantization scale alignment constraints are satisfied. + # For tp=4, inter_dim=320 is not divisible by block_n=128; padding to + # 384 (=3×128) satisfies both stage1 NPerBlock=128 and block_n=128. + # The weight loader already supports partial loading into a padded + # buffer via the MXFP4 alignment path (_load_w13 / _load_w2). + padded_inter = ( + (intermediate_size_per_partition + block_n - 1) // block_n * block_n + ) # NOTE: To ensure proper alignment of the block-wise quantization # scales, the output_size of the weights for both the gate and up # layers must be divisible by block_n. # Required by column parallel or enabling merged weights - if intermediate_size_per_partition % block_n != 0: + if padded_inter % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " - f"{intermediate_size_per_partition} is not divisible by " + f"{padded_inter} is not divisible by " f"weight quantization block_n = {block_n}." ) - if tp_size > 1 and intermediate_size_per_partition % block_k != 0: + if tp_size > 1 and padded_inter % block_k != 0: # Required by row parallel raise ValueError( f"The input_size of down's weight = " - f"{intermediate_size_per_partition} is not divisible by " + f"{padded_inter} is not divisible by " f"weight quantization block_k = {block_k}." ) # WEIGHTS + # Allocated at original (un-padded) size; inter_dim padding is applied + # later in _process_block_quant (after normalize, before shuffle_weights), + # mirroring the BF16 approach in UnquantizedFusedMoEMethod. w13_weight = atom_parameter( torch.empty( num_experts, @@ -1881,6 +1930,39 @@ def _process_block_quant(self, layer: nn.Module) -> None: assert self.quant_config.is_dynamic self._normalize_weights_and_scales(layer) + # Inter-dim padding for block-quantized FP8 (mirrors BF16 approach in + # UnquantizedFusedMoEMethod.process_weights_after_loading). + # When inter_dim is not a multiple of block_n (e.g. tp=4: 320 % 128 ≠ 0), + # zero-pad both weights to the nearest block_n multiple BEFORE shuffling. + # Padding area is zero so dequant(0, scale) = 0 is numerically safe. + # Scale tensors use ceil(inter/block_n) and are already shape-compatible. + inter_dim = layer.w2_weight.shape[-1] + block_n = 128 if self.quant_type == QuantType.per_1x128 else 32 + # FP8 blockscale stage2 requires KPerBlock=128 (gfx950 FP8 mfma KPack=32 constraint + # prevents KPerBlock=64). align must always be block_n(=128) so that inter_pad%128==0. + # Bug fix: previously used align=64 for inter<=192 (copied from BF16 path), but + # 192%128=64!=0 → stage2 kernel dispatch fails. Correct: always align to block_n. + # tp=8 inter=160 → 256 (3×128→no, ceil(160/128)*128=256); tp=4 inter=320 → 384. + align = block_n + inter_pad = (inter_dim + align - 1) // align * align + if inter_pad != inter_dim: + E = layer.w13_weight.shape[0] + hidden = layer.w13_weight.shape[-1] + w13 = layer.w13_weight.data + w13_new = torch.zeros( + E, 2 * inter_pad, hidden, dtype=w13.dtype, device=w13.device + ) + w13_new[:, :inter_dim, :] = w13[:, :inter_dim, :] # gate + w13_new[:, inter_pad : inter_pad + inter_dim, :] = w13[ + :, inter_dim:, : + ] # up + layer.w13_weight = atom_parameter(w13_new) + + w2 = layer.w2_weight.data + w2_new = torch.zeros(E, hidden, inter_pad, dtype=w2.dtype, device=w2.device) + w2_new[:, :, :inter_dim] = w2 + layer.w2_weight = atom_parameter(w2_new) + if not self.need_normalize_e4m3fn_to_e4m3fnuz: layer.w13_weight = atom_parameter(layer.w13_weight.data) layer.w13_weight_scale = atom_parameter(layer.w13_weight_scale.data) @@ -2726,10 +2808,38 @@ def _load_w13( expert_shard_size = expert_data.shape[shard_dim] // 2 # Derive shard size from loaded_weight (unpadded checkpoint) to avoid # out-of-bounds when expert_data is padded (e.g. MXFP4 alignment). - load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size - loaded_weight = loaded_weight.narrow( - shard_dim, load_shard_size * tp_rank, load_shard_size - ) + # Use ceil so that the last partial scale block (e.g. per_1x128 with + # inter=1280 and tp=4: 10 blocks / 4 = 2.5 → ceil=3) is included. + # Without ceil, the 3rd scale block is never copied and stays at the + # torch.ones() initial value of 1.0, causing ~5000× dequant error. + load_shard_size = ( + loaded_weight.shape[shard_dim] + self.tp_size - 1 + ) // self.tp_size + start = load_shard_size * tp_rank + # When D < tp_size (e.g. per_1x128 scale block count smaller than + # tp_size, observed at tp=8 with inter=1280 → D=10), the ceil split + # gives some trailing ranks start >= D so they hold no slice of the + # loaded tensor. Skip narrow + copy_ for those ranks; the rank's + # slice of expert_data stays at its initialised value (0 for weight, + # 1.0 for scale) and the rank contributes a no-op to the column + # gather / row reduction. + if start >= loaded_weight.shape[shard_dim]: + # FP8 scale tensors are torch.ones() initialised. If we leave the + # trailing rank's slice at 1.0, the downstream FP8 dequant multiplies + # the (uninitialised) fp8 weight by 1.0 instead of the correct + # quantization scale, contaminating the column gather / row reduction. + # Zero the slot so dequant produces 0 and the rank contributes a + # true no-op (matches MXFP4 scale init at moe.py:776,813). + if expert_data.dtype == torch.float32: + if shard_id == "w1": + expert_data.narrow(shard_dim, 0, expert_shard_size).zero_() + else: + expert_data.narrow( + shard_dim, expert_shard_size, expert_shard_size + ).zero_() + return + size = min(load_shard_size, loaded_weight.shape[shard_dim] - start) + loaded_weight = loaded_weight.narrow(shard_dim, start, size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -2740,10 +2850,16 @@ def _load_w13( expert_data = expert_data.narrow( shard_dim, expert_shard_size, expert_shard_size ) - # When expert_data is padded beyond the actual weight size, narrow to - # the loaded weight size so the copy shape matches. - if load_shard_size != expert_shard_size: - expert_data = expert_data.narrow(shard_dim, 0, load_shard_size) + # Narrow expert_data to the actually-loaded `size` so copy_ matches + # loaded_weight, and zero any remainder of this rank's slot (the + # trailing partial rank where size < load_shard_size, or padded + # expert_data). Without this, a non-evenly-divisible split (e.g. tp=4 + # with D=10) hits a copy_ shape mismatch and leaves the tail at its + # init value. No-op for tp=8 (D divides evenly; size == slot). + slot = expert_data.shape[shard_dim] + if size < slot: + expert_data.narrow(shard_dim, size, slot - size).zero_() + expert_data = expert_data.narrow(shard_dim, 0, size) if expert_data.dtype != dtypes.fp4x2: # Dtype glue: V4 stores per-1x32 weight scales as float8_e8m0fnu but # FusedMoE allocates them as uint8 (raw byte storage). PyTorch's @@ -2782,12 +2898,34 @@ def _load_w2( # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] - load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size - loaded_weight = loaded_weight.narrow( - shard_dim, load_shard_size * tp_rank, load_shard_size - ) - if load_shard_size != shard_size: - expert_data = expert_data.narrow(shard_dim, 0, load_shard_size) + if not load_full: + # Derive shard size from loaded_weight (unpadded checkpoint) to + # avoid out-of-bounds when expert_data is padded (e.g. MXFP4). + # Use ceil (same reason as _load_w13: partial last scale block). + load_shard_size = ( + loaded_weight.shape[shard_dim] + self.tp_size - 1 + ) // self.tp_size + start = load_shard_size * tp_rank + # See _load_w13 comment above: when D < tp_size the ceil split + # leaves trailing ranks with no slice; skip narrow + copy_. + if start >= loaded_weight.shape[shard_dim]: + # Zero the scale slice so dequant=0 instead of multiplying by + # stale init=1.0; see _load_w13 comment for full rationale. + if expert_data.dtype == torch.float32: + if load_shard_size != shard_size: + expert_data.narrow(shard_dim, 0, load_shard_size).zero_() + else: + expert_data.zero_() + return + size = min(load_shard_size, loaded_weight.shape[shard_dim] - start) + loaded_weight = loaded_weight.narrow(shard_dim, start, size) + # Narrow expert_data to the actually-loaded `size` so copy_ matches + # loaded_weight, and zero any remainder of this rank's slot (see + # _load_w13 for full rationale). No-op for tp=8 (size == slot). + slot = expert_data.shape[shard_dim] + if size < slot: + expert_data.narrow(shard_dim, size, slot - size).zero_() + expert_data = expert_data.narrow(shard_dim, 0, size) # w2, down_proj: Load into only logical weight of w2. if expert_data.dtype == dtypes.fp4x2: expert_data.view(torch.uint8).copy_(loaded_weight.view(torch.uint8)) diff --git a/atom/models/step3p5.py b/atom/models/step3p5.py new file mode 100644 index 0000000000..2c279917cd --- /dev/null +++ b/atom/models/step3p5.py @@ -0,0 +1,890 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Inference-only Step-3.5 (Flash) model. + +Step-3.5 is a sparse MoE transformer with: + - 45 decoder layers, hidden_size=4096, head_dim=128 + - GQA with two attention configs: full_attention (64 q heads, 8 kv groups) + and sliding_attention (96 q heads, 8 kv groups, window=512) + - 3:1 sliding window pattern (1 full + 3 sliding) + - Per-layer rope_theta and partial_rotary_factor + - QK RMSNorm (zero-centered, i.e. weight * (1 + param)) + - Head-wise attention gating via g_proj (sigmoid) + - MoE on layers 3-44: 288 routed experts + 1 shared expert, top-8, + sigmoid routing with learnable router bias + - Dense MLP on layers 0-2 + - Per-layer SwiGLU clamp limits + - Multi-token prediction (MTP) config (num_nextn_predict_layers=3) is present + but NOT implemented here (not needed for standard inference). +""" + +import os +from typing import Optional, Union + +import torch +from aiter import ActivationType +from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size +from aiter.rotary_embedding import get_rope +from atom.config import Config, QuantizationConfig +from atom.model_ops.activation import SiluAndMul +from atom.model_ops.base_attention import Attention +from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding +from atom.model_ops.layernorm import GemmaRMSNorm as Step3p5RMSNorm +from atom.model_ops.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from atom.model_ops.moe import FusedMoE +from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled +from atom.models.utils import ( + IntermediateTensors, + PPMissingLayer, + extract_layer_index, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from atom.utils.decorators import support_torch_compile +from torch import nn +from transformers import PretrainedConfig + + +def _uses_swiglustep_at_layer( + config: PretrainedConfig, layer_idx: Optional[int] +) -> bool: + """Return True iff the routed FusedMoE at this layer needs the SwigluStep + activation (i.e. ``swiglu_limits[layer_idx] > 0``). + + The CK kernel hard-codes the clamp at 7.0; Step-3.5-Flash uses 7.0 at + layers 43 and 44, which is why the kernel is only valid at those layers. + Other layers must keep the plain Silu path. + """ + if layer_idx is None: + return False + # Toggle-off bit: ATOM_DISABLE_SWIGLUSTEP=1 forces plain Silu at every + # layer (verification helper only). + if os.environ.get("ATOM_DISABLE_SWIGLUSTEP"): + return False + swiglu_limits = getattr(config, "swiglu_limits", None) + if not swiglu_limits or layer_idx >= len(swiglu_limits): + return False + return bool(swiglu_limits[layer_idx]) + + +def _fuse_shared_at_layer(config: PretrainedConfig, layer_idx: Optional[int]) -> bool: + """Whether to fuse the shared expert into the routed FusedMoE at this layer. + + R5 mitigation: at SwigluStep layers the kernel clamps every expert at 7.0, + but the shared expert may use a different clamp (e.g. 16 at layer 44 or 0 + at layer 43). Therefore the shared expert MUST stay on the dense path at + every SwigluStep layer, even when the global aiter fusion is enabled. + """ + # ATOM_FORCE_FUSE_SHARED=1 always fuses the shared expert into the + # routed kernel (verification helper: bypass R5 mitigation). + if os.environ.get("ATOM_FORCE_FUSE_SHARED"): + return is_rocm_aiter_fusion_shared_expert_enabled() + return ( + is_rocm_aiter_fusion_shared_expert_enabled() + and not _uses_swiglustep_at_layer(config, layer_idx) + ) + + +# --------------------------------------------------------------------------- +# MLP (dense, used for first few layers and shared expert) +# --------------------------------------------------------------------------- + + +class Step3p5MLP(nn.Module): + """Dense SwiGLU MLP with optional activation clamping.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + reduce_results: bool = True, + clamp_limit: Optional[float] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + # 0.0 means no clamping (disabled), only apply if > 0 + self.clamp_limit = ( + clamp_limit if (clamp_limit is not None and clamp_limit > 0) else None + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.gate_up_proj(x) + if self.clamp_limit is not None: + # Match HF: clamp AFTER silu activation on gate, symmetric on up + # gate_proj output is first half, up_proj output is second half + half = x.shape[-1] // 2 + gate, up = x[..., :half], x[..., half:] + gate = torch.nn.functional.silu(gate).clamp(max=self.clamp_limit) + up = up.clamp(min=-self.clamp_limit, max=self.clamp_limit) + x = self.down_proj(gate * up) + else: + x = self.act_fn(x) + x = self.down_proj(x) + return x + + +# --------------------------------------------------------------------------- +# MoE block (routed experts + shared expert) +# --------------------------------------------------------------------------- + + +class Step3p5MoE(nn.Module): + """Sparse MoE block for Step-3.5. + + Checkpoint weight layout under ``layers.{i}.moe.*``: + - gate.weight (router linear) + - router_bias (learnable additive bias for sigmoid routing) + - gate_proj.weight (per-expert, shape [num_experts, intermediate, hidden]) + - up_proj.weight (per-expert) + - down_proj.weight (per-expert) + + The FusedMoE kernel maps these via ``get_expert_mapping``. + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + num_experts: int = config.moe_num_experts + top_k: int = config.moe_top_k + moe_intermediate_size: int = config.moe_intermediate_size + + # Per-layer SwiGLU clamp limit for routed experts. + # Step-3.5 applies clamp(x, -limit, limit) after gate_up_proj and + # before the SwiGLU activation inside each expert. The CK kernel + # implements this as ``ActivationType.SwigluStep`` with a hard-coded + # ±7 clamp; Step-3.5-Flash uses 7 at layers 43-44 only. + layer_idx = extract_layer_index(prefix) if prefix else None + self._layer_idx = layer_idx + swiglu_limits = getattr(config, "swiglu_limits", None) + self.clamp_limit = ( + swiglu_limits[layer_idx] + if ( + swiglu_limits and layer_idx is not None and swiglu_limits[layer_idx] > 0 + ) + else None + ) + self._uses_swiglustep = self.clamp_limit is not None + self._activation = ( + ActivationType.SwigluStep if self._uses_swiglustep else ActivationType.Silu + ) + + # Router --------------------------------------------------------- + self.gate = ReplicatedLinear( + self.hidden_size, + num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + # Learnable router bias (added to sigmoid probs before top-k) + self.router_bias = nn.Parameter( + torch.zeros(num_experts, dtype=torch.float32), + requires_grad=False, + ) + + self.routed_scaling_factor = getattr(config, "moe_router_scaling_factor", 1.0) + self._need_fp32_gate = getattr(config, "need_fp32_gate", False) + + # Routed experts (fused MoE kernel) -------------------------------- + # R5 mitigation: at SwigluStep layers we MUST NOT fuse the shared + # expert into the routed FusedMoE (the kernel hard-codes ±7 clamp, + # but the shared expert uses a different limit, e.g. 16 at layer 44 + # or 0 at layer 43). Fall back to the dense Step3p5MLP path there. + self._fuse_shared = _fuse_shared_at_layer(config, layer_idx) + n_shared = 1 if self._fuse_shared else 0 + self._n_shared_fused = ( + n_shared # 1 when shared expert is fused as expert num_experts + ) + self.experts = FusedMoE( + num_experts=num_experts + n_shared, + top_k=top_k + n_shared, # +1 so kernel selects top_k routed + 1 shared + hidden_size=self.hidden_size, + intermediate_size=moe_intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + custom_routing_function=self._routing_function, + config=config, + activation=self._activation, + ) + + def _routing_function( + self, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + hidden_states: Optional[torch.Tensor] = None, + ): + """Sigmoid routing with additive bias and scaling. + + ``hidden_states`` is accepted for compatibility with the framework's + custom-routing calling convention (used by per-token hash routing in + other models); Step-3.5 routing is computed from ``gating_output`` only + and ignores it. + + When the shared expert is fused (self._n_shared_fused == 1), topk is + top_k_routed + 1. We select top_k_routed routed experts and append + the shared expert (index num_routed_experts) with weight 1.0. + """ + n_shared = self._n_shared_fused + top_k_routed = topk - n_shared # number of routed experts to pick + + gate_prob = torch.sigmoid(gating_output.float()) + gate_prob_biased = gate_prob + self.router_bias.unsqueeze(0) + _, indices = torch.topk(gate_prob_biased, k=top_k_routed, dim=1) + topk_prob = torch.gather(gate_prob, 1, indices) + if renormalize: + topk_prob = topk_prob / (topk_prob.sum(dim=-1, keepdim=True) + 1e-20) + topk_prob = topk_prob * self.routed_scaling_factor + + if n_shared > 0: + # Append shared expert (always selected, weight=1.0) + T = gating_output.shape[0] + num_routed = gating_output.shape[1] # 288 + shared_ids = torch.full( + (T, n_shared), + num_routed, + dtype=torch.int32, + device=gating_output.device, + ) + shared_weights = torch.ones( + (T, n_shared), dtype=torch.float32, device=gating_output.device + ) + topk_prob = torch.cat([topk_prob, shared_weights], dim=1) + indices = torch.cat([indices.to(torch.int32), shared_ids], dim=1) + return topk_prob, indices + + return topk_prob, indices.to(torch.int32) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + + # Router logits must be computed in fp32 (need_fp32_gate=True in config). + # self._need_fp32_gate is always set in __init__ (default False when the + # config field is absent), so read it directly. + if self._need_fp32_gate: + router_logits = torch.nn.functional.linear( + hidden_states.float(), self.gate.weight.float() + ) + else: + router_logits = self.gate(hidden_states) + + # Routed experts. At SwigluStep layers (43-44) the FusedMoE was + # constructed with ``activation=ActivationType.SwigluStep`` so the CK + # kernel applies ``silu(g).clamp(max=7) * up.clamp(±7)`` per expert. + routed_out = self.experts(hidden_states, router_logits) + + return routed_out.view(orig_shape) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class Step3p5Attention(nn.Module): + """GQA attention for Step-3.5. + + Key differences from vanilla LLaMA attention: + - Per-layer rope_theta and partial_rotary_factor (from config lists). + - Two attention head configurations depending on full vs sliding. + - QK RMSNorm (zero-centered / GemmaRMSNorm style). + - Head-wise attention gating via g_proj (sigmoid). + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: str = "bf16", + prefix: str = "", + layer_num: int = 0, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = config.hidden_size + + # Determine layer type and head counts ---------------------------- + layer_types = getattr(config, "layer_types", []) + is_sliding = ( + layer_types[layer_idx] == "sliding_attention" if layer_types else False + ) + attn_other = getattr(config, "attention_other_setting", None) + + if is_sliding and attn_other is not None: + self.total_num_heads = attn_other["num_attention_heads"] + self.total_num_kv_heads = attn_other["num_attention_groups"] + else: + self.total_num_heads = config.num_attention_heads + self.total_num_kv_heads = config.num_attention_groups + + self.head_dim = getattr(config, "head_dim", 128) + + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + # RoPE configuration ----------------------------------------------- + rope_theta_cfg = getattr(config, "rope_theta", 10000.0) + if isinstance(rope_theta_cfg, list): + rope_theta = rope_theta_cfg[layer_idx] + else: + rope_theta = rope_theta_cfg + + partial_rotary_factors = getattr(config, "partial_rotary_factors", None) + if partial_rotary_factors is not None: + partial_rotary_factor = partial_rotary_factors[layer_idx] + else: + partial_rotary_factor = 1.0 + + rotary_dim = int(self.head_dim * partial_rotary_factor) + + max_position_embeddings = getattr(config, "max_position_embeddings", 262144) + + # Determine rope_scaling for this layer + rope_scaling = getattr(config, "rope_scaling", None) + yarn_only_types = getattr(config, "yarn_only_types", None) + if yarn_only_types and layer_types: + layer_type = layer_types[layer_idx] + if layer_type not in yarn_only_types: + rope_scaling = None + + # Projections ------------------------------------------------------- + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # QK Norm (zero-centered RMSNorm) ----------------------------------- + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + self.q_norm = Step3p5RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = Step3p5RMSNorm(self.head_dim, eps=rms_norm_eps) + + # Head-wise attention gate ------------------------------------------- + self.use_head_wise_attn_gate = getattr(config, "use_head_wise_attn_gate", False) + if self.use_head_wise_attn_gate: + self.g_proj = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.total_num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.g_proj", + ) + + # Rotary embedding --------------------------------------------------- + # Note: rotary_dim is already computed as head_dim * partial_rotary_factor, + # so we do NOT pass partial_rotary_factor to get_rope (which would apply it twice). + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + + # Sliding window and sink tokens per layer --------------------------- + sliding_window = None + sinks = None + if is_sliding: + sliding_window = getattr(config, "sliding_window", None) + # config.sink may be False (bool) or an int; only a positive int + # enables attention sinks (False/0 -> no sinks). + sink_size = getattr(config, "sink", 0) + if sink_size > 0: + sinks = nn.Parameter(torch.zeros(self.num_heads, requires_grad=False)) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + kv_cache_dtype=cache_config, + layer_num=layer_num, + per_layer_sliding_window=sliding_window, + sinks=sinks, + prefix=f"{prefix}.attn", + rotary_emb=self.rotary_emb, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) + + # QK Norm – apply per-head RMSNorm + # Reshape to (..., num_heads, head_dim), apply norm, reshape back + q = self.q_norm(q.reshape(*q.shape[:-1], -1, self.head_dim)).flatten(-2) + k = self.k_norm(k.reshape(*k.shape[:-1], -1, self.head_dim)).flatten(-2) + + attn_output = self.attn(q, k, v, positions) + + # Head-wise gating + if self.use_head_wise_attn_gate: + gate = self.g_proj(hidden_states) # (tokens, num_heads_tp) + # gate: (tokens, num_heads_tp) -> (tokens, num_heads_tp, 1) + gate = torch.sigmoid(gate).unsqueeze(-1) + reshaped = attn_output.reshape(*attn_output.shape[:-1], -1, self.head_dim) + attn_output = (reshaped * gate).flatten(-2) + + output = self.o_proj(attn_output) + return output + + +# --------------------------------------------------------------------------- +# Decoder Layer +# --------------------------------------------------------------------------- + + +class Step3p5DecoderLayer(nn.Module): + """Single decoder layer for Step-3.5. + + - Layers 0-2: dense MLP + - Layers 3-44: MoE (288 routed + 1 shared) + """ + + def __init__( + self, + config: PretrainedConfig, + cache_config: str = "bf16", + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + layer_num: int = 0, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) + + # Attention + self.self_attn = Step3p5Attention( + config=config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + layer_num=layer_num, + ) + + # FFN: dense MLP or MoE depending on layer index + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + if isinstance(moe_layers_enum, str): + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] + else: + moe_layers_idx = list(moe_layers_enum) + else: + moe_layers_idx = list(range(3, config.num_hidden_layers)) + + self.is_moe_layer = layer_idx in moe_layers_idx + + # Per-layer SwiGLU clamp limits + swiglu_limits_shared = getattr(config, "swiglu_limits_shared", None) + clamp_limit_shared = ( + swiglu_limits_shared[layer_idx] if swiglu_limits_shared else None + ) + + if self.is_moe_layer: + self.moe = Step3p5MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.moe", + ) + # Shared expert (always active, sibling of moe in checkpoint). + # Per-layer fuse decision: SwigluStep layers must keep the shared + # expert on the dense path because the routed CK kernel hard-codes + # the clamp at 7 (see _fuse_shared_at_layer). + if not _fuse_shared_at_layer(config, layer_idx): + self.share_expert = Step3p5MLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + clamp_limit=clamp_limit_shared, + ) + else: + self.share_expert = None + else: + self.mlp = Step3p5MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + clamp_limit=clamp_limit_shared, # HF uses swiglu_limits_shared for dense MLP + ) + + # Layer norms (zero-centered RMSNorm) + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + self.input_layernorm = Step3p5RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = Step3p5RMSNorm( + config.hidden_size, eps=rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + + # FFN + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + if self.is_moe_layer: + moe_output = self.moe(hidden_states) + if self.share_expert is not None: + shared_output = self.share_expert(hidden_states) + hidden_states = moe_output + shared_output + else: + hidden_states = moe_output + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# --------------------------------------------------------------------------- +# Full Model +# --------------------------------------------------------------------------- + + +@support_torch_compile +class Step3p5Model(nn.Module): + def __init__( + self, + atom_config: Config, + prefix: str = "", + ): + super().__init__() + config = atom_config.hf_config + self.config = config + cache_config = atom_config.kv_cache_dtype + quant_config = atom_config.quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or ( + getattr(config, "tie_word_embeddings", False) + and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix, layer_num=None: Step3p5DecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + layer_num=layer_num, + ), + prefix=f"{prefix}.layers", + layer_num_offset=0, + ) + + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + if get_pp_group().is_last_rank: + self.norm = Step3p5RMSNorm(config.hidden_size, eps=rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer : self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +# --------------------------------------------------------------------------- +# CausalLM wrapper +# --------------------------------------------------------------------------- + + +class Step3p5ForCausalLM(nn.Module): + """Step-3.5 model with language modelling head.""" + + packed_modules_mapping = { + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + atom_config: Config, + prefix: str = "", + ): + super().__init__() + config = atom_config.hf_config + self.config = config + + self.model = Step3p5Model( + atom_config=atom_config, + prefix=maybe_prefix(prefix, "model"), + ) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if getattr(config, "tie_word_embeddings", False): + self.lm_head.weight = self.model.embed_tokens.weight + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + logits = self.lm_head(hidden_states) + return logits + + def detect_fused_expert_format(self, weight_name: str) -> bool: + """Step-3.5-Flash expert weights are flat: moe.gate_proj.weight [E, I, H]. + When shared expert fusion is enabled, share_expert weights are also loaded + as expert N in FusedMoE; otherwise they are loaded as a regular MLP. + + Per-layer override: at SwigluStep layers (43-44) the shared expert is + not fused, so its weights must take the dense path even when the + global aiter fusion flag is on. + """ + is_routed_expert = ( + ".moe.gate_proj" in weight_name + or ".moe.up_proj" in weight_name + or ".moe.down_proj" in weight_name + ) + if is_routed_expert: + return True + is_share_expert = ( + ".share_expert.gate_proj" in weight_name + or ".share_expert.up_proj" in weight_name + or ".share_expert.down_proj" in weight_name + ) + if is_share_expert: + layer_idx = extract_layer_index(weight_name) + return _fuse_shared_at_layer(self.config, layer_idx) + return False + + def get_fused_expert_mapping(self) -> list[tuple[str, str, str]]: + """Mapping from flat checkpoint names to FusedMoE parameter names. + + Weight names include the '.weight' suffix from the checkpoint so that + the replace() in loader.py produces the correct param name without the + extra '.weight' tail (e.g. 'moe.gate_proj.weight' -> 'moe.experts.w13_weight'). + """ + mapping = [ + ("moe.experts.w13_weight", "moe.gate_proj.weight", "w1"), + ("moe.experts.w13_weight", "moe.up_proj.weight", "w3"), + ("moe.experts.w2_weight", "moe.down_proj.weight", "w2"), + ] + if is_rocm_aiter_fusion_shared_expert_enabled(): + mapping += [ + ("moe.experts.w13_weight", "share_expert.gate_proj.weight", "w1"), + ("moe.experts.w13_weight", "share_expert.up_proj.weight", "w3"), + ("moe.experts.w2_weight", "share_expert.down_proj.weight", "w2"), + ] + return mapping + + def load_fused_expert_weights( + self, + original_name: str, + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, + ) -> bool: + """Load flat expert weights [E, I, H] into FusedMoE per-expert params. + + For routed experts: loaded_weight is [num_experts, ...], loaded per-expert. + For shared expert: loaded_weight is [I, H] or [H, I], loaded as expert num_experts. + """ + # num_experts from loader may be 0 if hf_config uses non-standard attr name + if num_experts == 0: + num_experts = self.config.moe_num_experts + + if name not in params_dict: + return False + param = params_dict[name] + weight_loader = param.weight_loader + loaded_local_expert = False + + is_share_expert = "share_expert" in original_name + + if is_share_expert: + # Defensive: if this layer keeps the shared expert dense (e.g. + # SwigluStep layers 43-44), do not route it through FusedMoE. + layer_idx = extract_layer_index(original_name) + if not _fuse_shared_at_layer(self.config, layer_idx): + return False + # Shared expert is loaded as expert index num_experts (288) + expert_id = num_experts + try: + success = weight_loader( + param, + loaded_weight, + name, + shard_id, + expert_id, + return_success=True, + ) + if success: + loaded_local_expert = True + except TypeError: + weight_loader(param, loaded_weight, name, shard_id, expert_id) + loaded_local_expert = True + else: + for expert_id in range(num_experts): + try: + success = weight_loader( + param, + loaded_weight[expert_id], + name, + shard_id, + expert_id, + return_success=True, + ) + if success: + loaded_local_expert = True + except TypeError: + weight_loader( + param, loaded_weight[expert_id], name, shard_id, expert_id + ) + loaded_local_expert = True + + return loaded_local_expert + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """Return the expert parameter mapping for weight loading. + + Note: Step-3.5-Flash uses flat expert weights in the checkpoint + (moe.gate_proj.weight etc.), so get_expert_mapping is used only + as a sentinel to enable the expert loading path in loader.py. + The actual loading is handled by load_fused_expert_weights. + """ + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.moe_num_experts + + (1 if is_rocm_aiter_fusion_shared_expert_enabled() else 0), + ) diff --git a/atom/plugin/vllm/attention/metadata.py b/atom/plugin/vllm/attention/metadata.py index d411514d48..5e68a20715 100644 --- a/atom/plugin/vllm/attention/metadata.py +++ b/atom/plugin/vllm/attention/metadata.py @@ -419,6 +419,14 @@ def __init__( self.block_ratio = 1 sliding_window_sizes: set[tuple[int, int] | None] = set() + # SWA per-layer KV head workaround: + # Initialize to 0 (not self.num_heads_kv): self.num_heads_kv comes from + # ModelConfig.get_num_kv_heads() which may not match per-layer + # Attention.num_kv_heads (e.g. stepfun-Flash-FP8 returns 32 from ModelConfig + # but per-layer = 4). Using self.num_heads_kv as a floor would mask the + # real per-layer values. We rely on the loop below to populate from each + # layer; defensive getattr fallback keeps it 0 if field missing. + max_per_layer_num_kv_heads = 0 layers = get_layers_from_vllm_config(config, AttentionLayerBase, layer_names) for layer in layers.values(): from atom.plugin.vllm.attention.layer import AttentionForVllmMHA @@ -431,6 +439,12 @@ def __init__( sliding_window_sizes.add(sliding_window) else: sliding_window_sizes.add((sliding_window - 1, 0)) + per_layer_kv = getattr(layer, "num_kv_heads", None) + if per_layer_kv is None: + per_layer_kv = getattr(layer.impl, "num_kv_heads", 0) + if per_layer_kv and per_layer_kv > max_per_layer_num_kv_heads: + max_per_layer_num_kv_heads = per_layer_kv + self.swa_max_num_heads_kv = max_per_layer_num_kv_heads while len(sliding_window_sizes) > 0: sliding_window_config = sliding_window_sizes.pop() @@ -585,8 +599,13 @@ def build( token_to_seq, swa_seqlen_for_extend ) fetched_shape = cu_seq_lens[-1].item() + if self.swa_max_num_heads_kv <= 0: + raise RuntimeError( + "SWA is enabled but no per-layer num_kv_heads was found on " + "any attention layer; swa_workspace would be zero-sized." + ) swa_workspace = torch.empty( - (2, fetched_shape, self.num_heads_kv, self.head_dim), + (2, fetched_shape, self.swa_max_num_heads_kv, self.head_dim), dtype=self.vllm_config.model_config.dtype, device=self.device, )