Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,7 @@ def __init__(self, model_runner: ModelRunner):
def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata = self._forward_metadata(forward_batch)
self.forward_metadata = Mamba2Metadata.prepare_mixed(
metadata.query_start_loc,
metadata.mamba_cache_indices,
metadata,
self.mamba_chunk_size,
forward_batch,
)
Expand All @@ -762,8 +761,12 @@ def init_forward_metadata_capture_cuda_graph(
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode, spec_info)
draft_token_num = spec_info.draft_token_num if spec_info is not None else 1
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
metadata,
seq_lens,
is_target_verify=forward_mode.is_target_verify(),
draft_token_num=draft_token_num,
)

def init_forward_metadata_replay_cuda_graph(
Expand All @@ -780,8 +783,12 @@ def init_forward_metadata_replay_cuda_graph(
metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
draft_token_num = spec_info.draft_token_num if spec_info is not None else 1
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
metadata,
seq_lens,
is_target_verify=forward_mode.is_target_verify(),
draft_token_num=draft_token_num,
)

def forward(
Expand Down
139 changes: 103 additions & 36 deletions python/sglang/srt/layers/attention/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
from sglang.srt.layers.attention.mamba.ops import (
Expand Down Expand Up @@ -401,23 +400,28 @@ def forward(

num_prefills = metadata.num_prefills # request count
num_decodes = metadata.num_decodes # token count (=request)
num_decode_tokens = (
num_decodes * metadata.draft_token_num
if metadata.is_target_verify
else num_decodes
)
num_prefill_tokens = metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
num_actual_tokens = num_prefill_tokens + num_decode_tokens
assert num_actual_tokens == projected_states.shape[0]

# NOTE: V0 put prefill before decode
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
[num_prefill_tokens, num_decode_tokens],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
[num_prefill_tokens, num_decode_tokens],
dim=0,
)
# Split along batch dimension
Expand All @@ -441,7 +445,7 @@ def forward(
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
[num_prefill_tokens, num_decode_tokens],
dim=0,
)

Expand Down Expand Up @@ -520,20 +524,52 @@ def forward(

# Process decode requests
if has_decode:
is_target_verify = metadata.is_target_verify

# 2. Convolution sequence transformation
ccu = (
causal_conv1d_update
if not use_triton_causal_conv
else causal_conv1d_update_triton
)
hidden_states_B_C_d = ccu(
hidden_states_B_C_d,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
)
if is_target_verify:
assert (
use_triton_causal_conv
), "Speculative decoding requires use_triton_causal_conv=True for intermediate state support"
assert isinstance(
layer_cache, MambaPool.SpeculativeState
), "layer_cache must be SpeculativeState for speculative decoding"
draft_token_num = metadata.draft_token_num

# Reshape for batch processing
hidden_states_B_C_d_reshaped = hidden_states_B_C_d.view(
num_decodes, draft_token_num, -1
).transpose(1, 2)

hidden_states_B_C_d_processed = causal_conv1d_update_triton(
hidden_states_B_C_d_reshaped,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d[:num_decodes],
intermediate_conv_window=layer_cache.intermediate_conv_window[0],
retrieve_next_token=metadata.retrieve_next_token,
retrieve_next_sibling=metadata.retrieve_next_sibling,
retrieve_parent_token=metadata.retrieve_parent_token,
)
hidden_states_B_C_d = hidden_states_B_C_d_processed.transpose(
1, 2
).view(num_decode_tokens, -1)
else:
ccu = (
causal_conv1d_update
if not use_triton_causal_conv
else causal_conv1d_update_triton
)
hidden_states_B_C_d = ccu(
hidden_states_B_C_d,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
)

hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)

Expand All @@ -553,24 +589,55 @@ def forward(
-1, self.num_heads // self.tp_size, self.head_dim
)

# - the hidden is reshaped into (bs, num_heads, head_dim)
# - layer_state.ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
selective_state_update(
ssm_state,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
if is_target_verify:
selective_state_update(
ssm_state,
hidden_states_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
dt_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
A_d,
B_d.view(num_decodes, draft_token_num, n_groups, -1),
C_d.view(num_decodes, draft_token_num, n_groups, -1),
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d[:num_decodes],
out=preallocated_ssm_out_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
disable_state_update=True,
intermediate_states_buffer=layer_cache.intermediate_ssm,
cache_steps=draft_token_num,
retrieve_parent_token=metadata.retrieve_parent_token,
)
else:
selective_state_update(
ssm_state,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)

# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
Expand Down
44 changes: 35 additions & 9 deletions python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class ForwardMetadata:
retrieve_next_token: Optional[torch.Tensor] = None
retrieve_next_sibling: Optional[torch.Tensor] = None
retrieve_parent_token: Optional[torch.Tensor] = None
is_target_verify: bool = False
draft_token_num: int = 1


@dataclass(kw_only=True)
Expand Down Expand Up @@ -141,31 +143,45 @@ def _query_start_loc_to_chunk_indices_offsets(

@staticmethod
def prepare_decode(
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
forward_metadata: ForwardMetadata,
seq_lens: torch.Tensor,
*,
is_target_verify: bool,
draft_token_num: int,
) -> "Mamba2Metadata":
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
query_start_loc=forward_metadata.query_start_loc,
mamba_cache_indices=forward_metadata.mamba_cache_indices,
retrieve_next_token=forward_metadata.retrieve_next_token,
retrieve_next_sibling=forward_metadata.retrieve_next_sibling,
retrieve_parent_token=forward_metadata.retrieve_parent_token,
num_decodes=len(seq_lens),
num_prefills=0,
num_prefill_tokens=0,
is_target_verify=is_target_verify,
draft_token_num=draft_token_num,
)

@classmethod
def prepare_mixed(
cls,
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
forward_metadata: ForwardMetadata,
chunk_size: int,
forward_batch: ForwardBatch,
) -> "Mamba2Metadata":
"""This path cannot run with CUDA graph, as it contains extend requests."""
if forward_batch.extend_num_tokens is None:
draft_token_num = (
forward_batch.spec_info.draft_token_num
if forward_batch.spec_info is not None
else 1
)
return cls.prepare_decode(
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
forward_metadata,
forward_batch.seq_lens,
is_target_verify=forward_batch.forward_mode.is_target_verify(),
draft_token_num=draft_token_num,
)
num_prefills = len(forward_batch.extend_seq_lens)
num_prefill_tokens = forward_batch.extend_num_tokens
Expand All @@ -176,7 +192,7 @@ def prepare_mixed(
has_initial_states = context_lens_tensor > 0
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()

query_start_loc = query_start_loc[: num_prefills + 1]
query_start_loc = forward_metadata.query_start_loc[: num_prefills + 1]
seq_idx = torch.repeat_interleave(
torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device
Expand All @@ -197,12 +213,22 @@ def prepare_mixed(
)
)

draft_token_num = (
getattr(forward_batch.spec_info, "draft_token_num", 1)
if forward_batch.spec_info is not None
else 1
)
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
mamba_cache_indices=forward_metadata.mamba_cache_indices,
retrieve_next_token=forward_metadata.retrieve_next_token,
retrieve_next_sibling=forward_metadata.retrieve_next_sibling,
retrieve_parent_token=forward_metadata.retrieve_parent_token,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
is_target_verify=forward_batch.forward_mode.is_target_verify(),
draft_token_num=draft_token_num,
mixed_metadata=cls.MixedMetadata(
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
Expand Down
Loading
Loading