diff --git a/.gitignore b/.gitignore index 63408699f..3fb49db8b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e6405e4d7..36968471f 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -164,6 +165,7 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): + reset_moe_layer_counter() self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03..2e723de44 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -13,6 +13,7 @@ from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args from lightllm.utils.dist_utils import get_global_world_size, get_global_rank from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.routing_manager import get_next_moe_layer_index logger = init_logger(__name__) @@ -35,6 +36,7 @@ def __init__( network_config: Dict[str, Any] = None, ) -> None: super().__init__(data_type=data_type) + self.moe_layer_index = get_next_moe_layer_index() self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name self.w3_weight_name = up_proj_name @@ -130,6 +132,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +148,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + moe_layer_index=self.moe_layer_index, + microbatch_index=microbatch_index, ) def low_latency_dispatch( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b..5627a5925 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -8,6 +8,7 @@ from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.routing_manager import g_routing_capture_manager logger = init_logger(__name__) @@ -144,10 +145,14 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._router(router_logits, top_k) + if g_routing_capture_manager is not None: + g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac18..1c93cb13d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 8bcdb4bf9..77a9b1a45 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -3,6 +3,7 @@ from lightllm.common.quantization.no_quant import WeightPack from lightllm.common.quantization.quantize_method import QuantizationMethod from .base_impl import FuseMoeBaseImpl +from lightllm.common.basemodel.routing_manager import g_routing_capture_manager class FuseMoeTriton(FuseMoeBaseImpl): @@ -124,6 +125,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -136,6 +139,10 @@ def __call__( num_expert_group=num_expert_group, scoring_func=scoring_func, ) + + if g_routing_capture_manager is not None and moe_layer_index is not None: + g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index) + output = self._fused_experts( input_tensor=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 000000000..f9bdc04fc --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,215 @@ +import torch +import numpy as np +from typing import Optional +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_rank_in_dp +from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray +from lightllm.utils.envs_utils import get_unique_server_name + +logger = init_logger(__name__) + + +class SharedRoutingConfig: + """Shared MoE routing configuration across processes.""" + + def __init__(self): + service_name = get_unique_server_name() + self._shm = SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32) + + @property + def num_moe_layers(self) -> int: + return int(self._shm.arr[0]) + + @num_moe_layers.setter + def num_moe_layers(self, value: int): + self._shm.arr[0] = value + + @property + def topk(self) -> int: + return int(self._shm.arr[1]) + + @topk.setter + def topk(self, value: int): + self._shm.arr[1] = value + + def is_initialized(self) -> bool: + return self.num_moe_layers > 0 and self.topk > 0 + + +_shared_routing_config: Optional[SharedRoutingConfig] = None + + +def get_shared_routing_config() -> SharedRoutingConfig: + """Get or create the shared routing config.""" + global _shared_routing_config + if _shared_routing_config is None: + _shared_routing_config = SharedRoutingConfig() + return _shared_routing_config + + +_moe_layer_counter: int = 0 + + +def reset_moe_layer_counter() -> None: + global _moe_layer_counter + _moe_layer_counter = 0 + + +def get_next_moe_layer_index() -> int: + global _moe_layer_counter + idx = _moe_layer_counter + _moe_layer_counter += 1 + return idx + + +def get_moe_layer_count() -> int: + return _moe_layer_counter + + +class RoutingCaptureManager: + """Captures MoE routing decisions""" + + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + batch_max_tokens: int, + kv_cache_size: int, + enable_overlap: bool = False, + ): + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.batch_max_tokens = batch_max_tokens + self.kv_cache_size = kv_cache_size + + self.dtype = torch.int8 if num_experts <= 127 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.int8 else 2 + + self.num_slots = 2 if enable_overlap else 1 + + gpu_buffer_size = self.num_slots * num_moe_layers * batch_max_tokens * topk * dtype_bytes + self.gpu_buffer = torch.zeros( + (self.num_slots, num_moe_layers, batch_max_tokens, topk), + dtype=self.dtype, + device="cuda", + ) + + cpu_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.cpu_buffer = torch.zeros( + (num_moe_layers, kv_cache_size, topk), + dtype=self.dtype, + device="cpu", + pin_memory=True, + ) + + self.flush_streams = [torch.cuda.Stream() for _ in range(self.num_slots)] + self.flush_events = [torch.cuda.Event() for _ in range(self.num_slots)] + + dtype_name = "int8" if self.dtype == torch.int8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"slots={self.num_slots}, GPU={gpu_buffer_size / 1024 / 1024:.2f}MB, " + f"CPU={cpu_buffer_size / 1024 / 1024:.2f}MB, dtype={dtype_name}" + ) + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + num_tokens = topk_ids.shape[0] + self.gpu_buffer[microbatch_index, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype) + + def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None: + num_tokens = mem_indexes.shape[0] + if num_tokens == 0: + return + + slot = microbatch_index % self.num_slots + stream = self.flush_streams[slot] + event = self.flush_events[slot] + + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream): + cpu_indexes = mem_indexes.cpu() + self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu() + event.record() + + def sync_events(self) -> None: + """Synchronize all flush events. Call once before batch extraction.""" + for event in self.flush_events: + event.synchronize() + + def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray: + self.sync_events() + return self.cpu_buffer[:, mem_indexes, :].numpy() + + def extract_for_request_no_sync(self, mem_indexes: torch.Tensor) -> np.ndarray: + return self.cpu_buffer[:, mem_indexes, :].numpy() + + +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + batch_max_tokens: int, + kv_cache_size: int, + enable_overlap: bool = False, +) -> None: + global g_routing_capture_manager + assert g_routing_capture_manager is None, "RoutingCaptureManager already exists" + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=batch_max_tokens, + kv_cache_size=kv_cache_size, + enable_overlap=enable_overlap, + ) + + +def init_routing_capture(model) -> None: + if not getattr(model.args, "enable_return_routed_experts", False): + return + + if get_current_rank_in_dp() != 0: + logger.info("Skipping routing capture initialization on non-zero rank") + return + + num_moe_layers = get_moe_layer_count() + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled." + ) + return + + num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) + topk = model.config.get("num_experts_per_tok", 0) + assert num_experts > 0 and topk > 0 + enable_overlap = getattr(model.args, "enable_decode_microbatch_overlap", False) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=model.max_total_token_num, + kv_cache_size=model.mem_manager.size + 1, + enable_overlap=enable_overlap, + ) + + shared_config = get_shared_routing_config() + shared_config.num_moe_layers = num_moe_layers + shared_config.topk = topk + logger.info(f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}") + + +def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + if g_routing_capture_manager is not None: + g_routing_capture_manager.flush_to_cpu_async(mem_indexes, microbatch_index) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index e1e435cce..9b7a4274f 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -311,6 +311,7 @@ def _moe_ffn( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -337,6 +338,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index f0739a8a8..2b8c48e65 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -48,6 +49,7 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + init_routing_capture(self) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d80eefd16..e5672f821 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 9e9561eb2..f91e25690 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -28,3 +29,7 @@ def _init_att_backend(self): self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])( model=self ) + + def _init_custom(self): + super()._init_custom() + init_routing_capture(self) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index c104ebccc..63561a6b5 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -74,14 +74,17 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() - return - - if "rope_type" in rope_scaling: + elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") + super()._init_custom() + + def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -96,7 +99,6 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py deleted file mode 100644 index b0e27ac1d..000000000 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -# Pytorch version -# Triton version in progress -def topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output, - topk=2, -): - scores = torch.softmax(gating_output, dim=-1) - topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False) - return topk_weights, topk_ids - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - alloc_tensor_func=torch.empty, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 44e66cff2..a2968f5ab 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -19,25 +16,15 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_)) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, - ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - return fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e8..76b18667d 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -45,6 +46,7 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + init_routing_capture(self) return def _init_mem_manager(self): diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 71b16cb34..43d15cc54 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -128,6 +128,7 @@ def _moe_ffn( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) @@ -148,6 +149,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index a889609d7..dd1cc6112 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -1,4 +1,3 @@ -import os from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a505127..1dabc38a9 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -4,6 +4,7 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -26,3 +27,4 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + init_routing_capture(self) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index e49b0cc67..fb2f0094d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -629,4 +629,10 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f5..5abd90815 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -53,6 +53,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +79,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +105,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f489aac9c..6066acc98 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -1,6 +1,7 @@ import os import math import ctypes +import base64 import numpy as np import time from .sampling_params import SamplingParams @@ -227,6 +228,53 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int): + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_moe_layers, num_tokens, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) + self.shm_routing_data.create_shm() + return + + def link_routing_data_shm_array(self, num_moe_layers: int, topk: int): + if num_moe_layers == 0: + return + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_moe_layers, self.shm_cur_kv_len, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) + self.shm_routing_data.link_shm() + return + + def get_routing_data(self): + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + return None + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.close_shm() + self.shm_routing_data = None + return + + def get_routing_metadata(self, num_moe_layers: int, topk: int): + if num_moe_layers == 0 or topk == 0: + return None + try: + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + self.link_routing_data_shm_array(num_moe_layers, topk) + routing_data = self.get_routing_data() + if routing_data is None: + return None + return { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + except Exception as e: + logger.warning(f"Failed to read routing data for req {self.request_id}: {e}") + return None + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 059cd739f..8abf05ccb 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -159,3 +159,5 @@ class StartArgs: # multi_modal enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) + + enable_return_routed_experts: bool = field(default=False) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 212e037e9..a0f074b21 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -29,6 +29,7 @@ from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.basemodel.routing_manager import get_shared_routing_config from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage @@ -114,6 +115,9 @@ def __init__( # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + # Cache routing config for MoE expert routing data extraction + self._routing_config = get_shared_routing_config() if args.enable_return_routed_experts else None return async def _alloc_resource(self, items, md5sums, token_nums, datas): @@ -686,6 +690,11 @@ async def recycle_resource_loop(self): for req_status in release_req_status: self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) for req in req_status.group_req_objs.shm_req_objs: + if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) @@ -773,6 +782,13 @@ async def handle_loop(self): else: finish_status = FinishStatus(req.finish_status.status) + if self._routing_config is not None and self._routing_config.is_initialized(): + routing_meta = req.get_routing_metadata( + self._routing_config.num_moe_layers, self._routing_config.topk + ) + if routing_meta is not None: + metadata["routed_experts"] = routing_meta + token_list.append((req_id, text, metadata, finish_status)) else: break diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538..955617e1f 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -19,6 +19,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel.routing_manager import g_routing_capture_manager logger = init_logger(__name__) @@ -113,6 +114,18 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq", sync: bool = True): + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + num_moe_layers = g_routing_capture_manager.num_moe_layers + topk = g_routing_capture_manager.topk + num_tokens = req.cur_kv_len + if sync: + routing_data = g_routing_capture_manager.extract_for_request(mem_indexes.cpu()) + else: + routing_data = g_routing_capture_manager.extract_for_request_no_sync(mem_indexes.cpu()) + req.shm_req.create_routing_data_shm_array(num_moe_layers, num_tokens, topk) + req.shm_req.shm_routing_data.arr[:] = routing_data + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -149,12 +162,20 @@ def _filter(self, finished_request_ids: List[int]): if len(finished_request_ids) == 0: return + need_routing_data = g_routing_capture_manager is not None + if need_routing_data: + g_routing_capture_manager.sync_events() + free_req_index = [] free_token_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() + + if need_routing_data: + self._extract_routing_data(req, sync=False) + self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) @@ -580,6 +601,7 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 805c9b8e5..bab305a53 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -40,6 +40,7 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.common.basemodel.routing_manager import flush_routing_capture from .multi_level_kv_cache import MultiLevelKvCacheModule @@ -794,6 +795,9 @@ def _sample_and_scatter_token( ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _flush_routing_after_sample(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + flush_routing_capture(mem_indexes, microbatch_index) + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224eb..969964a2a 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -118,6 +118,7 @@ def prefill_normal( b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu, mask_func=self.prefill_mask_func, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -156,6 +157,7 @@ def decode_normal( is_prefill=False, mask_func=self.decode_mask_func, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -195,6 +197,7 @@ def prefill_mtp( b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu, mask_func=self.prefill_mask_func, ) + self._flush_routing_after_sample(model_input.mem_indexes) # mtp kv fill self._draft_prefill_forward( model_input=model_input, model_output=model_output, next_token_ids=next_token_ids @@ -279,6 +282,7 @@ def decode_mtp( next_token_ids=next_token_ids, mask=accepted_index == 1, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb62..1d93ea116 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -77,6 +77,7 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq next_token_ids=next_token_ids, next_token_logprobs=next_token_logprobs ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e7..877264e6e 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -155,6 +155,7 @@ def prefill_normal( b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu[:run_reqs_num], mask_func=None, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -197,6 +198,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq is_prefill=False, mask_func=None, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -263,6 +265,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer b_prefill_has_output_cpu=b_has_out_cpu, mask_func=None, ) + self._flush_routing_after_sample(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_after_sample(model_input1.mem_indexes, microbatch_index=1) sync_event = torch.cuda.Event() sync_event.record() @@ -326,6 +330,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe is_prefill=False, mask_func=None, ) + self._flush_routing_after_sample(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_after_sample(model_input1.mem_indexes, microbatch_index=1) sync_event = torch.cuda.Event() sync_event.record() @@ -374,6 +380,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] b_prefill_has_output_cpu=b_has_out_cpu, mask_func=None, ) + self._flush_routing_after_sample(model_input.mem_indexes) # mtp kv fill draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") @@ -471,6 +478,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): next_token_ids=next_token_ids, mask=accepted_index == 1, ) + self._flush_routing_after_sample(model_input.mem_indexes) sync_event = torch.cuda.Event() sync_event.record() @@ -652,6 +660,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I is_prefill=True, b_prefill_has_output_cpu=b_has_out_cpu, ) + self._flush_routing_after_sample(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_after_sample(model_input1.mem_indexes, microbatch_index=1) # spec prefill: MTP draft_model_input0, draft_model_input1 = model_input0, model_input1 @@ -789,6 +799,8 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf next_token_ids=next_token_ids, mask=accepted_index == 1, ) + self._flush_routing_after_sample(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_after_sample(model_input1.mem_indexes, microbatch_index=1) sync_event = torch.cuda.Event() sync_event.record() diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py new file mode 100644 index 000000000..0ad1b67c6 --- /dev/null +++ b/test/test_api/test_r3.py @@ -0,0 +1,84 @@ +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + "return_routed_experts": True, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype = np.dtype(routing_info["dtype"]) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape}") + print(f"Dtype: {dtype}") + print(f"Num MoE layers: {shape[0]}") + print(f"Num tokens: {shape[1]}") + print(f"Top-K: {shape[2]}") + + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = min(5, shape[1]) + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") + + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 000000000..ef9390637 --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,114 @@ +import pytest +import torch +import numpy as np + + +def test_moe_layer_counter(): + from lightllm.common.basemodel.routing_manager import ( + reset_moe_layer_counter, + get_next_moe_layer_index, + get_moe_layer_count, + ) + + reset_moe_layer_counter() + assert get_moe_layer_count() == 0 + + assert get_next_moe_layer_index() == 0 + assert get_next_moe_layer_index() == 1 + assert get_next_moe_layer_index() == 2 + assert get_moe_layer_count() == 3 + + reset_moe_layer_counter() + assert get_moe_layer_count() == 0 + assert get_next_moe_layer_index() == 0 + + +class TestRoutingCaptureManager: + def test_capture_explicit_layer_index(self): + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + batch_max_tokens=128, + kv_cache_size=1024, + enable_overlap=False, + ) + + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=2, topk_ids=topk_ids) + + assert torch.equal(manager.gpu_buffer[0, 2, :10, :], topk_ids.to(manager.dtype)) + + def test_double_buffer_overlap_mode(self): + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + batch_max_tokens=64, + kv_cache_size=256, + enable_overlap=True, + ) + + assert manager.num_slots == 2 + assert manager.gpu_buffer.shape[0] == 2 + + ids_0 = torch.ones((5, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + ids_1 = torch.ones((5, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + assert manager.gpu_buffer[0, 0, 0, 0].item() == 1 + assert manager.gpu_buffer[1, 0, 0, 0].item() == 2 + + def test_flush_and_extract(self): + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + batch_max_tokens=64, + kv_cache_size=256, + enable_overlap=False, + ) + + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids) + manager.capture(moe_layer_index=1, topk_ids=topk_ids + 10) + + mem_indexes = torch.tensor([10, 11], device="cuda") + manager.flush_to_cpu_async(mem_indexes, microbatch_index=0) + + result = manager.extract_for_request(mem_indexes.cpu()) + + assert result.shape == (2, 2, 4) + assert result[0, 0, 0] == 1 + assert result[1, 0, 0] == 11 + + def test_dtype_selection(self): + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager_small = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + batch_max_tokens=32, + kv_cache_size=128, + enable_overlap=False, + ) + assert manager_small.dtype == torch.int8 + + manager_large = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + batch_max_tokens=32, + kv_cache_size=128, + enable_overlap=False, + ) + assert manager_large.dtype == torch.int16