diff --git a/src/megatron/bridge/models/glm_moe_dsa/cross_layer_dsa.py b/src/megatron/bridge/models/glm_moe_dsa/cross_layer_dsa.py new file mode 100644 index 0000000000..52f63926ac --- /dev/null +++ b/src/megatron/bridge/models/glm_moe_dsa/cross_layer_dsa.py @@ -0,0 +1,276 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""GLM-5.2 DSA *cross-layer index sharing* for the Megatron-Bridge path. + +GLM-5.2 keeps GLM-5.1's ``glm_moe_dsa`` skeleton (MLA + lightning indexer + MoE) but adds +DSA cross-layer index sharing: only "computing"/anchor layers carry indexer weights and +compute the sparse top-k; "skip" layers reuse the most recent computing layer's top-k. +Activated by the HF config field ``index_topk_freq > 1`` (+ ``index_skip_topk_offset``); +GLM-5.1 lacks these fields (freq defaults to 1 -> every layer computes -> plain DSA). + +Implemented entirely on the Megatron-Bridge side (no megatron-core edits): a ``DSAttention`` +subclass that, on skip layers, drops its indexer and reuses the anchor's top-k; plus a cloned +spec builder pointing ``module=`` at the subclass. Mirrors the slime reference +(``slime_plugins/models/glm5/glm5.py``: ``is_skip_topk_layer`` / ``source_compute_layer`` / +the per-microbatch top-k holder / the skip-layer ``delattr``). +""" + +import threading + +import torch +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.experimental_attention_variant.dsa import ( + DSAIndexerLossAutoScaler, + DSAIndexerLossLoggingHelper, + DSAttention, + FusedDSAIndexerLoss, + unfused_dsa_fn, +) + + +# ---- computing-layer schedule (mirrors slime glm5.py:37-52) ---- +def is_skip_topk_layer(layer_number: int, skip_topk_offset: int, topk_freq: int) -> bool: + """1-indexed Megatron ``layer_number`` reuses a previous layer's top-k when True. + + A layer *computes* its own top-k iff ``max(layer_number - offset, 0) % freq == 0``. + """ + return (max(layer_number - skip_topk_offset, 0) % topk_freq) != 0 + + +def source_compute_layer(layer_number: int, skip_topk_offset: int, topk_freq: int) -> int: + """The computing layer whose ``topk_indices`` a skip layer reuses (walk downward).""" + layer = layer_number + while is_skip_topk_layer(layer, skip_topk_offset, topk_freq): + layer -= 1 + return layer + + +def assert_pp_stage_starts_on_computing_layer(config, vp_stage=None) -> None: + """Build-time guard: a (virtual) pipeline stage must not START on a skip layer. + + The per-microbatch top-k holder does NOT cross pipeline boundaries, so a skip layer's source + computing layer must live in the same PP stage. If a stage's first layer is a skip layer, its + source is on a previous stage -> cross-PP top-k sharing (unsupported). Mirrors slime's + ``get_glm5_spec`` build-time check so a bad PP layout fails at model construction with a clear + message, instead of only at the first forward (the runtime guard in ``CrossLayerDSAttention``). + + No-op unless cross-layer sharing is active (``dsa_index_topk_freq > 1``). If the layer layout + cannot be determined (e.g. parallel state not yet initialised), this silently returns and + leaves the runtime guard as the backstop. + """ + freq = getattr(config, "dsa_index_topk_freq", 1) or 1 + if freq <= 1: + return + offset = getattr(config, "dsa_index_skip_topk_offset", 0) or 0 + try: + from megatron.core.transformer.transformer_block import get_transformer_layer_offset + + layer_offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + except Exception: # noqa: BLE001 - layout not determinable; runtime guard still applies + return + first_layer_number = layer_offset + 1 # Megatron layer_number is 1-indexed + if is_skip_topk_layer(first_layer_number, offset, freq): + src = source_compute_layer(first_layer_number, offset, freq) + raise AssertionError( + "DSA cross-layer index-share: this pipeline stage starts at global " + f"layer_number={first_layer_number}, which is a SKIP layer whose source computing " + f"layer={src} is on a previous pipeline stage. The per-microbatch top-k holder does " + "not cross PP boundaries -- choose a pipeline layout where every (virtual) stage " + f"begins on a computing layer (dsa_index_topk_freq={freq}, " + f"dsa_index_skip_topk_offset={offset})." + ) + + +# Per-microbatch top-k holder. Preferred carrier is the ``packed_seq_params`` object (thd: +# fresh per microbatch + closure-captured by activation-checkpoint custom_forward => recompute +# safe under PP 1F1B), matching slime. With ``--qkv-format bshd`` packed_seq_params is None, so +# we fall back to a thread-local dict keyed by layer_number. The fallback is correct for +# sequentially-executed micro-batches WITHOUT full activation recompute (each micro-batch's +# forward writes the anchor before the in-stage skip layer reads it, and the next micro-batch +# overwrites before its own skip reads). bshd + activation recompute is UNSAFE and is rejected +# at forward time (see the recompute guard in CrossLayerDSAttention.forward) -- use thd there. +# (Skip layers always have their source anchor in the same PP stage; see the runtime assert.) +_HOLDER_ATTR = "_dsa_index_share_topk_holder" +_TLS = threading.local() + + +def _holder(packed_seq_params): + if packed_seq_params is not None: + h = getattr(packed_seq_params, _HOLDER_ATTR, None) + if h is None: + h = {} + setattr(packed_seq_params, _HOLDER_ATTR, h) + return h + h = getattr(_TLS, "holder", None) + if h is None: + h = {} + _TLS.holder = h + return h + + +class CrossLayerDSAttention(DSAttention): + """``DSAttention`` with GLM-5.2 cross-layer index sharing. + + Anchor (computing) layers behave like the base class but also publish their ``topk_indices`` + to a per-microbatch holder. Skip layers carry no indexer (dropped in ``__init__``) and reuse + the most recent computing layer's ``topk_indices`` for the sparse-attention kernel. + """ + + def __init__(self, config, submodules, layer_number, *args, **kwargs): + super().__init__(config, submodules, layer_number, *args, **kwargs) + self._index_topk_freq = getattr(config, "dsa_index_topk_freq", 1) or 1 + self._skip_topk_offset = getattr(config, "dsa_index_skip_topk_offset", 0) or 0 + self._index_share = self._index_topk_freq > 1 + self._skip_topk = self._index_share and is_skip_topk_layer( + layer_number, self._skip_topk_offset, self._index_topk_freq + ) + self._source_layer = ( + source_compute_layer(layer_number, self._skip_topk_offset, self._index_topk_freq) + if self._index_share + else layer_number + ) + # Skip layers carry NO indexer params: drop the module the base class built so the + # parameter set matches the GLM-5.2 checkpoint (indexer weights only on computing + # layers) and HF export / LoRA target-matching naturally omit them here. + if self._skip_topk and hasattr(self, "indexer"): + del self.indexer + # The bshd holder fallback (thread-local) is NOT recompute-safe (see ``forward``); the + # thd carrier on ``packed_seq_params`` is. Remember whether activation recompute is on so + # the forward can reject the unsafe bshd + recompute + cross-layer combination loudly. + self._recompute_active = self._index_share and (getattr(config, "recompute_granularity", None) is not None) + + def forward( + self, + query, + key, + value, + attention_mask, + x, + qr, + attn_mask_type=None, + attention_bias=None, + packed_seq_params=None, + ): + # GLM-5.1 / no sharing -> identical to the base class. + if not self._index_share: + return super().forward( + query, + key, + value, + attention_mask, + x, + qr, + attn_mask_type, + attention_bias, + packed_seq_params, + ) + + # bshd (``packed_seq_params is None``) uses the thread-local holder fallback, which is + # NOT recompute-safe: under activation recompute a skip layer's recompute can read a stale + # anchor top-k (the thread-local dict is not closure-captured per microbatch the way + # ``packed_seq_params`` is). Fail loudly instead of silently producing wrong gradients. + if packed_seq_params is None and self._recompute_active: + raise AssertionError( + "DSA cross-layer index-share is not recompute-safe in the bshd layout: " + "packed_seq_params is None, so the per-microbatch top-k holder falls back to a " + "thread-local dict that activation recompute can read stale. Use --qkv-format thd " + "(the holder rides on packed_seq_params and is recompute-safe), or disable " + f"activation recompute (recompute_granularity=" + f"{getattr(self.config, 'recompute_granularity', None)})." + ) + + holder = _holder(packed_seq_params) + + # ---- skip layer: reuse the anchor's top-k, no indexer compute, no indexer loss ---- + if self._skip_topk: + if self._source_layer not in holder: + raise AssertionError( + f"DSA index-share: skip layer (layer_number={self.layer_number}) needs the " + f"top-k of its source computing layer (layer_number={self._source_layer}), " + f"which did not run in this pipeline stage's forward. Ensure every PP stage " + f"starts on a computing layer (index_topk_freq={self._index_topk_freq}, " + f"index_skip_topk_offset={self._skip_topk_offset}). Holder has {sorted(holder)}." + ) + topk_indices = holder[self._source_layer] + return unfused_dsa_fn(query, key, value, topk_indices, self.softmax_scale) + + # ---- anchor layer: compute top-k (base-class logic) + publish to holder ---- + sq, b, np, hn = query.size() + skv = key.size(0) + x = x.detach() + qr = qr.detach() + + if attn_mask_type is not None: + assert attn_mask_type == AttnMaskType.causal, "Only causal mask is supported for now" + float_mask = torch.triu( + torch.full((sq, skv), float("-inf"), dtype=torch.float32, device=x.device), + diagonal=1, + ) + else: + assert attention_mask.shape == (b, 1, sq, skv), "attention_mask shape mismatch" + mask = attention_mask.squeeze() + float_mask = torch.zeros_like(mask, dtype=torch.float32).masked_fill(mask, float("-inf")) + + if self.training and torch.is_grad_enabled(): + q, k, weights = self.indexer.forward_before_topk(x, qr, packed_seq_params) + indexer_loss_coeff = getattr(self.config, "dsa_indexer_loss_coeff", 0.0) + topk_indices, indexer_loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query.detach(), + key.detach(), + self.softmax_scale, + self.indexer.index_topk, + indexer_loss_coeff, + float_mask, + getattr(self.config, "dsa_indexer_use_sparse_loss", False), + self.indexer.pg_collection, + ) + if indexer_loss_coeff > 0: + DSAIndexerLossLoggingHelper.save_loss_to_tracker( + loss=indexer_loss, + layer_number=self.layer_number, + num_layers=self.config.num_layers, + ) + holder[self.layer_number] = topk_indices + output = unfused_dsa_fn(query, key, value, topk_indices, self.softmax_scale) + output = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + else: + _, topk_indices = self.indexer.forward_with_scores( + x, qr, mask=float_mask, packed_seq_params=packed_seq_params + ) + holder[self.layer_number] = topk_indices + output = unfused_dsa_fn(query, key, value, topk_indices, self.softmax_scale) + + return output + + +def get_glm5_crosslayer_dsa_spec(config, backend=None): + """megatron-core's *exact* DSA MLA spec, with the core-attention module swapped to + :class:`CrossLayerDSAttention`. + + Rather than hand-clone ``get_dsa_module_spec_for_backend`` (which is easy to get subtly + wrong -- e.g. it fuses the qk-layernorm into the q/kv up-projections via + ``column_parallel_layer_norm_linear`` and sets ``q_layernorm = kv_layernorm = IdentityOp``, + so the MLA tensor dims match the checkpoint), we call it and mutate only the one thing that + differs: ``submodules.core_attention.module``. The indexer ModuleSpec is uniform across + layers; skip layers drop it in ``CrossLayerDSAttention.__init__``. + + ``metainfo['fuse_input_layernorm']=False`` is set here because this path bypasses the + dispatcher (``get_experimental_attention_variant_module_spec``) that would otherwise set it; + this mirrors the GLM-5.1 fallback in ``glm5_bridge._build_glm5_dsa_block_spec``. + """ + from megatron.core.models.gpt import experimental_attention_variant_module_specs as _eav + + if backend is None: + backend = _eav._get_backend_spec_provider(config=config) + spec = _eav.get_dsa_module_spec_for_backend(config=config, backend=backend) + spec.submodules.core_attention.module = CrossLayerDSAttention + if spec.metainfo is None: + spec.metainfo = {} + spec.metainfo.setdefault("fuse_input_layernorm", False) + return spec diff --git a/src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py b/src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py index 77b3c519c3..6500a60f65 100644 --- a/src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py +++ b/src/megatron/bridge/models/glm_moe_dsa/glm5_bridge.py @@ -31,6 +31,76 @@ logger = logging.getLogger(__name__) +def _build_glm5_dsa_block_spec(config, *args, **kwargs): + """``transformer_layer_spec`` for GLM-5 / GLM-5.1 DSA (feature-detected, self-disabling). + + Older megatron-core (e.g. 0.16.0rc0): its experimental-attention dispatcher + (``get_experimental_attention_variant_module_spec``) only natively wires + ``"gated_delta_net"`` and raises ``ValueError`` for ``"dsa"``, and its DSA builder + (``get_dsa_module_spec_for_backend``) omits the ``metainfo`` the variant + layer-builder reads. Newer megatron-core handles ``"dsa"`` natively (the dispatcher + gained a ``== "dsa"`` branch and the DSA builder sets ``metainfo`` itself). + + So this wraps the dispatcher to PREFER megatron-core's own handling, and only when it + raises for ``"dsa"`` (old megatron-core) back-fills via the shipped DSA builder + sets + ``metainfo["fuse_input_layernorm"]=False`` (MLA-based DSA keeps a separate, non-fused + input layernorm, like the DeepSeek-V4 ``dsv4`` spec; ``gated_delta_net`` uses ``True``). + => On newer megatron-core this is a transparent no-op; once the runtime's megatron-core + handles ``"dsa"``, this whole helper can be deleted. No megatron-core source change. + """ + # GLM-5.2 cross-layer: fail early at build time if this (virtual) pipeline stage would start + # on a skip layer -- the per-microbatch top-k holder does not cross PP boundaries. No-op for + # GLM-5.1 (index_topk_freq=1) and when the layout can't be determined (runtime guard backs it). + if getattr(config, "experimental_attention_variant", None) == "dsa" and ( + (getattr(config, "dsa_index_topk_freq", 1) or 1) > 1 + ): + from megatron.bridge.models.glm_moe_dsa.cross_layer_dsa import ( + assert_pp_stage_starts_on_computing_layer, + ) + + assert_pp_stage_starts_on_computing_layer(config, vp_stage=kwargs.get("vp_stage")) + + from megatron.core.models.gpt import experimental_attention_variant_module_specs as _eav + + _orig = _eav.get_experimental_attention_variant_module_spec + + def _patched(config, backend=None): + # GLM-5.2 DSA cross-layer index sharing: when index_topk_freq>1, build our own + # CrossLayerDSAttention spec (megatron-core's DSA -- native or shimmed -- is per-layer + # only and cannot share top-k across layers). GLM-5.1 (no freq) falls through below. + if getattr(config, "experimental_attention_variant", None) == "dsa" and ( + (getattr(config, "dsa_index_topk_freq", 1) or 1) > 1 + ): + if backend is None: + backend = _eav._get_backend_spec_provider(config=config) + from megatron.bridge.models.glm_moe_dsa.cross_layer_dsa import ( + get_glm5_crosslayer_dsa_spec, + ) + + return get_glm5_crosslayer_dsa_spec(config, backend) + # Prefer megatron-core's native handling (works as-is on newer megatron-core). + try: + return _orig(config, backend) + except ValueError: + # Old megatron-core: dispatcher doesn't know "dsa". Don't mask genuine errors + # for other variants -- only back-fill the dsa case. + if getattr(config, "experimental_attention_variant", None) != "dsa": + raise + if backend is None: + backend = _eav._get_backend_spec_provider(config=config) + spec = _eav.get_dsa_module_spec_for_backend(config=config, backend=backend) + if spec.metainfo is None: + spec.metainfo = {} + spec.metainfo.setdefault("fuse_input_layernorm", False) + return spec + + _eav.get_experimental_attention_variant_module_spec = _patched + try: + return _eav.get_transformer_block_with_experimental_attention_variant_spec(config, *args, **kwargs) + finally: + _eav.get_experimental_attention_variant_module_spec = _orig + + @MegatronModelBridge.register_bridge( source=GlmMoeDsaForCausalLM, target=GPTModel, provider=MLAModelProvider, model_type="glm_moe_dsa" ) @@ -57,13 +127,14 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MLAModelProvider provider = super().provider_bridge(hf_pretrained) hf_config = hf_pretrained.config - # Use experimental-attention spec for DSA + # Use experimental-attention spec for DSA. megatron-core's dispatcher raises for + # "dsa", so route it through _build_glm5_dsa_block_spec (which makes the DSA + # variant buildable + supplies the metainfo). This makes the GLM-5/5.1 bridge + # self-contained for both LoRA and full-FT builds (no caller-side monkey-patch). try: - from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( - get_transformer_block_with_experimental_attention_variant_spec, - ) + import megatron.core.models.gpt.experimental_attention_variant_module_specs # noqa: F401 - provider.transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec + provider.transformer_layer_spec = _build_glm5_dsa_block_spec except (ImportError, ModuleNotFoundError): logger.warning("DSA spec not available; falling back to standard GPT decoder block spec.") @@ -99,13 +170,41 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MLAModelProvider ) provider.moe_shared_expert_intermediate_size = hf_config.moe_intermediate_size * hf_config.n_shared_experts - # GLM5-specific: rotary_base is nested in rope_parameters - provider.rotary_base = hf_config.rope_parameters["rope_theta"] + # GLM5-specific: rope_theta is nested in rope_parameters (transformers 5.x) or flat + # (older / GLM-5.2 = 8e6). Handle both shapes robustly. + _rope_params = getattr(hf_config, "rope_parameters", None) + provider.rotary_base = ( + (_rope_params.get("rope_theta") if isinstance(_rope_params, dict) else None) + or getattr(hf_config, "rope_theta", None) + or 10000 + ) # GLM5 uses default rope (no YaRN scaling) provider.rotary_scaling_factor = 1.0 provider.mscale = 1.0 provider.mscale_all_dim = 1.0 + # GLM-5.2 / transformers>=5.12 mis-parses qk_rope_head_dim as head_dim (192) rather than + # the config.json value (64); the base config-mapping then sizes MLA's decoupled-rope key + # by 192, giving linear_kv_down_proj = kv_lora_rank + 192 = 704. The checkpoint is ground + # truth: kv_a_proj_with_mqa = kv_lora_rank + qk_rope_head_dim = 576 = 512 + 64, and MLA + # applies rotary over qk_pos_emb_head_dim. Read the rope dim straight from config.json so + # the dims match the weights for both GLM-5.1 (64) and GLM-5.2 (64). No-op when correct. + import json as _json + import os as _os + + _cfg_dir = getattr(hf_config, "_name_or_path", "") or "" + _cfg_json = _os.path.join(_cfg_dir, "config.json") + if _os.path.isfile(_cfg_json): + _rope = _json.load(open(_cfg_json)).get("qk_rope_head_dim") + if _rope and _rope != provider.qk_pos_emb_head_dim: + logger.info( + "GLM5 bridge: overriding qk_pos_emb_head_dim %s -> %s from config.json " + "(transformers mis-parse of qk_rope_head_dim)", + provider.qk_pos_emb_head_dim, + _rope, + ) + provider.qk_pos_emb_head_dim = _rope + # DSA indexer params provider.experimental_attention_variant = "dsa" provider.dsa_indexer_head_dim = hf_config.index_head_dim @@ -113,6 +212,11 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MLAModelProvider provider.dsa_indexer_topk = hf_config.index_topk provider.dsa_indexer_loss_coeff = 0.001 provider.dsa_indexer_use_sparse_loss = True + # GLM-5.2 DSA cross-layer index sharing. Absent in GLM-5.1 (-> freq=1 -> every layer + # computes its own top-k = plain DSA). When >1, CrossLayerDSAttention builds the indexer + # only on computing layers and skip layers reuse the most recent computing layer's top-k. + provider.dsa_index_topk_freq = getattr(hf_config, "index_topk_freq", 1) or 1 + provider.dsa_index_skip_topk_offset = getattr(hf_config, "index_skip_topk_offset", 0) or 0 return provider