-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[Spec] Mamba2 support in target models #13434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @roikoren755, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for Mamba2 models to act as target models in a speculative decoding framework. The changes primarily focus on adapting the underlying Triton kernels and model metadata to handle the multi-token and intermediate state requirements of speculative verification, aiming to leverage the efficiency benefits of this technique for Mamba2 architectures. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces comprehensive support for Mamba2 layers within the speculative decoding framework. The changes involve significant refactoring of metadata handling, particularly centralizing is_target_verify and draft_token_num into the Mamba2Metadata object. A major part of this work is the enhancement of the _selective_scan_update_kernel Triton kernel to support multi-token processing, intermediate state caching, and EAGLE tree attention masks, which is crucial for the performance and correctness of speculative decoding with Mamba2. The MambaMixer2.forward method has been adapted to correctly handle these new capabilities, including appropriate tensor reshaping and parameter passing. Overall, the changes are well-aligned with the goal of enabling Mamba2 in target models for speculative decoding.
| x = x.unsqueeze(1) | ||
| if x.dim() == 3: | ||
| x = x.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for unsqueezing dimensions can be a bit confusing due to the sequential if statements. If x.dim() is 2, it will be unsqueezed once to 3 dimensions, and then immediately again to 4 dimensions by the next if x.dim() == 3 check. While this achieves the desired 4D shape (batch, 1, 1, dim) for an initial 2D input (batch, dim), it could be clearer. Consider using elif or combining the logic to explicitly target the final 4D shape.
| x = x.unsqueeze(1) | |
| if x.dim() == 3: | |
| x = x.unsqueeze(1) | |
| if x.dim() == 2: | |
| x = x.unsqueeze(1).unsqueeze(1) | |
| elif x.dim() == 3: | |
| x = x.unsqueeze(1) |
| dt = dt.unsqueeze(1) | ||
| if dt.dim() == 3: | ||
| dt = dt.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the comment for x, the sequential if statements for dt can be made more explicit. If dt.dim() is 2, it will be unsqueezed twice. Using elif or combining the logic would improve clarity.
| dt = dt.unsqueeze(1) | |
| if dt.dim() == 3: | |
| dt = dt.unsqueeze(1) | |
| if dt.dim() == 2: | |
| dt = dt.unsqueeze(1).unsqueeze(1) | |
| elif dt.dim() == 3: | |
| dt = dt.unsqueeze(1) |
| B = B.unsqueeze(1) | ||
| if B.dim() == 3: | ||
| B = B.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| C = C.unsqueeze(1) | ||
| if C.dim() == 3: | ||
| C = C.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if z.dim() == 2: | ||
| z = z.unsqueeze(1) | ||
| if z.dim() == 3: | ||
| z = z.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| out = out.unsqueeze(1) | ||
| if out.dim() == 3: | ||
| out = out.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yizhang2077
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, could we add some ut in test_nvidia_nemotron_nano_v2.py?
You mean add configuration to run it in a speculative decoding scenario there? I'm not sure we can easily do that, as we don't yet have a public EAGLE or MTP model that matches it (or any other Mamba-based model), and running a standalone draft model requires finding one that doesn't crash. I'll try looking into it |
|
@yizhang2077 added a test 👍 |
yizhang2077
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Could you resolve this conflicts? |
Done 👍 |
|
@yizhang2077 apologies, found an issue in the code for the case where the SSM cache dtype is different than FP32, fixed and added another test to verify |
|
/rerun-failed-ci |
|
/tag-and-rerun-ci |
I reverted the changes to the GDN attention backend, things should work now 🤞 |
f37929c to
39f4923
Compare
|
/tag-and-rerun-ci |
Motivation
Enable running a model with a Mamba2 layer as the target model in a speculative decoding setup.
Modifications
Updated the
selective_scan_updateTriton kernel, so that it supports multiple tokens per sequence, writing intermediate cache states to a provided tensor, reading initial state from a different location for EAGLE top-k, and skipping writing final state to cache, similar to what is done in thefused_recurrent_gated_delta_rule_updateTriton kernel.Added support to the MambaMixer2 module for target verification batches.
Enabled speculative decoding support for Mamba2 models in the model runner.
Accuracy Tests
Running the
few_shot_gsm8ktest with 200 question, got the following results.nvidia/NVIDIA-Nemotron-Nano-9B-v2 - branch, with standalone speculative decoding (5 speculative steps, EAGLE top-k 3 and 8 draft tokens, very bad draft model):
nvidia/NVIDIA-Nemotron-Nano-9B-v2 - branch, without standalone speculative decoding:
nvidia/NVIDIA-Nemotron-Nano-9B-v2 - main, without standalone speculative decoding:
Benchmarking and Profiling
Benchmark nvidia/NVIDIA-Nemotron-Nano-9B-v2, without speculative decoding. Running the following command:
python -m sglang.bench_one_batch --model-path nvidia/NVIDIA-Nemotron-Nano-9B-v2 --max-running-requests 32 --batch 32 --input-len 256 --output-len 32Branch:
Main:
Checklist