From 5dfcf8b4819cb02630fe0ac8d23b5e93e638becd Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 27 Jan 2026 09:38:58 +0000 Subject: [PATCH 1/3] draft --- .gitignore | 1 + ...026-01-27-remove-moe-model-mixin-design.md | 0 lightllm/common/basemodel/basemodel.py | 2 + .../fused_moe/fused_moe_weight.py | 6 + .../fused_moe/gpt_oss_fused_moe_weight_tp.py | 5 + .../meta_weights/fused_moe/impl/base_impl.py | 2 + .../fused_moe/impl/triton_impl.py | 7 + lightllm/common/basemodel/routing_manager.py | 178 ++++++++++++++++++ lightllm/common/quantization/w8a8.py | 2 +- .../layer_infer/transformer_layer_infer.py | 2 + lightllm/models/deepseek2/model.py | 2 + .../layer_infer/transformer_layer_infer.py | 1 + lightllm/models/gpt_oss/model.py | 5 + lightllm/models/llama/model.py | 11 +- .../models/mixtral/layer_infer/_custom_ops.py | 46 ----- .../layer_infer/transformer_layer_infer.py | 33 +--- lightllm/models/mixtral/model.py | 2 + .../layer_infer/transformer_layer_infer.py | 2 + .../layer_weights/transformer_layer_weight.py | 1 - lightllm/models/qwen3_moe/model.py | 2 + lightllm/server/api_cli.py | 6 + lightllm/server/api_lightllm.py | 5 + lightllm/server/core/objs/req.py | 41 ++++ lightllm/server/core/objs/sampling_params.py | 1 + lightllm/server/httpserver/manager.py | 20 ++ .../server/router/model_infer/infer_batch.py | 17 ++ .../model_infer/mode_backend/base_backend.py | 4 + .../mode_backend/chunked_prefill/impl.py | 4 + .../mode_backend/diverse_backend/impl.py | 1 + .../mode_backend/dp_backend/impl.py | 12 ++ scripts/run_e2e_r3_test.sh | 109 +++++++++++ test_r3.py | 99 ++++++++++ unit_tests/__init__.py | 0 unit_tests/common/__init__.py | 0 unit_tests/common/basemodel/__init__.py | 0 .../basemodel/test_routing_capture_manager.py | 132 +++++++++++++ 36 files changed, 686 insertions(+), 75 deletions(-) create mode 100644 docs/plans/2026-01-27-remove-moe-model-mixin-design.md create mode 100644 lightllm/common/basemodel/routing_manager.py delete mode 100644 lightllm/models/mixtral/layer_infer/_custom_ops.py create mode 100755 scripts/run_e2e_r3_test.sh create mode 100644 test_r3.py create mode 100644 unit_tests/__init__.py create mode 100644 unit_tests/common/__init__.py create mode 100644 unit_tests/common/basemodel/__init__.py create mode 100644 unit_tests/common/basemodel/test_routing_capture_manager.py diff --git a/.gitignore b/.gitignore index 63408699f..3fb49db8b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/docs/plans/2026-01-27-remove-moe-model-mixin-design.md b/docs/plans/2026-01-27-remove-moe-model-mixin-design.md new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e6405e4d7..36968471f 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -164,6 +165,7 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): + reset_moe_layer_counter() self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03..ada943f6c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -13,6 +13,7 @@ from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args from lightllm.utils.dist_utils import get_global_world_size, get_global_rank from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.routing_manager import g_routing_capture_manager, get_next_moe_layer_index logger = init_logger(__name__) @@ -35,6 +36,7 @@ def __init__( network_config: Dict[str, Any] = None, ) -> None: super().__init__(data_type=data_type) + self.moe_layer_index = get_next_moe_layer_index() self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name self.w3_weight_name = up_proj_name @@ -103,6 +105,7 @@ def _init_parallel_params(self): f"redundancy_expertids: {self.redundancy_expert_ids}" ) self.local_n_routed_experts = self.n_routed_experts // self.global_world_size + self.redundancy_expert_num + self.split_inter_size = self.moe_intermediate_size n_experts_per_rank = self.n_routed_experts // self.global_world_size start_expert_id = self.global_rank_ * n_experts_per_rank self.local_expert_ids = ( @@ -130,6 +133,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +149,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + moe_layer_index=self.moe_layer_index, + microbatch_index=microbatch_index, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b..5627a5925 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -8,6 +8,7 @@ from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.routing_manager import g_routing_capture_manager logger = init_logger(__name__) @@ -144,10 +145,14 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._router(router_logits, top_k) + if g_routing_capture_manager is not None: + g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac18..1c93cb13d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 8bcdb4bf9..77a9b1a45 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -3,6 +3,7 @@ from lightllm.common.quantization.no_quant import WeightPack from lightllm.common.quantization.quantize_method import QuantizationMethod from .base_impl import FuseMoeBaseImpl +from lightllm.common.basemodel.routing_manager import g_routing_capture_manager class FuseMoeTriton(FuseMoeBaseImpl): @@ -124,6 +125,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -136,6 +139,10 @@ def __call__( num_expert_group=num_expert_group, scoring_func=scoring_func, ) + + if g_routing_capture_manager is not None and moe_layer_index is not None: + g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index) + output = self._fused_experts( input_tensor=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 000000000..4a76dca95 --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,178 @@ +import torch +import numpy as np +from typing import Optional +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_rank_in_dp + +logger = init_logger(__name__) + +# MoE layer counter for auto-incrementing moe_layer_index +_moe_layer_counter: int = 0 + + +def reset_moe_layer_counter() -> None: + global _moe_layer_counter + _moe_layer_counter = 0 + + +def get_next_moe_layer_index() -> int: + global _moe_layer_counter + idx = _moe_layer_counter + _moe_layer_counter += 1 + return idx + + +def get_moe_layer_count() -> int: + return _moe_layer_counter + + +class RoutingCaptureManager: + """Captures MoE routing decisions""" + + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + batch_max_tokens: int, + kv_cache_size: int, + enable_overlap: bool = False, + ): + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.batch_max_tokens = batch_max_tokens + self.kv_cache_size = kv_cache_size + + self.dtype = torch.int8 if num_experts <= 127 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.int8 else 2 + + self.num_slots = 2 if enable_overlap else 1 + + gpu_buffer_size = self.num_slots * num_moe_layers * batch_max_tokens * topk * dtype_bytes + self.gpu_buffer = torch.zeros( + (self.num_slots, num_moe_layers, batch_max_tokens, topk), + dtype=self.dtype, + device="cuda", + ) + + cpu_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.cpu_buffer = torch.zeros( + (num_moe_layers, kv_cache_size, topk), + dtype=self.dtype, + device="cpu", + pin_memory=True, + ) + + self.flush_streams = [torch.cuda.Stream() for _ in range(self.num_slots)] + self.flush_events = [torch.cuda.Event() for _ in range(self.num_slots)] + + dtype_name = "int8" if self.dtype == torch.int8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"slots={self.num_slots}, GPU={gpu_buffer_size / 1024 / 1024:.2f}MB, " + f"CPU={cpu_buffer_size / 1024 / 1024:.2f}MB, dtype={dtype_name}" + ) + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + assert ( + 0 <= moe_layer_index < self.num_moe_layers + ), f"moe_layer_index {moe_layer_index} out of range [0, {self.num_moe_layers})" + slot = microbatch_index % self.num_slots + num_tokens = topk_ids.shape[0] + self.gpu_buffer[slot, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype) + + def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None: + num_tokens = mem_indexes.shape[0] + if num_tokens == 0: + return + + slot = microbatch_index % self.num_slots + stream = self.flush_streams[slot] + event = self.flush_events[slot] + + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream): + cpu_indexes = mem_indexes.cpu() + self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu() + event.record() + + def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray: + for event in self.flush_events: + event.synchronize() + return self.cpu_buffer[:, mem_indexes, :].numpy() + + +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + batch_max_tokens: int, + kv_cache_size: int, + enable_overlap: bool = False, +) -> None: + global g_routing_capture_manager + assert g_routing_capture_manager is None, "RoutingCaptureManager already exists" + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=batch_max_tokens, + kv_cache_size=kv_cache_size, + enable_overlap=enable_overlap, + ) + + +def init_routing_capture(model) -> None: + if not getattr(model.args, "enable_return_routed_experts", False): + return + + # Only create routing capture manager on rank 0 + # Routing decisions are identical across all TP ranks, so we only need to capture on rank 0 + # which is the rank that communicates results back to the Router/HTTP server + if get_current_rank_in_dp() != 0: + logger.info("Skipping routing capture initialization on non-zero rank") + return + + num_moe_layers = get_moe_layer_count() + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled." + ) + return + + n_routed_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) + if n_routed_experts == 0: + logger.warning( + "enable_return_routed_experts is set but n_routed_experts=0. " "Routing capture will not be enabled." + ) + return + + topk = model.config.get("num_experts_per_tok", 1) + num_experts = n_routed_experts + + enable_overlap = getattr(model.args, "enable_decode_microbatch_overlap", False) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=model.max_total_token_num, + # Add 1 to handle potential edge case where mem_index == size + kv_cache_size=model.mem_manager.size + 1, + enable_overlap=enable_overlap, + ) + + +def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + if g_routing_capture_manager is not None: + g_routing_capture_manager.flush_to_cpu_async(mem_indexes, microbatch_index) diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index 98626e1d3..3f4fec319 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -72,7 +72,7 @@ def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: weight = weight.float().cuda(self.device_id_) scale = weight.abs().max(dim=-1)[0] / 127 weight = weight / scale.reshape(-1, 1) - weight = torch.round(weight.clamp(min=-127, max=127)).to(dtype=torch.int8) + weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) output.weight.copy_(weight) output.weight_scale.copy_(scale) return diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index e1e435cce..9b7a4274f 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -311,6 +311,7 @@ def _moe_ffn( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -337,6 +338,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index f0739a8a8..2b8c48e65 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -48,6 +49,7 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + init_routing_capture(self) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d80eefd16..e5672f821 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 9e9561eb2..f91e25690 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -28,3 +29,7 @@ def _init_att_backend(self): self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])( model=self ) + + def _init_custom(self): + super()._init_custom() + init_routing_capture(self) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index c104ebccc..f5a66a6c3 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -74,14 +74,18 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() - return - - if "rope_type" in rope_scaling: + elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") + super()._init_custom() + + def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): + """Initialize rotary embeddings based on scaling type.""" if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -96,7 +100,6 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py deleted file mode 100644 index b0e27ac1d..000000000 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -# Pytorch version -# Triton version in progress -def topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output, - topk=2, -): - scores = torch.softmax(gating_output, dim=-1) - topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False) - return topk_weights, topk_ids - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - alloc_tensor_func=torch.empty, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 44e66cff2..a2968f5ab 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -19,25 +16,15 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_)) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, - ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - return fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e8..76b18667d 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -45,6 +46,7 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + init_routing_capture(self) return def _init_mem_manager(self): diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 71b16cb34..43d15cc54 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -128,6 +128,7 @@ def _moe_ffn( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) @@ -148,6 +149,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index a889609d7..dd1cc6112 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -1,4 +1,3 @@ -import os from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a505127..1dabc38a9 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -4,6 +4,7 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -26,3 +27,4 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + init_routing_capture(self) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index e49b0cc67..fb2f0094d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -629,4 +629,10 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f5..5abd90815 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -53,6 +53,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +79,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +105,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f489aac9c..128423e6e 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -122,6 +122,9 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), + ("routing_data_num_moe_layers", ctypes.c_int), + ("routing_data_num_tokens", ctypes.c_int), + ("routing_data_topk", ctypes.c_int), ] def get_str(self): @@ -180,6 +183,10 @@ def init( self.stop_str_matched = False self.stop_str_matched_token_index = -1 + self.routing_data_num_moe_layers = 0 + self.routing_data_num_tokens = 0 + self.routing_data_topk = 0 + self.post_init() self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size @@ -227,6 +234,40 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int): + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_moe_layers, num_tokens, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) + self.shm_routing_data.create_shm() + self.routing_data_num_moe_layers = num_moe_layers + self.routing_data_num_tokens = num_tokens + self.routing_data_topk = topk + return + + def link_routing_data_shm_array(self): + if self.routing_data_num_moe_layers == 0: + return + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (self.routing_data_num_moe_layers, self.routing_data_num_tokens, self.routing_data_topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) + self.shm_routing_data.link_shm() + return + + def get_routing_data(self): + if self.routing_data_num_moe_layers == 0 or not hasattr(self, "shm_routing_data"): + return None + if self.shm_routing_data is None: + return None + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.close_shm() + self.shm_routing_data = None + return + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index d955aa6a8..7d65cd233 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -497,6 +497,7 @@ def to_dict(self): "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, "print_eos_token": self.print_eos_token, "disable_prompt_cache": self.disable_prompt_cache, + "return_routed_experts": self.return_routed_experts, } def to_origin_dict(self): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 212e037e9..bd92a6b7d 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import base64 from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -686,6 +687,11 @@ async def recycle_resource_loop(self): for req_status in release_req_status: self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) for req in req_status.group_req_objs.shm_req_objs: + if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) @@ -773,6 +779,20 @@ async def handle_loop(self): else: finish_status = FinishStatus(req.finish_status.status) + if req.sample_params.return_routed_experts and req.routing_data_num_moe_layers > 0: + try: + req.link_routing_data_shm_array() + routing_data = req.get_routing_data() + if routing_data is not None: + metadata["routed_experts"] = { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + req.close_routing_data_shm_array() + except Exception as e: + logger.warning(f"Failed to read routing data for req {req_id}: {e}") + token_list.append((req_id, text, metadata, finish_status)) else: break diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538..276f97c18 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -19,6 +19,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel.routing_manager import g_routing_capture_manager logger = init_logger(__name__) @@ -113,6 +114,17 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq"): + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + num_moe_layers = g_routing_capture_manager.num_moe_layers + topk = g_routing_capture_manager.topk + num_tokens = req.cur_kv_len + logger.debug(f"R3: Extracting routing for req {req.req_id}: {num_moe_layers}x{num_tokens}x{topk}") + routing_data = g_routing_capture_manager.extract_for_request(mem_indexes.cpu()) + req.shm_req.create_routing_data_shm_array(num_moe_layers, num_tokens, topk) + req.shm_req.shm_routing_data.arr[:] = routing_data + logger.debug(f"R3: Successfully extracted routing data for req {req.req_id}") + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -155,6 +167,9 @@ def _filter(self, finished_request_ids: List[int]): req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() + + self._extract_routing_data(req) + self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) @@ -580,6 +595,8 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + # Extract routing data before setting finish_status so HTTP server sees it + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 805c9b8e5..bab305a53 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -40,6 +40,7 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.common.basemodel.routing_manager import flush_routing_capture from .multi_level_kv_cache import MultiLevelKvCacheModule @@ -794,6 +795,9 @@ def _sample_and_scatter_token( ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _flush_routing_after_sample(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + flush_routing_capture(mem_indexes, microbatch_index) + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224eb..969964a2a 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -118,6 +118,7 @@ def prefill_normal( b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu, mask_func=self.prefill_mask_func, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -156,6 +157,7 @@ def decode_normal( is_prefill=False, mask_func=self.decode_mask_func, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -195,6 +197,7 @@ def prefill_mtp( b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu, mask_func=self.prefill_mask_func, ) + self._flush_routing_after_sample(model_input.mem_indexes) # mtp kv fill self._draft_prefill_forward( model_input=model_input, model_output=model_output, next_token_ids=next_token_ids @@ -279,6 +282,7 @@ def decode_mtp( next_token_ids=next_token_ids, mask=accepted_index == 1, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb62..1d93ea116 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -77,6 +77,7 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq next_token_ids=next_token_ids, next_token_logprobs=next_token_logprobs ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e7..877264e6e 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -155,6 +155,7 @@ def prefill_normal( b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu[:run_reqs_num], mask_func=None, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -197,6 +198,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq is_prefill=False, mask_func=None, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -263,6 +265,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer b_prefill_has_output_cpu=b_has_out_cpu, mask_func=None, ) + self._flush_routing_after_sample(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_after_sample(model_input1.mem_indexes, microbatch_index=1) sync_event = torch.cuda.Event() sync_event.record() @@ -326,6 +330,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe is_prefill=False, mask_func=None, ) + self._flush_routing_after_sample(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_after_sample(model_input1.mem_indexes, microbatch_index=1) sync_event = torch.cuda.Event() sync_event.record() @@ -374,6 +380,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] b_prefill_has_output_cpu=b_has_out_cpu, mask_func=None, ) + self._flush_routing_after_sample(model_input.mem_indexes) # mtp kv fill draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") @@ -471,6 +478,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): next_token_ids=next_token_ids, mask=accepted_index == 1, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -652,6 +660,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I is_prefill=True, b_prefill_has_output_cpu=b_has_out_cpu, ) + self._flush_routing_after_sample(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_after_sample(model_input1.mem_indexes, microbatch_index=1) # spec prefill: MTP draft_model_input0, draft_model_input1 = model_input0, model_input1 @@ -789,6 +799,8 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf next_token_ids=next_token_ids, mask=accepted_index == 1, ) + self._flush_routing_after_sample(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_after_sample(model_input1.mem_indexes, microbatch_index=1) sync_event = torch.cuda.Event() sync_event.record() diff --git a/scripts/run_e2e_r3_test.sh b/scripts/run_e2e_r3_test.sh new file mode 100755 index 000000000..3100f18a7 --- /dev/null +++ b/scripts/run_e2e_r3_test.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# E2E Test Script for R3 Routing Capture Feature +# +# This script starts a LightLLM server with routing capture enabled, +# runs the client test, and verifies the results. +# +# Requirements: +# - A MoE model (DeepSeek-V2/V3, Qwen-MoE, Mixtral, etc.) +# - At least 1 GPU with sufficient memory +# - LightLLM installed +# +# Usage: +# ./scripts/run_e2e_r3_test.sh /path/to/moe/model [--tp N] + +set -e + +MODEL_DIR="${1:-}" +TP="${2:-1}" +PORT=8765 + +if [ -z "$MODEL_DIR" ]; then + echo "Usage: $0 /path/to/moe/model [--tp N]" + echo "" + echo "Example:" + echo " $0 /models/DeepSeek-V3 --tp 8" + echo " $0 /models/Qwen-MoE-A14B --tp 4" + exit 1 +fi + +if [ ! -d "$MODEL_DIR" ]; then + echo "ERROR: Model directory not found: $MODEL_DIR" + exit 1 +fi + +echo "==========================================" +echo "R3 E2E Test: Routing Capture Feature" +echo "==========================================" +echo "Model: $MODEL_DIR" +echo "TP: $TP" +echo "Port: $PORT" +echo "" + +# Kill any existing server on the port +pkill -f "lightllm.server.api_server.*--port $PORT" 2>/dev/null || true +sleep 2 + +# Start server in background +echo "Starting LightLLM server..." +python -m lightllm.server.api_server \ + --model_dir "$MODEL_DIR" \ + --tp "$TP" \ + --port "$PORT" \ + --enable_return_routed_experts \ + --max_total_token_num 8000 \ + --batch_max_tokens 4000 \ + > /tmp/lightllm_r3_test.log 2>&1 & + +SERVER_PID=$! +echo "Server PID: $SERVER_PID" +echo "Log: /tmp/lightllm_r3_test.log" + +# Wait for server to be ready +echo "Waiting for server to be ready..." +MAX_WAIT=300 +WAITED=0 +while [ $WAITED -lt $MAX_WAIT ]; do + if curl -s "http://localhost:$PORT/health" > /dev/null 2>&1; then + echo "Server is ready!" + break + fi + sleep 5 + WAITED=$((WAITED + 5)) + echo " Waited ${WAITED}s..." +done + +if [ $WAITED -ge $MAX_WAIT ]; then + echo "ERROR: Server failed to start within ${MAX_WAIT}s" + echo "Server log:" + tail -50 /tmp/lightllm_r3_test.log + kill $SERVER_PID 2>/dev/null || true + exit 1 +fi + +# Run client test +echo "" +echo "Running R3 client test..." +echo "==========================================" +python test_r3.py --url "http://localhost:$PORT" +TEST_RESULT=$? + +# Cleanup +echo "" +echo "Stopping server..." +kill $SERVER_PID 2>/dev/null || true +wait $SERVER_PID 2>/dev/null || true + +# Report result +echo "" +echo "==========================================" +if [ $TEST_RESULT -eq 0 ]; then + echo "E2E TEST PASSED!" +else + echo "E2E TEST FAILED!" + echo "Server log (last 30 lines):" + tail -30 /tmp/lightllm_r3_test.log +fi +echo "==========================================" + +exit $TEST_RESULT diff --git a/test_r3.py b/test_r3.py new file mode 100644 index 000000000..14157fab5 --- /dev/null +++ b/test_r3.py @@ -0,0 +1,99 @@ +""" +R3 Client Test: Tests the routing capture export feature. + +This test requires a running LightLLM server with: +- A MoE model (e.g., DeepSeek-V2/V3) +- --enable_return_routed_experts flag + +Usage: + python test_r3.py [--url URL] +""" +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + """Test the routing export feature.""" + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + "return_routed_experts": True, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + # Check for routed_experts in response + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + # Decode routed_experts from base64 + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype = np.dtype(routing_info["dtype"]) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape} # [num_moe_layers, num_tokens, topk]") + print(f"Dtype: {dtype}") + print(f"Num MoE layers: {shape[0]}") + print(f"Num tokens: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Show sample of routing data + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = min(5, shape[1]) + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") + + # Validate data + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 000000000..8d5a1d0bf --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,132 @@ +import pytest +import torch +import numpy as np + + +def test_moe_layer_counter(): + """Counter increments and resets correctly.""" + from lightllm.common.basemodel.routing_manager import ( + reset_moe_layer_counter, + get_next_moe_layer_index, + get_moe_layer_count, + ) + + reset_moe_layer_counter() + assert get_moe_layer_count() == 0 + + assert get_next_moe_layer_index() == 0 + assert get_next_moe_layer_index() == 1 + assert get_next_moe_layer_index() == 2 + assert get_moe_layer_count() == 3 + + reset_moe_layer_counter() + assert get_moe_layer_count() == 0 + assert get_next_moe_layer_index() == 0 + + +class TestRoutingCaptureManager: + """Tests for the redesigned RoutingCaptureManager.""" + + def test_capture_explicit_layer_index(self): + """Capture stores data at explicit moe_layer_index.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + batch_max_tokens=128, + kv_cache_size=1024, + enable_overlap=False, + ) + + # Capture at layer 2 (not sequential) + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=2, topk_ids=topk_ids) + + # Verify data is at layer 2, not layer 0 + assert torch.equal(manager.gpu_buffer[0, 2, :10, :], topk_ids.to(manager.dtype)) + + def test_double_buffer_overlap_mode(self): + """Double buffer prevents race condition in overlap mode.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + batch_max_tokens=64, + kv_cache_size=256, + enable_overlap=True, + ) + + # Should have 2 buffer slots + assert manager.num_slots == 2 + assert manager.gpu_buffer.shape[0] == 2 + + # Capture to slot 0 (microbatch_index=0) + ids_0 = torch.ones((5, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + # Capture to slot 1 (microbatch_index=1) + ids_1 = torch.ones((5, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + # Both slots have different data + assert manager.gpu_buffer[0, 0, 0, 0].item() == 1 + assert manager.gpu_buffer[1, 0, 0, 0].item() == 2 + + def test_flush_and_extract(self): + """Flush transfers data to CPU, extract retrieves by mem_index.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + batch_max_tokens=64, + kv_cache_size=256, + enable_overlap=False, + ) + + # Capture some data (microbatch_index defaults to 0) + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids) + manager.capture(moe_layer_index=1, topk_ids=topk_ids + 10) + + # Flush to mem_indexes 10 and 11 + mem_indexes = torch.tensor([10, 11], device="cuda") + manager.flush_to_cpu_async(mem_indexes, microbatch_index=0) + + # Extract + result = manager.extract_for_request(mem_indexes.cpu()) + + assert result.shape == (2, 2, 4) # [layers, tokens, topk] + assert result[0, 0, 0] == 1 + assert result[1, 0, 0] == 11 + + def test_dtype_selection(self): + """Uses int8 for <=127 experts, int16 otherwise.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + # Small expert count -> int8 + manager_small = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + batch_max_tokens=32, + kv_cache_size=128, + enable_overlap=False, + ) + assert manager_small.dtype == torch.int8 + + # Large expert count -> int16 + manager_large = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + batch_max_tokens=32, + kv_cache_size=128, + enable_overlap=False, + ) + assert manager_large.dtype == torch.int16 From 611b216f3201b16e4557cb1811cc55821fb00559 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 27 Jan 2026 12:31:25 +0000 Subject: [PATCH 2/3] clean --- .../fused_moe/fused_moe_weight.py | 3 +- lightllm/common/basemodel/routing_manager.py | 84 ++++++++++++++----- lightllm/common/quantization/w8a8.py | 2 +- lightllm/server/core/objs/req.py | 45 ++++++---- lightllm/server/core/objs/sampling_params.py | 1 - lightllm/server/core/objs/start_args_type.py | 2 + lightllm/server/httpserver/manager.py | 24 +++--- .../server/router/model_infer/infer_batch.py | 25 ++++-- 8 files changed, 128 insertions(+), 58 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index ada943f6c..2e723de44 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -13,7 +13,7 @@ from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args from lightllm.utils.dist_utils import get_global_world_size, get_global_rank from lightllm.utils.log_utils import init_logger -from lightllm.common.basemodel.routing_manager import g_routing_capture_manager, get_next_moe_layer_index +from lightllm.common.basemodel.routing_manager import get_next_moe_layer_index logger = init_logger(__name__) @@ -105,7 +105,6 @@ def _init_parallel_params(self): f"redundancy_expertids: {self.redundancy_expert_ids}" ) self.local_n_routed_experts = self.n_routed_experts // self.global_world_size + self.redundancy_expert_num - self.split_inter_size = self.moe_intermediate_size n_experts_per_rank = self.n_routed_experts // self.global_world_size start_expert_id = self.global_rank_ * n_experts_per_rank self.local_expert_ids = ( diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py index 4a76dca95..2bc55d3be 100644 --- a/lightllm/common/basemodel/routing_manager.py +++ b/lightllm/common/basemodel/routing_manager.py @@ -3,9 +3,52 @@ from typing import Optional from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_rank_in_dp +from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray +from lightllm.utils.envs_utils import get_unique_server_name logger = init_logger(__name__) + +class SharedRoutingConfig: + """Shared MoE routing configuration across processes.""" + + def __init__(self): + service_name = get_unique_server_name() + # Shape: [num_moe_layers, topk] + self._shm = SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32) + + @property + def num_moe_layers(self) -> int: + return int(self._shm.arr[0]) + + @num_moe_layers.setter + def num_moe_layers(self, value: int): + self._shm.arr[0] = value + + @property + def topk(self) -> int: + return int(self._shm.arr[1]) + + @topk.setter + def topk(self, value: int): + self._shm.arr[1] = value + + def is_initialized(self) -> bool: + return self.num_moe_layers > 0 and self.topk > 0 + + +# Global shared routing config (lazy initialized) +_shared_routing_config: Optional[SharedRoutingConfig] = None + + +def get_shared_routing_config() -> SharedRoutingConfig: + """Get or create the shared routing config.""" + global _shared_routing_config + if _shared_routing_config is None: + _shared_routing_config = SharedRoutingConfig() + return _shared_routing_config + + # MoE layer counter for auto-incrementing moe_layer_index _moe_layer_counter: int = 0 @@ -75,12 +118,8 @@ def __init__( ) def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: - assert ( - 0 <= moe_layer_index < self.num_moe_layers - ), f"moe_layer_index {moe_layer_index} out of range [0, {self.num_moe_layers})" - slot = microbatch_index % self.num_slots num_tokens = topk_ids.shape[0] - self.gpu_buffer[slot, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype) + self.gpu_buffer[microbatch_index, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype) def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None: num_tokens = mem_indexes.shape[0] @@ -98,9 +137,20 @@ def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) - self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu() event.record() - def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray: + def sync_events(self) -> None: + """Synchronize all flush events. Call once before batch extraction.""" for event in self.flush_events: event.synchronize() + + def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray: + self.sync_events() + return self.cpu_buffer[:, mem_indexes, :].numpy() + + def extract_for_request_no_sync(self, mem_indexes: torch.Tensor) -> np.ndarray: + """Extract routing data without synchronizing events. + + Call sync_events() once before using this method in a batch. + """ return self.cpu_buffer[:, mem_indexes, :].numpy() @@ -132,8 +182,6 @@ def init_routing_capture(model) -> None: return # Only create routing capture manager on rank 0 - # Routing decisions are identical across all TP ranks, so we only need to capture on rank 0 - # which is the rank that communicates results back to the Router/HTTP server if get_current_rank_in_dp() != 0: logger.info("Skipping routing capture initialization on non-zero rank") return @@ -145,16 +193,9 @@ def init_routing_capture(model) -> None: ) return - n_routed_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) - if n_routed_experts == 0: - logger.warning( - "enable_return_routed_experts is set but n_routed_experts=0. " "Routing capture will not be enabled." - ) - return - - topk = model.config.get("num_experts_per_tok", 1) - num_experts = n_routed_experts - + num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) + topk = model.config.get("num_experts_per_tok", 0) + assert num_experts > 0 and topk > 0 enable_overlap = getattr(model.args, "enable_decode_microbatch_overlap", False) logger.info( @@ -167,11 +208,16 @@ def init_routing_capture(model) -> None: topk=topk, num_experts=num_experts, batch_max_tokens=model.max_total_token_num, - # Add 1 to handle potential edge case where mem_index == size kv_cache_size=model.mem_manager.size + 1, enable_overlap=enable_overlap, ) + # Set shared routing config for cross-process access + shared_config = get_shared_routing_config() + shared_config.num_moe_layers = num_moe_layers + shared_config.topk = topk + logger.info(f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}") + def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: if g_routing_capture_manager is not None: diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index 3f4fec319..98626e1d3 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -72,7 +72,7 @@ def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: weight = weight.float().cuda(self.device_id_) scale = weight.abs().max(dim=-1)[0] / 127 weight = weight / scale.reshape(-1, 1) - weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) + weight = torch.round(weight.clamp(min=-127, max=127)).to(dtype=torch.int8) output.weight.copy_(weight) output.weight_scale.copy_(scale) return diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 128423e6e..442968594 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -1,6 +1,7 @@ import os import math import ctypes +import base64 import numpy as np import time from .sampling_params import SamplingParams @@ -122,9 +123,6 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), - ("routing_data_num_moe_layers", ctypes.c_int), - ("routing_data_num_tokens", ctypes.c_int), - ("routing_data_topk", ctypes.c_int), ] def get_str(self): @@ -183,10 +181,6 @@ def init( self.stop_str_matched = False self.stop_str_matched_token_index = -1 - self.routing_data_num_moe_layers = 0 - self.routing_data_num_tokens = 0 - self.routing_data_topk = 0 - self.post_init() self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size @@ -240,25 +234,21 @@ def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, to shape = (num_moe_layers, num_tokens, topk) self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) self.shm_routing_data.create_shm() - self.routing_data_num_moe_layers = num_moe_layers - self.routing_data_num_tokens = num_tokens - self.routing_data_topk = topk return - def link_routing_data_shm_array(self): - if self.routing_data_num_moe_layers == 0: + def link_routing_data_shm_array(self, num_moe_layers: int, topk: int): + if num_moe_layers == 0: return service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" - shape = (self.routing_data_num_moe_layers, self.routing_data_num_tokens, self.routing_data_topk) + # num_tokens equals shm_cur_kv_len at the time of creation + shape = (num_moe_layers, self.shm_cur_kv_len, topk) self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) self.shm_routing_data.link_shm() return def get_routing_data(self): - if self.routing_data_num_moe_layers == 0 or not hasattr(self, "shm_routing_data"): - return None - if self.shm_routing_data is None: + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: return None return self.shm_routing_data.arr @@ -268,6 +258,29 @@ def close_routing_data_shm_array(self): self.shm_routing_data = None return + def get_routing_metadata(self, num_moe_layers: int, topk: int): + """Safely extract routing data and format for API response. + + Returns a dict with shape, dtype, and base64-encoded data, or None if unavailable. + """ + if num_moe_layers == 0 or topk == 0: + return None + try: + self.link_routing_data_shm_array(num_moe_layers, topk) + routing_data = self.get_routing_data() + if routing_data is None: + return None + return { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + except Exception as e: + logger.warning(f"Failed to read routing data for req {self.request_id}: {e}") + return None + finally: + self.close_routing_data_shm_array() + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 7d65cd233..d955aa6a8 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -497,7 +497,6 @@ def to_dict(self): "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, "print_eos_token": self.print_eos_token, "disable_prompt_cache": self.disable_prompt_cache, - "return_routed_experts": self.return_routed_experts, } def to_origin_dict(self): diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 059cd739f..8abf05ccb 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -159,3 +159,5 @@ class StartArgs: # multi_modal enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) + + enable_return_routed_experts: bool = field(default=False) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index bd92a6b7d..a0f074b21 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,7 +10,6 @@ import hashlib import datetime import pickle -import base64 from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -30,6 +29,7 @@ from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.basemodel.routing_manager import get_shared_routing_config from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage @@ -115,6 +115,9 @@ def __init__( # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + # Cache routing config for MoE expert routing data extraction + self._routing_config = get_shared_routing_config() if args.enable_return_routed_experts else None return async def _alloc_resource(self, items, md5sums, token_nums, datas): @@ -779,19 +782,12 @@ async def handle_loop(self): else: finish_status = FinishStatus(req.finish_status.status) - if req.sample_params.return_routed_experts and req.routing_data_num_moe_layers > 0: - try: - req.link_routing_data_shm_array() - routing_data = req.get_routing_data() - if routing_data is not None: - metadata["routed_experts"] = { - "shape": list(routing_data.shape), - "dtype": str(routing_data.dtype), - "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), - } - req.close_routing_data_shm_array() - except Exception as e: - logger.warning(f"Failed to read routing data for req {req_id}: {e}") + if self._routing_config is not None and self._routing_config.is_initialized(): + routing_meta = req.get_routing_metadata( + self._routing_config.num_moe_layers, self._routing_config.topk + ) + if routing_meta is not None: + metadata["routed_experts"] = routing_meta token_list.append((req_id, text, metadata, finish_status)) else: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 276f97c18..2ccddfb87 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -114,16 +114,25 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs - def _extract_routing_data(self, req: "InferReq"): + def _extract_routing_data(self, req: "InferReq", sync: bool = True): + """Extract MoE routing data for a completed request. + + Args: + req: The inference request to extract routing data for. + sync: If True, synchronize CUDA events before extraction. Set to False + when processing multiple requests in batch after calling + g_routing_capture_manager.sync_events() once. + """ mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] num_moe_layers = g_routing_capture_manager.num_moe_layers topk = g_routing_capture_manager.topk num_tokens = req.cur_kv_len - logger.debug(f"R3: Extracting routing for req {req.req_id}: {num_moe_layers}x{num_tokens}x{topk}") - routing_data = g_routing_capture_manager.extract_for_request(mem_indexes.cpu()) + if sync: + routing_data = g_routing_capture_manager.extract_for_request(mem_indexes.cpu()) + else: + routing_data = g_routing_capture_manager.extract_for_request_no_sync(mem_indexes.cpu()) req.shm_req.create_routing_data_shm_array(num_moe_layers, num_tokens, topk) req.shm_req.shm_routing_data.arr[:] = routing_data - logger.debug(f"R3: Successfully extracted routing data for req {req.req_id}") def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: @@ -161,6 +170,11 @@ def _filter(self, finished_request_ids: List[int]): if len(finished_request_ids) == 0: return + # Optimization: sync CUDA events once for batch routing data extraction + need_routing_data = g_routing_capture_manager is not None + if need_routing_data: + g_routing_capture_manager.sync_events() + free_req_index = [] free_token_index = [] for request_id in finished_request_ids: @@ -168,7 +182,8 @@ def _filter(self, finished_request_ids: List[int]): if self.args.diverse_mode: req.clear_master_slave_state() - self._extract_routing_data(req) + if need_routing_data: + self._extract_routing_data(req, sync=False) self.free_a_req_mem(free_token_index, req) From 6ed1951551ef7ce3adc4451910895da1aa946f9b Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 27 Jan 2026 13:26:37 +0000 Subject: [PATCH 3/3] clean --- ...026-01-27-remove-moe-model-mixin-design.md | 0 lightllm/common/basemodel/routing_manager.py | 9 -- lightllm/models/llama/model.py | 1 - lightllm/server/core/objs/req.py | 10 +- .../server/router/model_infer/infer_batch.py | 10 -- scripts/run_e2e_r3_test.sh | 109 ------------------ test_r3.py => test/test_api/test_r3.py | 17 +-- .../basemodel/test_routing_capture_manager.py | 20 +--- 8 files changed, 4 insertions(+), 172 deletions(-) delete mode 100644 docs/plans/2026-01-27-remove-moe-model-mixin-design.md delete mode 100755 scripts/run_e2e_r3_test.sh rename test_r3.py => test/test_api/test_r3.py (85%) diff --git a/docs/plans/2026-01-27-remove-moe-model-mixin-design.md b/docs/plans/2026-01-27-remove-moe-model-mixin-design.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py index 2bc55d3be..f9bdc04fc 100644 --- a/lightllm/common/basemodel/routing_manager.py +++ b/lightllm/common/basemodel/routing_manager.py @@ -14,7 +14,6 @@ class SharedRoutingConfig: def __init__(self): service_name = get_unique_server_name() - # Shape: [num_moe_layers, topk] self._shm = SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32) @property @@ -37,7 +36,6 @@ def is_initialized(self) -> bool: return self.num_moe_layers > 0 and self.topk > 0 -# Global shared routing config (lazy initialized) _shared_routing_config: Optional[SharedRoutingConfig] = None @@ -49,7 +47,6 @@ def get_shared_routing_config() -> SharedRoutingConfig: return _shared_routing_config -# MoE layer counter for auto-incrementing moe_layer_index _moe_layer_counter: int = 0 @@ -147,10 +144,6 @@ def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray: return self.cpu_buffer[:, mem_indexes, :].numpy() def extract_for_request_no_sync(self, mem_indexes: torch.Tensor) -> np.ndarray: - """Extract routing data without synchronizing events. - - Call sync_events() once before using this method in a batch. - """ return self.cpu_buffer[:, mem_indexes, :].numpy() @@ -181,7 +174,6 @@ def init_routing_capture(model) -> None: if not getattr(model.args, "enable_return_routed_experts", False): return - # Only create routing capture manager on rank 0 if get_current_rank_in_dp() != 0: logger.info("Skipping routing capture initialization on non-zero rank") return @@ -212,7 +204,6 @@ def init_routing_capture(model) -> None: enable_overlap=enable_overlap, ) - # Set shared routing config for cross-process access shared_config = get_shared_routing_config() shared_config.num_moe_layers = num_moe_layers shared_config.topk = topk diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index f5a66a6c3..63561a6b5 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -85,7 +85,6 @@ def _init_custom(self): super()._init_custom() def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): - """Initialize rotary embeddings based on scaling type.""" if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 442968594..6066acc98 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -241,7 +241,6 @@ def link_routing_data_shm_array(self, num_moe_layers: int, topk: int): return service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" - # num_tokens equals shm_cur_kv_len at the time of creation shape = (num_moe_layers, self.shm_cur_kv_len, topk) self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) self.shm_routing_data.link_shm() @@ -259,14 +258,11 @@ def close_routing_data_shm_array(self): return def get_routing_metadata(self, num_moe_layers: int, topk: int): - """Safely extract routing data and format for API response. - - Returns a dict with shape, dtype, and base64-encoded data, or None if unavailable. - """ if num_moe_layers == 0 or topk == 0: return None try: - self.link_routing_data_shm_array(num_moe_layers, topk) + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + self.link_routing_data_shm_array(num_moe_layers, topk) routing_data = self.get_routing_data() if routing_data is None: return None @@ -278,8 +274,6 @@ def get_routing_metadata(self, num_moe_layers: int, topk: int): except Exception as e: logger.warning(f"Failed to read routing data for req {self.request_id}: {e}") return None - finally: - self.close_routing_data_shm_array() def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 2ccddfb87..955617e1f 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -115,14 +115,6 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs def _extract_routing_data(self, req: "InferReq", sync: bool = True): - """Extract MoE routing data for a completed request. - - Args: - req: The inference request to extract routing data for. - sync: If True, synchronize CUDA events before extraction. Set to False - when processing multiple requests in batch after calling - g_routing_capture_manager.sync_events() once. - """ mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] num_moe_layers = g_routing_capture_manager.num_moe_layers topk = g_routing_capture_manager.topk @@ -170,7 +162,6 @@ def _filter(self, finished_request_ids: List[int]): if len(finished_request_ids) == 0: return - # Optimization: sync CUDA events once for batch routing data extraction need_routing_data = g_routing_capture_manager is not None if need_routing_data: g_routing_capture_manager.sync_events() @@ -610,7 +601,6 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): - # Extract routing data before setting finish_status so HTTP server sees it g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/scripts/run_e2e_r3_test.sh b/scripts/run_e2e_r3_test.sh deleted file mode 100755 index 3100f18a7..000000000 --- a/scripts/run_e2e_r3_test.sh +++ /dev/null @@ -1,109 +0,0 @@ -#!/bin/bash -# E2E Test Script for R3 Routing Capture Feature -# -# This script starts a LightLLM server with routing capture enabled, -# runs the client test, and verifies the results. -# -# Requirements: -# - A MoE model (DeepSeek-V2/V3, Qwen-MoE, Mixtral, etc.) -# - At least 1 GPU with sufficient memory -# - LightLLM installed -# -# Usage: -# ./scripts/run_e2e_r3_test.sh /path/to/moe/model [--tp N] - -set -e - -MODEL_DIR="${1:-}" -TP="${2:-1}" -PORT=8765 - -if [ -z "$MODEL_DIR" ]; then - echo "Usage: $0 /path/to/moe/model [--tp N]" - echo "" - echo "Example:" - echo " $0 /models/DeepSeek-V3 --tp 8" - echo " $0 /models/Qwen-MoE-A14B --tp 4" - exit 1 -fi - -if [ ! -d "$MODEL_DIR" ]; then - echo "ERROR: Model directory not found: $MODEL_DIR" - exit 1 -fi - -echo "==========================================" -echo "R3 E2E Test: Routing Capture Feature" -echo "==========================================" -echo "Model: $MODEL_DIR" -echo "TP: $TP" -echo "Port: $PORT" -echo "" - -# Kill any existing server on the port -pkill -f "lightllm.server.api_server.*--port $PORT" 2>/dev/null || true -sleep 2 - -# Start server in background -echo "Starting LightLLM server..." -python -m lightllm.server.api_server \ - --model_dir "$MODEL_DIR" \ - --tp "$TP" \ - --port "$PORT" \ - --enable_return_routed_experts \ - --max_total_token_num 8000 \ - --batch_max_tokens 4000 \ - > /tmp/lightllm_r3_test.log 2>&1 & - -SERVER_PID=$! -echo "Server PID: $SERVER_PID" -echo "Log: /tmp/lightllm_r3_test.log" - -# Wait for server to be ready -echo "Waiting for server to be ready..." -MAX_WAIT=300 -WAITED=0 -while [ $WAITED -lt $MAX_WAIT ]; do - if curl -s "http://localhost:$PORT/health" > /dev/null 2>&1; then - echo "Server is ready!" - break - fi - sleep 5 - WAITED=$((WAITED + 5)) - echo " Waited ${WAITED}s..." -done - -if [ $WAITED -ge $MAX_WAIT ]; then - echo "ERROR: Server failed to start within ${MAX_WAIT}s" - echo "Server log:" - tail -50 /tmp/lightllm_r3_test.log - kill $SERVER_PID 2>/dev/null || true - exit 1 -fi - -# Run client test -echo "" -echo "Running R3 client test..." -echo "==========================================" -python test_r3.py --url "http://localhost:$PORT" -TEST_RESULT=$? - -# Cleanup -echo "" -echo "Stopping server..." -kill $SERVER_PID 2>/dev/null || true -wait $SERVER_PID 2>/dev/null || true - -# Report result -echo "" -echo "==========================================" -if [ $TEST_RESULT -eq 0 ]; then - echo "E2E TEST PASSED!" -else - echo "E2E TEST FAILED!" - echo "Server log (last 30 lines):" - tail -30 /tmp/lightllm_r3_test.log -fi -echo "==========================================" - -exit $TEST_RESULT diff --git a/test_r3.py b/test/test_api/test_r3.py similarity index 85% rename from test_r3.py rename to test/test_api/test_r3.py index 14157fab5..0ad1b67c6 100644 --- a/test_r3.py +++ b/test/test_api/test_r3.py @@ -1,13 +1,3 @@ -""" -R3 Client Test: Tests the routing capture export feature. - -This test requires a running LightLLM server with: -- A MoE model (e.g., DeepSeek-V2/V3) -- --enable_return_routed_experts flag - -Usage: - python test_r3.py [--url URL] -""" import sys import argparse import requests @@ -16,7 +6,6 @@ def test_routing_export(url: str = "http://localhost:8000"): - """Test the routing export feature.""" print(f"Testing routing export at {url}") print("-" * 50) @@ -50,7 +39,6 @@ def test_routing_export(url: str = "http://localhost:8000"): res = response.json() print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") - # Check for routed_experts in response if "routed_experts" not in res or not res["routed_experts"]: print("\nWARNING: No routed_experts in response.") print("This could mean:") @@ -59,7 +47,6 @@ def test_routing_export(url: str = "http://localhost:8000"): print(" - The routing capture manager was not initialized") return False - # Decode routed_experts from base64 routing_info = res["routed_experts"] shape = routing_info["shape"] dtype = np.dtype(routing_info["dtype"]) @@ -69,19 +56,17 @@ def test_routing_export(url: str = "http://localhost:8000"): print(f"\n{'=' * 50}") print("ROUTING CAPTURE SUCCESS!") print(f"{'=' * 50}") - print(f"Shape: {shape} # [num_moe_layers, num_tokens, topk]") + print(f"Shape: {shape}") print(f"Dtype: {dtype}") print(f"Num MoE layers: {shape[0]}") print(f"Num tokens: {shape[1]}") print(f"Top-K: {shape[2]}") - # Show sample of routing data print(f"\nSample routing (first layer, first 5 tokens):") num_tokens_to_show = min(5, shape[1]) for i in range(num_tokens_to_show): print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") - # Validate data if np.all(routing_array == 0): print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") return False diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py index 8d5a1d0bf..ef9390637 100644 --- a/unit_tests/common/basemodel/test_routing_capture_manager.py +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -4,7 +4,6 @@ def test_moe_layer_counter(): - """Counter increments and resets correctly.""" from lightllm.common.basemodel.routing_manager import ( reset_moe_layer_counter, get_next_moe_layer_index, @@ -25,10 +24,7 @@ def test_moe_layer_counter(): class TestRoutingCaptureManager: - """Tests for the redesigned RoutingCaptureManager.""" - def test_capture_explicit_layer_index(self): - """Capture stores data at explicit moe_layer_index.""" from lightllm.common.basemodel.routing_manager import RoutingCaptureManager manager = RoutingCaptureManager( @@ -40,15 +36,12 @@ def test_capture_explicit_layer_index(self): enable_overlap=False, ) - # Capture at layer 2 (not sequential) topk_ids = torch.randint(0, 64, (10, 8), device="cuda") manager.capture(moe_layer_index=2, topk_ids=topk_ids) - # Verify data is at layer 2, not layer 0 assert torch.equal(manager.gpu_buffer[0, 2, :10, :], topk_ids.to(manager.dtype)) def test_double_buffer_overlap_mode(self): - """Double buffer prevents race condition in overlap mode.""" from lightllm.common.basemodel.routing_manager import RoutingCaptureManager manager = RoutingCaptureManager( @@ -60,24 +53,19 @@ def test_double_buffer_overlap_mode(self): enable_overlap=True, ) - # Should have 2 buffer slots assert manager.num_slots == 2 assert manager.gpu_buffer.shape[0] == 2 - # Capture to slot 0 (microbatch_index=0) ids_0 = torch.ones((5, 4), dtype=torch.int64, device="cuda") manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) - # Capture to slot 1 (microbatch_index=1) ids_1 = torch.ones((5, 4), dtype=torch.int64, device="cuda") * 2 manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) - # Both slots have different data assert manager.gpu_buffer[0, 0, 0, 0].item() == 1 assert manager.gpu_buffer[1, 0, 0, 0].item() == 2 def test_flush_and_extract(self): - """Flush transfers data to CPU, extract retrieves by mem_index.""" from lightllm.common.basemodel.routing_manager import RoutingCaptureManager manager = RoutingCaptureManager( @@ -89,27 +77,22 @@ def test_flush_and_extract(self): enable_overlap=False, ) - # Capture some data (microbatch_index defaults to 0) topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], device="cuda") manager.capture(moe_layer_index=0, topk_ids=topk_ids) manager.capture(moe_layer_index=1, topk_ids=topk_ids + 10) - # Flush to mem_indexes 10 and 11 mem_indexes = torch.tensor([10, 11], device="cuda") manager.flush_to_cpu_async(mem_indexes, microbatch_index=0) - # Extract result = manager.extract_for_request(mem_indexes.cpu()) - assert result.shape == (2, 2, 4) # [layers, tokens, topk] + assert result.shape == (2, 2, 4) assert result[0, 0, 0] == 1 assert result[1, 0, 0] == 11 def test_dtype_selection(self): - """Uses int8 for <=127 experts, int16 otherwise.""" from lightllm.common.basemodel.routing_manager import RoutingCaptureManager - # Small expert count -> int8 manager_small = RoutingCaptureManager( num_moe_layers=1, topk=2, @@ -120,7 +103,6 @@ def test_dtype_selection(self): ) assert manager_small.dtype == torch.int8 - # Large expert count -> int16 manager_large = RoutingCaptureManager( num_moe_layers=1, topk=2,