Skip to content

Commit 11c7304

Browse files
committed
[Deepseek V3.2] Change indexer weights_proj to fp32
1 parent 2bc7c5e commit 11c7304

File tree

3 files changed

+90
-122
lines changed

3 files changed

+90
-122
lines changed

docs/basic_usage/deepseek_v32.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ Latency: 25.109 s
129129
Output throughput: 5226.235 token/s
130130
```
131131

132+
To test long-context accuracy, run gsm8k with `--num-shots 20`. The results are very close to the 8 shots results:
133+
```
134+
Accuracy: 0.956
135+
Invalid: 0.000
136+
Latency: 29.545 s
137+
Output throughput: 4418.617 token/s
138+
```
132139

133140
### Accuracy Test with `gpqa-diamond`
134141

@@ -143,6 +150,65 @@ Repeat: 8, mean: 0.797
143150
Scores: ['0.808', '0.798', '0.808', '0.798', '0.783', '0.788', '0.803', '0.793']
144151
```
145152

153+
### Accuracy Test with `aime 2025`
154+
155+
Prepare the environment by installing NeMo-Skills in the docker or your own virtual environment:
156+
157+
```
158+
pip install git+https://github.com/NVIDIA/NeMo-Skills.git --ignore-installed blinker
159+
```
160+
161+
Modify the [`jinja chat_template`](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/tokenizer_config.json#L34) by replacing
162+
163+
```
164+
{% set thinking = false %}
165+
```
166+
with
167+
```
168+
{% set thinking = true %}
169+
```
170+
and save it to `chat_template_thinking.jinja`.
171+
172+
Launch the SGLang server with the modified chat-template file:
173+
```
174+
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --chat-template chat_template_thinking.jinja
175+
```
176+
177+
Run the following script to evaluate AIME 2025:
178+
```
179+
#! /bin/bash
180+
export NEMO_SKILLS_DISABLE_UNCOMMITTED_CHANGES_CHECK=1
181+
182+
ns prepare_data aime25
183+
184+
PORT=30000
185+
BACKEND=sglang
186+
MODEL="deepseek-ai/DeepSeek-V3.2-Exp"
187+
MODEL_NAME="dsv32-fp8"
188+
189+
echo "Starting AIME25 evaluation with model $MODEL on port $PORT using backend $BACKEND..."
190+
ns eval \
191+
--benchmarks=aime25:4 \
192+
--server_type=$BACKEND \
193+
--model=$MODEL \
194+
--server_address=http://localhost:${PORT}/v1 \
195+
--output_dir=nemo_skills_aime25_${MODEL_NAME}_output_${BACKEND}_$(date +%Y%m%d_%H%M%S) \
196+
++max_concurrent_requests=512 \
197+
++server.api_key=dummy \
198+
++inference.tokens_to_generate=64000
199+
```
200+
201+
Test results:
202+
203+
204+
| evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer |
205+
|--------------------|-------------|------------|-------------|-----------------------|-----------|
206+
| pass@1[avg-of-4] | 30 | 14410 | 1758 | 85.83% ± 4.19% | 0.00% |
207+
| majority@4 | 30 | 14410 | 1758 | 90.00% | 0.00% |
208+
| pass@4 | 30 | 14410 | 1758 | 93.33% | 0.00% |
209+
210+
Note that the result of problem#3 with id `aime25-2` is marked as false by nemo-skills but is actually correct because nemo-skills fails to match predicted_answer `016` with expected_answer `16`. If we add 1/30 = 3.33% to the results, the pass@1[avg-of-4] result matches with reference which is 89.3.
211+
146212

147213
## DSA long sequence context parallel optimization(experimental)
148214

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 24 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def __init__(
109109
prefix: str = "",
110110
quant_config: Optional[QuantizationConfig] = None,
111111
alt_stream: Optional[torch.cuda.Stream] = None,
112-
fuse_wk_and_weights_proj: bool = False,
113112
):
114113
super().__init__()
115114
self.hidden_size = hidden_size
@@ -120,7 +119,6 @@ def __init__(
120119
self.q_lora_rank = q_lora_rank
121120
self.layer_id = layer_id
122121
self.alt_stream = alt_stream
123-
self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
124122
self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp()
125123
if self.nsa_enable_prefill_cp:
126124
self.cp_size = get_attention_tp_size()
@@ -139,28 +137,22 @@ def __init__(
139137
quant_config=quant_config,
140138
prefix=add_prefix("wq_b", prefix),
141139
)
142-
if self.fuse_wk_and_weights_proj:
143-
self.fused_wk_and_weights_proj = ReplicatedLinear(
144-
self.hidden_size,
145-
self.head_dim + self.n_heads,
146-
bias=False,
147-
prefix=add_prefix("fused_wk_and_weights_proj", prefix),
148-
)
149-
else:
150-
self.wk = ReplicatedLinear(
151-
self.hidden_size,
152-
self.head_dim,
153-
bias=False,
154-
quant_config=quant_config,
155-
prefix=add_prefix("wk", prefix),
156-
)
157-
# NOTE: weight_proj is not quantized
158-
self.weights_proj = ReplicatedLinear(
159-
self.hidden_size,
160-
self.n_heads,
161-
bias=False,
162-
prefix=add_prefix("weights_proj", prefix),
163-
)
140+
141+
self.wk = ReplicatedLinear(
142+
self.hidden_size,
143+
self.head_dim,
144+
bias=False,
145+
quant_config=quant_config,
146+
prefix=add_prefix("wk", prefix),
147+
)
148+
# NOTE: weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenience
149+
self.weights_proj = ReplicatedLinear(
150+
self.hidden_size,
151+
self.n_heads,
152+
bias=False,
153+
params_dtype=torch.float32,
154+
prefix=add_prefix("weights_proj", prefix),
155+
)
164156
self.k_norm = LayerNorm(self.head_dim, dtype=torch.float32)
165157
self.rotary_emb = get_rope_wrapper(
166158
rope_head_dim,
@@ -189,7 +181,6 @@ def _get_q_k_bf16(
189181
enable_dual_stream: bool,
190182
forward_batch: ForwardBatch,
191183
):
192-
weights = None
193184
if enable_dual_stream:
194185
current_stream = torch.cuda.current_stream()
195186
self.alt_stream.wait_stream(current_stream)
@@ -206,12 +197,7 @@ def _get_q_k_bf16(
206197
)
207198
with torch.cuda.stream(self.alt_stream):
208199
# TODO we should also put DeepGEMM half SM here?
209-
if self.fuse_wk_and_weights_proj:
210-
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
211-
[self.head_dim, self.n_heads], dim=-1
212-
)
213-
else:
214-
key, _ = self.wk(x)
200+
key, _ = self.wk(x)
215201
key = self.k_norm(key)
216202

217203
k_rope, _ = torch.split(
@@ -224,17 +210,10 @@ def _get_q_k_bf16(
224210
else:
225211
query, _ = self.wq_b(q_lora)
226212
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
227-
228213
q_rope, _ = torch.split(
229214
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
230215
)
231-
232-
if self.fuse_wk_and_weights_proj:
233-
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
234-
[self.head_dim, self.n_heads], dim=-1
235-
)
236-
else:
237-
key, _ = self.wk(x)
216+
key, _ = self.wk(x)
238217
key = self.k_norm(key)
239218
k_rope, _ = torch.split(
240219
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
@@ -266,21 +245,16 @@ def _get_q_k_bf16(
266245
query = rotate_activation(query)
267246
key = rotate_activation(key)
268247

269-
return query, key, weights
248+
return query, key
270249

271250
def _get_k_bf16(
272251
self,
273252
x: torch.Tensor,
274253
positions: torch.Tensor,
275254
enable_dual_stream: bool,
276255
):
277-
# Compute only key, skip query and weights (weights is discarded if fused)
278-
if self.fuse_wk_and_weights_proj:
279-
key, _ = self.fused_wk_and_weights_proj(x)[0].split(
280-
[self.head_dim, self.n_heads], dim=-1
281-
)
282-
else:
283-
key, _ = self.wk(x)
256+
# Compute only key, skip query
257+
key, _ = self.wk(x)
284258
key = self.k_norm(key)
285259
k_rope, _ = torch.split(
286260
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
@@ -779,7 +753,7 @@ def forward_cuda(
779753
return_indices,
780754
)
781755

782-
query, key, weights = self._get_q_k_bf16(
756+
query, key = self._get_q_k_bf16(
783757
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
784758
)
785759

@@ -808,8 +782,7 @@ def forward_cuda(
808782
index_k_scale=k_scale,
809783
)
810784

811-
if not self.fuse_wk_and_weights_proj:
812-
weights, _ = self.weights_proj(x)
785+
weights, _ = self.weights_proj(x.float())
813786
weights = self._get_logits_head_gate(weights, q_scale)
814787

815788
if is_cuda():
@@ -1037,7 +1010,7 @@ def forward_npu(
10371010
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
10381011

10391012
x = x.view(-1, self.hidden_size)
1040-
weights = self.weights_proj(x)[0]
1013+
weights = self.weights_proj(x.float())[0]
10411014
block_table = (
10421015
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
10431016
)

python/sglang/srt/models/deepseek_v2.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -238,17 +238,6 @@ def add_forward_absorb_core_attention_backend(backend_name):
238238
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
239239

240240

241-
def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
242-
"""
243-
NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
244-
"""
245-
return (
246-
is_deepseek_nsa(config)
247-
and quant_config is not None
248-
and quant_config.get_name() == "modelopt_fp4"
249-
)
250-
251-
252241
class AttnForwardMethod(IntEnum):
253242
# Use multi-head attention
254243
MHA = auto()
@@ -1224,9 +1213,6 @@ def __init__(
12241213
quant_config=quant_config,
12251214
layer_id=layer_id,
12261215
alt_stream=alt_stream,
1227-
fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
1228-
config, quant_config
1229-
),
12301216
)
12311217

12321218
self.kv_b_proj = ColumnParallelLinear(
@@ -3766,12 +3752,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
37663752
self.config.q_lora_rank is not None
37673753
)
37683754
cached_a_proj = {} if fuse_qkv_a_proj else None
3769-
# Fuse wk and weights_proj when NSA Indexer is enabled and quant_config is FP4. For nextn, fp4 is disabled so we cannot fuse.
3770-
fuse_wk_and_weights_proj = (
3771-
is_nsa_indexer_wk_and_weights_proj_fused(self.config, self.quant_config)
3772-
and not is_nextn
3773-
)
3774-
cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None
37753755

37763756
if is_nextn:
37773757
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
@@ -3957,57 +3937,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
39573937
)
39583938
cached_a_proj.pop(q_a_proj_name)
39593939
cached_a_proj.pop(kv_a_proj_name)
3960-
elif fuse_wk_and_weights_proj and (
3961-
"wk" in name or "weights_proj" in name
3962-
):
3963-
cached_wk_and_weights_proj[name] = loaded_weight
3964-
wk_name = (
3965-
name
3966-
if "wk" in name
3967-
else name.replace("weights_proj", "wk")
3968-
)
3969-
weights_proj_name = (
3970-
name
3971-
if "weights_proj" in name
3972-
else name.replace("wk", "weights_proj")
3973-
)
3974-
3975-
# When both wk and weights_proj has been cached, load the fused weight to parameter
3976-
if (
3977-
wk_name in cached_wk_and_weights_proj
3978-
and weights_proj_name in cached_wk_and_weights_proj
3979-
):
3980-
wk_weight = cached_wk_and_weights_proj[wk_name]
3981-
weights_proj_weight = cached_wk_and_weights_proj[
3982-
weights_proj_name
3983-
]
3984-
# todo dequantize wk for fp8
3985-
assert wk_weight.dtype == weights_proj_weight.dtype
3986-
fused_weight = torch.cat(
3987-
[wk_weight, weights_proj_weight], dim=0
3988-
)
3989-
param_name = (
3990-
name.replace("wk", "fused_wk_and_weights_proj")
3991-
if "wk" in name
3992-
else name.replace(
3993-
"weights_proj",
3994-
"fused_wk_and_weights_proj",
3995-
)
3996-
)
3997-
param = params_dict[param_name]
3998-
3999-
weight_loader = getattr(
4000-
param, "weight_loader", default_weight_loader
4001-
)
4002-
maybe_executor_submit(
4003-
executor=executor,
4004-
futures=futures,
4005-
use_async=use_async_loading,
4006-
func=weight_loader,
4007-
func_args=(param, fused_weight),
4008-
)
4009-
cached_wk_and_weights_proj.pop(wk_name)
4010-
cached_wk_and_weights_proj.pop(weights_proj_name)
40113940
else:
40123941
if (
40133942
"k_scale" in name or "v_scale" in name

0 commit comments

Comments
 (0)