Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dist
.vscode
tmp/
requirements-musa.txt
CLAUDE.md
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter
from lightllm.common.kv_cache_mem_manager import MemoryManager
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
from lightllm.common.req_manager import ReqManager
Expand Down Expand Up @@ -164,6 +165,7 @@ def _init_quant(self):
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")

def _init_weights(self, start_layer_index=0):
reset_moe_layer_counter()
self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config)
self.trans_layers_weight = [
self.transformer_weight_class(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.routing_manager import get_next_moe_layer_index

logger = init_logger(__name__)

Expand All @@ -35,6 +36,7 @@ def __init__(
network_config: Dict[str, Any] = None,
) -> None:
super().__init__(data_type=data_type)
self.moe_layer_index = get_next_moe_layer_index()
self.w1_weight_name = gate_proj_name
self.w2_weight_name = down_proj_name
self.w3_weight_name = up_proj_name
Expand Down Expand Up @@ -130,6 +132,7 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
microbatch_index: int = 0,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
return self.fuse_moe_impl(
Expand All @@ -145,6 +148,8 @@ def experts(
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
moe_layer_index=self.moe_layer_index,
microbatch_index=microbatch_index,
)

def low_latency_dispatch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightllm.common.quantization import Quantcfg
from lightllm.common.quantization.quantize_method import QuantizationMethod
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager

logger = init_logger(__name__)

Expand Down Expand Up @@ -144,10 +145,14 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
microbatch_index: int = 0,
):

topk_weights, topk_ids = self._router(router_logits, top_k)

if g_routing_capture_manager is not None:
g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)

w1, w1_scale = self.w1
w2, w2_scale = self.w2
use_fp8_w8a8 = self.quant_method is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,7 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
moe_layer_index: Optional[int] = None,
microbatch_index: int = 0,
) -> torch.Tensor:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lightllm.common.quantization.no_quant import WeightPack
from lightllm.common.quantization.quantize_method import QuantizationMethod
from .base_impl import FuseMoeBaseImpl
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager


class FuseMoeTriton(FuseMoeBaseImpl):
Expand Down Expand Up @@ -124,6 +125,8 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
moe_layer_index: Optional[int] = None,
microbatch_index: int = 0,
):
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
Expand All @@ -136,6 +139,10 @@ def __call__(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)

if g_routing_capture_manager is not None and moe_layer_index is not None:
g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index)

output = self._fused_experts(
input_tensor=input_tensor,
w13=w13,
Expand Down
215 changes: 215 additions & 0 deletions lightllm/common/basemodel/routing_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import torch
import numpy as np
from typing import Optional
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_current_rank_in_dp
from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
from lightllm.utils.envs_utils import get_unique_server_name

logger = init_logger(__name__)


class SharedRoutingConfig:
"""Shared MoE routing configuration across processes."""

def __init__(self):
service_name = get_unique_server_name()
self._shm = SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32)

@property
def num_moe_layers(self) -> int:
return int(self._shm.arr[0])

@num_moe_layers.setter
def num_moe_layers(self, value: int):
self._shm.arr[0] = value

@property
def topk(self) -> int:
return int(self._shm.arr[1])

@topk.setter
def topk(self, value: int):
self._shm.arr[1] = value

def is_initialized(self) -> bool:
return self.num_moe_layers > 0 and self.topk > 0


_shared_routing_config: Optional[SharedRoutingConfig] = None


def get_shared_routing_config() -> SharedRoutingConfig:
"""Get or create the shared routing config."""
global _shared_routing_config
if _shared_routing_config is None:
_shared_routing_config = SharedRoutingConfig()
return _shared_routing_config


_moe_layer_counter: int = 0


def reset_moe_layer_counter() -> None:
global _moe_layer_counter
_moe_layer_counter = 0


def get_next_moe_layer_index() -> int:
global _moe_layer_counter
idx = _moe_layer_counter
_moe_layer_counter += 1
return idx


def get_moe_layer_count() -> int:
return _moe_layer_counter


class RoutingCaptureManager:
"""Captures MoE routing decisions"""

def __init__(
self,
num_moe_layers: int,
topk: int,
num_experts: int,
batch_max_tokens: int,
kv_cache_size: int,
enable_overlap: bool = False,
):
self.num_moe_layers = num_moe_layers
self.topk = topk
self.num_experts = num_experts
self.batch_max_tokens = batch_max_tokens
self.kv_cache_size = kv_cache_size

self.dtype = torch.int8 if num_experts <= 127 else torch.int16
dtype_bytes = 1 if self.dtype == torch.int8 else 2

self.num_slots = 2 if enable_overlap else 1

gpu_buffer_size = self.num_slots * num_moe_layers * batch_max_tokens * topk * dtype_bytes
self.gpu_buffer = torch.zeros(
(self.num_slots, num_moe_layers, batch_max_tokens, topk),
dtype=self.dtype,
device="cuda",
)

cpu_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes
self.cpu_buffer = torch.zeros(
(num_moe_layers, kv_cache_size, topk),
dtype=self.dtype,
device="cpu",
pin_memory=True,
)

self.flush_streams = [torch.cuda.Stream() for _ in range(self.num_slots)]
self.flush_events = [torch.cuda.Event() for _ in range(self.num_slots)]

dtype_name = "int8" if self.dtype == torch.int8 else "int16"
logger.info(
f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
f"slots={self.num_slots}, GPU={gpu_buffer_size / 1024 / 1024:.2f}MB, "
f"CPU={cpu_buffer_size / 1024 / 1024:.2f}MB, dtype={dtype_name}"
)

def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
num_tokens = topk_ids.shape[0]
self.gpu_buffer[microbatch_index, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype)

def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None:
num_tokens = mem_indexes.shape[0]
if num_tokens == 0:
return

slot = microbatch_index % self.num_slots
stream = self.flush_streams[slot]
event = self.flush_events[slot]

stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(stream):
cpu_indexes = mem_indexes.cpu()
self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu()
event.record()

def sync_events(self) -> None:
"""Synchronize all flush events. Call once before batch extraction."""
for event in self.flush_events:
event.synchronize()

def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray:
self.sync_events()
return self.cpu_buffer[:, mem_indexes, :].numpy()

def extract_for_request_no_sync(self, mem_indexes: torch.Tensor) -> np.ndarray:
return self.cpu_buffer[:, mem_indexes, :].numpy()


g_routing_capture_manager: Optional[RoutingCaptureManager] = None


def create_routing_capture_manager(
num_moe_layers: int,
topk: int,
num_experts: int,
batch_max_tokens: int,
kv_cache_size: int,
enable_overlap: bool = False,
) -> None:
global g_routing_capture_manager
assert g_routing_capture_manager is None, "RoutingCaptureManager already exists"
g_routing_capture_manager = RoutingCaptureManager(
num_moe_layers=num_moe_layers,
topk=topk,
num_experts=num_experts,
batch_max_tokens=batch_max_tokens,
kv_cache_size=kv_cache_size,
enable_overlap=enable_overlap,
)


def init_routing_capture(model) -> None:
if not getattr(model.args, "enable_return_routed_experts", False):
return

if get_current_rank_in_dp() != 0:
logger.info("Skipping routing capture initialization on non-zero rank")
return

num_moe_layers = get_moe_layer_count()
if num_moe_layers == 0:
logger.warning(
"enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled."
)
return

num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
topk = model.config.get("num_experts_per_tok", 0)
assert num_experts > 0 and topk > 0
enable_overlap = getattr(model.args, "enable_decode_microbatch_overlap", False)

logger.info(
f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}"
)

create_routing_capture_manager(
num_moe_layers=num_moe_layers,
topk=topk,
num_experts=num_experts,
batch_max_tokens=model.max_total_token_num,
kv_cache_size=model.mem_manager.size + 1,
enable_overlap=enable_overlap,
)

shared_config = get_shared_routing_config()
shared_config.num_moe_layers = num_moe_layers
shared_config.topk = topk
logger.info(f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}")


def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None:
if g_routing_capture_manager is not None:
g_routing_capture_manager.flush_to_cpu_async(mem_indexes, microbatch_index)
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def _moe_ffn(
use_grouped_topk=self.n_group,
topk_group=self.topk_group,
num_expert_group=self.n_group,
microbatch_index=infer_state.microbatch_index,
)

if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
Expand All @@ -337,6 +338,7 @@ def _moe_ffn_edp(
topk_group=self.topk_group,
num_expert_group=self.n_group,
is_prefill=infer_state.is_prefill,
microbatch_index=infer_state.microbatch_index,
)

if self.n_shared_experts is not None:
Expand Down
2 changes: 2 additions & 0 deletions lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num
from lightllm.distributed.communication_op import dist_group_manager
Expand Down Expand Up @@ -48,6 +49,7 @@ def _init_some_value(self):
def _init_custom(self):
self._init_to_get_yarn_rotary()
dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"])
init_routing_capture(self)

def _verify_params(self):
return super()._verify_params()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
microbatch_index=infer_state.microbatch_index,
)
return hidden_states.view(num_tokens, hidden_dim)

Expand Down
5 changes: 5 additions & 0 deletions lightllm/models/gpt_oss/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.models.registry import ModelRegistry
from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class
Expand All @@ -28,3 +29,7 @@ def _init_att_backend(self):
self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])(
model=self
)

def _init_custom(self):
super()._init_custom()
init_routing_capture(self)
Loading