Skip to content

Conversation

@roikoren755
Copy link
Contributor

Motivation

Enable running a model with a Mamba2 layer as the target model in a speculative decoding setup.

Modifications

Updated the selective_scan_update Triton kernel, so that it supports multiple tokens per sequence, writing intermediate cache states to a provided tensor, reading initial state from a different location for EAGLE top-k, and skipping writing final state to cache, similar to what is done in the fused_recurrent_gated_delta_rule_update Triton kernel.

Added support to the MambaMixer2 module for target verification batches.

Enabled speculative decoding support for Mamba2 models in the model runner.

Accuracy Tests

Running the few_shot_gsm8k test with 200 question, got the following results.

nvidia/NVIDIA-Nemotron-Nano-9B-v2 - branch, with standalone speculative decoding (5 speculative steps, EAGLE top-k 3 and 8 draft tokens, very bad draft model):

Accuracy: 0.885
Invalid: 0.005
Latency: 76.692 s
Output throughput: 316.615 token/s

nvidia/NVIDIA-Nemotron-Nano-9B-v2 - branch, without standalone speculative decoding:

Accuracy: 0.885
Invalid: 0.005
Latency: 37.193 s
Output throughput: 646.170 token/s

nvidia/NVIDIA-Nemotron-Nano-9B-v2 - main, without standalone speculative decoding:

Accuracy: 0.885
Invalid: 0.005
Latency: 37.123 s
Output throughput: 652.453 token/s

Benchmarking and Profiling

Benchmark nvidia/NVIDIA-Nemotron-Nano-9B-v2, without speculative decoding. Running the following command:
python -m sglang.bench_one_batch --model-path nvidia/NVIDIA-Nemotron-Nano-9B-v2 --max-running-requests 32 --batch 32 --input-len 256 --output-len 32

Branch:

Benchmark ...
Prefill. latency: 0.43751 s, throughput:  18724.34 token/s
Decode 0. Batch size: 32, latency: 0.01288 s, throughput:   2484.93 token/s
Decode 1. Batch size: 32, latency: 0.01210 s, throughput:   2644.09 token/s
Decode 2. Batch size: 32, latency: 0.01208 s, throughput:   2648.23 token/s
Decode 3. Batch size: 32, latency: 0.01207 s, throughput:   2651.73 token/s
Decode 4. Batch size: 32, latency: 0.01208 s, throughput:   2649.72 token/s
Decode.  median latency: 0.01206 s, median throughput:   2653.56 token/s
Total. latency:  0.812 s, throughput:  11351.80 token/s

Main:

Benchmark ...
Prefill. latency: 0.43763 s, throughput:  18719.14 token/s
Decode 0. Batch size: 32, latency: 0.01209 s, throughput:   2646.85 token/s
Decode 1. Batch size: 32, latency: 0.01201 s, throughput:   2664.41 token/s
Decode 2. Batch size: 32, latency: 0.01200 s, throughput:   2666.00 token/s
Decode 3. Batch size: 32, latency: 0.01201 s, throughput:   2663.96 token/s
Decode 4. Batch size: 32, latency: 0.01199 s, throughput:   2667.90 token/s
Decode.  median latency: 0.01197 s, median throughput:   2673.59 token/s
Total. latency:  0.809 s, throughput:  11393.65 token/s

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @roikoren755, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for Mamba2 models to act as target models in a speculative decoding framework. The changes primarily focus on adapting the underlying Triton kernels and model metadata to handle the multi-token and intermediate state requirements of speculative verification, aiming to leverage the efficiency benefits of this technique for Mamba2 architectures.

Highlights

  • Mamba2 Speculative Decoding Support: Enabled Mamba2 models to function as target models within a speculative decoding setup, allowing for more efficient inference.
  • Triton Kernel Enhancements: The selective_scan_update Triton kernel has been significantly updated to support multiple tokens per sequence, write intermediate cache states, read initial states for EAGLE top-k, and conditionally skip writing the final state to cache.
  • MambaMixer2 Module Adaptation: The MambaMixer2 module now includes support for target verification batches, which is crucial for the speculative decoding process.
  • Metadata Handling for Speculative Decoding: The Mamba2Metadata structure and its preparation methods (prepare_decode, prepare_mixed) have been extended to properly manage is_target_verify and draft_token_num flags, ensuring correct state management during speculative decoding.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for Mamba2 layers within the speculative decoding framework. The changes involve significant refactoring of metadata handling, particularly centralizing is_target_verify and draft_token_num into the Mamba2Metadata object. A major part of this work is the enhancement of the _selective_scan_update_kernel Triton kernel to support multi-token processing, intermediate state caching, and EAGLE tree attention masks, which is crucial for the performance and correctness of speculative decoding with Mamba2. The MambaMixer2.forward method has been adapted to correctly handle these new capabilities, including appropriate tensor reshaping and parameter passing. Overall, the changes are well-aligned with the goal of enabling Mamba2 in target models for speculative decoding.

Comment on lines 239 to +333
x = x.unsqueeze(1)
if x.dim() == 3:
x = x.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for unsqueezing dimensions can be a bit confusing due to the sequential if statements. If x.dim() is 2, it will be unsqueezed once to 3 dimensions, and then immediately again to 4 dimensions by the next if x.dim() == 3 check. While this achieves the desired 4D shape (batch, 1, 1, dim) for an initial 2D input (batch, dim), it could be clearer. Consider using elif or combining the logic to explicitly target the final 4D shape.

Suggested change
x = x.unsqueeze(1)
if x.dim() == 3:
x = x.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1).unsqueeze(1)
elif x.dim() == 3:
x = x.unsqueeze(1)

Comment on lines 241 to +337
dt = dt.unsqueeze(1)
if dt.dim() == 3:
dt = dt.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the comment for x, the sequential if statements for dt can be made more explicit. If dt.dim() is 2, it will be unsqueezed twice. Using elif or combining the logic would improve clarity.

Suggested change
dt = dt.unsqueeze(1)
if dt.dim() == 3:
dt = dt.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1).unsqueeze(1)
elif dt.dim() == 3:
dt = dt.unsqueeze(1)

Comment on lines 245 to +343
B = B.unsqueeze(1)
if B.dim() == 3:
B = B.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The same suggestion applies here for B. Consider using elif or combining the logic for better readability.

Suggested change
B = B.unsqueeze(1)
if B.dim() == 3:
B = B.unsqueeze(1)
if B.dim() == 2:
B = B.unsqueeze(1).unsqueeze(1)
elif B.dim() == 3:
B = B.unsqueeze(1)

Comment on lines 247 to +347
C = C.unsqueeze(1)
if C.dim() == 3:
C = C.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The same suggestion applies here for C. Consider using elif or combining the logic for better readability.

Suggested change
C = C.unsqueeze(1)
if C.dim() == 3:
C = C.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1).unsqueeze(1)
elif C.dim() == 3:
C = C.unsqueeze(1)

Comment on lines +351 to +354
if z.dim() == 2:
z = z.unsqueeze(1)
if z.dim() == 3:
z = z.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The same suggestion applies here for z. Consider using elif or combining the logic for better readability.

Suggested change
if z.dim() == 2:
z = z.unsqueeze(1)
if z.dim() == 3:
z = z.unsqueeze(1)
if z.dim() == 2:
z = z.unsqueeze(1).unsqueeze(1)
elif z.dim() == 3:
z = z.unsqueeze(1)

Comment on lines 255 to +360
out = out.unsqueeze(1)
if out.dim() == 3:
out = out.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The same suggestion applies here for out. Consider using elif or combining the logic for better readability.

Suggested change
out = out.unsqueeze(1)
if out.dim() == 3:
out = out.unsqueeze(1)
if out.dim() == 2:
out = out.unsqueeze(1).unsqueeze(1)
elif out.dim() == 3:
out = out.unsqueeze(1)

Copy link
Collaborator

@yizhang2077 yizhang2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, could we add some ut in test_nvidia_nemotron_nano_v2.py?

@roikoren755
Copy link
Contributor Author

LGTM overall, could we add some ut in test_nvidia_nemotron_nano_v2.py?

You mean add configuration to run it in a speculative decoding scenario there? I'm not sure we can easily do that, as we don't yet have a public EAGLE or MTP model that matches it (or any other Mamba-based model), and running a standalone draft model requires finding one that doesn't crash. I'll try looking into it

@roikoren755
Copy link
Contributor Author

@yizhang2077 added a test 👍

Copy link
Collaborator

@yizhang2077 yizhang2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yizhang2077
Copy link
Collaborator

Could you resolve this conflicts?

@roikoren755
Copy link
Contributor Author

Could you resolve this conflicts?

Done 👍

@roikoren755
Copy link
Contributor Author

@yizhang2077 apologies, found an issue in the code for the case where the SSM cache dtype is different than FP32, fixed and added another test to verify

@Fridge003
Copy link
Collaborator

/rerun-failed-ci

@Fridge003 Fridge003 closed this Nov 25, 2025
@Fridge003 Fridge003 reopened this Nov 25, 2025
@ispobock
Copy link
Collaborator

/tag-and-rerun-ci

@Fridge003
Copy link
Collaborator

@roikoren755 Please fix this error
https://github.com/sgl-project/sglang/actions/runs/19690909467/job/56425994326?pr=13434

@roikoren755
Copy link
Contributor Author

@roikoren755 Please fix this error https://github.com/sgl-project/sglang/actions/runs/19690909467/job/56425994326?pr=13434

I reverted the changes to the GDN attention backend, things should work now 🤞

@netanel-haber
Copy link
Contributor

/tag-and-rerun-ci

@ispobock ispobock merged commit 889b46e into sgl-project:main Dec 5, 2025
249 of 262 checks passed
@roikoren755 roikoren755 deleted the feat/mamba_mtp branch December 5, 2025 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants