Skip to content

Commit c8c3632

Browse files
committed
[Deepseek V3.2] Change indexer weights_proj to fp32
1 parent 6448b4c commit c8c3632

File tree

3 files changed

+74
-124
lines changed

3 files changed

+74
-124
lines changed

docs/basic_usage/deepseek_v32.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ Latency: 25.109 s
129129
Output throughput: 5226.235 token/s
130130
```
131131

132+
To test long-context accuracy, run gsm8k with `--num-shots 20`. The results are very close to the 8 shots results:
133+
```
134+
Accuracy: 0.956
135+
Invalid: 0.000
136+
Latency: 29.545 s
137+
Output throughput: 4418.617 token/s
138+
```
132139

133140
### Accuracy Test with `gpqa-diamond`
134141

@@ -142,3 +149,46 @@ The mean accuracy over 8 runs shows 0.797, which matches the number 79.9 in offi
142149
Repeat: 8, mean: 0.797
143150
Scores: ['0.808', '0.798', '0.808', '0.798', '0.783', '0.788', '0.803', '0.793']
144151
```
152+
153+
### Accuracy Test with `aime 2025`
154+
155+
Prepare the environment by installing NeMo-Skills in the docker or your own virtual environment:
156+
157+
```
158+
pip install git+https://github.com/NVIDIA/NeMo-Skills.git --ignore-installed blinker
159+
```
160+
161+
Run the following script:
162+
```
163+
#! /bin/bash
164+
export NEMO_SKILLS_DISABLE_UNCOMMITTED_CHANGES_CHECK=1
165+
166+
ns prepare_data aime25
167+
168+
PORT=30000
169+
BACKEND=sglang
170+
MODEL="deepseek-ai/DeepSeek-V3.2-Exp"
171+
MODEL_NAME="dsv32-fp8"
172+
173+
echo "Starting AIME25 evaluation with model $MODEL on port $PORT using backend $BACKEND..."
174+
ns eval \
175+
--benchmarks=aime25:4 \
176+
--server_type=$BACKEND \
177+
--model=$MODEL \
178+
--server_address=http://localhost:${PORT}/v1 \
179+
--output_dir=nemo_skills_aime25_${MODEL_NAME}_output_${BACKEND}_$(date +%Y%m%d_%H%M%S) \
180+
++max_concurrent_requests=512 \
181+
++server.api_key=dummy \
182+
++inference.tokens_to_generate=64000
183+
```
184+
185+
Test results:
186+
187+
188+
| evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer |
189+
|--------------------|-------------|------------|-------------|-----------------------|-----------|
190+
| pass@1[avg-of-4] | 30 | 14410 | 1758 | 85.83% ± 4.19% | 0.00% |
191+
| majority@4 | 30 | 14410 | 1758 | 90.00% | 0.00% |
192+
| pass@4 | 30 | 14410 | 1758 | 93.33% | 0.00% |
193+
194+
Note that the result of problem#3 with id `aime25-2` is marked as false by nemo-skills but is actually correct because nemo-skills fails to match predicted_anwer `016` with expected_answer `16`. If we add 1/30 = 3.33% to the results, the pass@1[avg-of-4] result matches with reference which is 89.3.

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def __init__(
100100
prefix: str = "",
101101
quant_config: Optional[QuantizationConfig] = None,
102102
alt_stream: Optional[torch.cuda.Stream] = None,
103-
fuse_wk_and_weights_proj: bool = False,
104103
):
105104
super().__init__()
106105
self.hidden_size = hidden_size
@@ -111,7 +110,6 @@ def __init__(
111110
self.q_lora_rank = q_lora_rank
112111
self.layer_id = layer_id
113112
self.alt_stream = alt_stream
114-
self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
115113
if is_cuda():
116114
self.sm_count = deep_gemm.get_num_sms()
117115
self.half_device_sm_count = ceil_align(self.sm_count // 2, 8)
@@ -123,28 +121,22 @@ def __init__(
123121
quant_config=quant_config,
124122
prefix=add_prefix("wq_b", prefix),
125123
)
126-
if self.fuse_wk_and_weights_proj:
127-
self.fused_wk_and_weights_proj = ReplicatedLinear(
128-
self.hidden_size,
129-
self.head_dim + self.n_heads,
130-
bias=False,
131-
prefix=add_prefix("fused_wk_and_weights_proj", prefix),
132-
)
133-
else:
134-
self.wk = ReplicatedLinear(
135-
self.hidden_size,
136-
self.head_dim,
137-
bias=False,
138-
quant_config=quant_config,
139-
prefix=add_prefix("wk", prefix),
140-
)
141-
# NOTE: weight_proj is not quantized
142-
self.weights_proj = ReplicatedLinear(
143-
self.hidden_size,
144-
self.n_heads,
145-
bias=False,
146-
prefix=add_prefix("weights_proj", prefix),
147-
)
124+
125+
self.wk = ReplicatedLinear(
126+
self.hidden_size,
127+
self.head_dim,
128+
bias=False,
129+
quant_config=quant_config,
130+
prefix=add_prefix("wk", prefix),
131+
)
132+
# NOTE: weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenience
133+
self.weights_proj = ReplicatedLinear(
134+
self.hidden_size,
135+
self.n_heads,
136+
bias=False,
137+
params_dtype=torch.float32,
138+
prefix=add_prefix("weights_proj", prefix),
139+
)
148140
self.k_norm = LayerNorm(self.head_dim, dtype=torch.float32)
149141
self.rotary_emb = get_rope_wrapper(
150142
rope_head_dim,
@@ -172,7 +164,6 @@ def _get_q_k_bf16(
172164
positions: torch.Tensor,
173165
enable_dual_stream: bool,
174166
):
175-
weights = None
176167
if enable_dual_stream:
177168
current_stream = torch.cuda.current_stream()
178169
self.alt_stream.wait_stream(current_stream)
@@ -189,12 +180,7 @@ def _get_q_k_bf16(
189180
)
190181
with torch.cuda.stream(self.alt_stream):
191182
# TODO we should also put DeepGEMM half SM here?
192-
if self.fuse_wk_and_weights_proj:
193-
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
194-
[self.head_dim, self.n_heads], dim=-1
195-
)
196-
else:
197-
key, _ = self.wk(x)
183+
key, _ = self.wk(x)
198184
key = self.k_norm(key)
199185

200186
k_rope, _ = torch.split(
@@ -207,17 +193,10 @@ def _get_q_k_bf16(
207193
else:
208194
query, _ = self.wq_b(q_lora)
209195
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
210-
211196
q_rope, _ = torch.split(
212197
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
213198
)
214-
215-
if self.fuse_wk_and_weights_proj:
216-
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
217-
[self.head_dim, self.n_heads], dim=-1
218-
)
219-
else:
220-
key, _ = self.wk(x)
199+
key, _ = self.wk(x)
221200
key = self.k_norm(key)
222201
k_rope, _ = torch.split(
223202
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
@@ -240,21 +219,16 @@ def _get_q_k_bf16(
240219
query = rotate_activation(query)
241220
key = rotate_activation(key)
242221

243-
return query, key, weights
222+
return query, key
244223

245224
def _get_k_bf16(
246225
self,
247226
x: torch.Tensor,
248227
positions: torch.Tensor,
249228
enable_dual_stream: bool,
250229
):
251-
# Compute only key, skip query and weights (weights is discarded if fused)
252-
if self.fuse_wk_and_weights_proj:
253-
key, _ = self.fused_wk_and_weights_proj(x)[0].split(
254-
[self.head_dim, self.n_heads], dim=-1
255-
)
256-
else:
257-
key, _ = self.wk(x)
230+
# Compute only key, skip query
231+
key, _ = self.wk(x)
258232
key = self.k_norm(key)
259233
k_rope, _ = torch.split(
260234
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
@@ -606,9 +580,7 @@ def forward_cuda(
606580
return_indices,
607581
)
608582

609-
query, key, weights = self._get_q_k_bf16(
610-
q_lora, x, positions, enable_dual_stream
611-
)
583+
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
612584

613585
if enable_dual_stream:
614586
current_stream = torch.cuda.current_stream()
@@ -635,8 +607,7 @@ def forward_cuda(
635607
index_k_scale=k_scale,
636608
)
637609

638-
if not self.fuse_wk_and_weights_proj:
639-
weights, _ = self.weights_proj(x)
610+
weights, _ = self.weights_proj(x.float())
640611
weights = self._get_logits_head_gate(weights, q_scale)
641612

642613
if is_cuda():
@@ -801,7 +772,7 @@ def forward_npu(
801772
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
802773

803774
x = x.view(-1, self.hidden_size)
804-
weights = self.weights_proj(x)[0]
775+
weights = self.weights_proj(x.float())[0]
805776
block_table = (
806777
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
807778
)

python/sglang/srt/models/deepseek_v2.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -224,17 +224,6 @@ def add_forward_absorb_core_attention_backend(backend_name):
224224
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
225225

226226

227-
def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
228-
"""
229-
NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
230-
"""
231-
return (
232-
is_deepseek_nsa(config)
233-
and quant_config is not None
234-
and quant_config.get_name() == "modelopt_fp4"
235-
)
236-
237-
238227
class AttnForwardMethod(IntEnum):
239228
# Use multi-head attention
240229
MHA = auto()
@@ -1189,9 +1178,6 @@ def __init__(
11891178
quant_config=quant_config,
11901179
layer_id=layer_id,
11911180
alt_stream=alt_stream,
1192-
fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
1193-
config, quant_config
1194-
),
11951181
)
11961182

11971183
self.kv_b_proj = ColumnParallelLinear(
@@ -3610,12 +3596,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
36103596
self.config.q_lora_rank is not None
36113597
)
36123598
cached_a_proj = {} if fuse_qkv_a_proj else None
3613-
# Fuse wk and weights_proj when NSA Indexer is enabled and quant_config is FP4. For nextn, fp4 is disabled so we cannot fuse.
3614-
fuse_wk_and_weights_proj = (
3615-
is_nsa_indexer_wk_and_weights_proj_fused(self.config, self.quant_config)
3616-
and not is_nextn
3617-
)
3618-
cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None
36193599

36203600
if is_nextn:
36213601
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
@@ -3801,57 +3781,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
38013781
)
38023782
cached_a_proj.pop(q_a_proj_name)
38033783
cached_a_proj.pop(kv_a_proj_name)
3804-
elif fuse_wk_and_weights_proj and (
3805-
"wk" in name or "weights_proj" in name
3806-
):
3807-
cached_wk_and_weights_proj[name] = loaded_weight
3808-
wk_name = (
3809-
name
3810-
if "wk" in name
3811-
else name.replace("weights_proj", "wk")
3812-
)
3813-
weights_proj_name = (
3814-
name
3815-
if "weights_proj" in name
3816-
else name.replace("wk", "weights_proj")
3817-
)
3818-
3819-
# When both wk and weights_proj has been cached, load the fused weight to parameter
3820-
if (
3821-
wk_name in cached_wk_and_weights_proj
3822-
and weights_proj_name in cached_wk_and_weights_proj
3823-
):
3824-
wk_weight = cached_wk_and_weights_proj[wk_name]
3825-
weights_proj_weight = cached_wk_and_weights_proj[
3826-
weights_proj_name
3827-
]
3828-
# todo dequantize wk for fp8
3829-
assert wk_weight.dtype == weights_proj_weight.dtype
3830-
fused_weight = torch.cat(
3831-
[wk_weight, weights_proj_weight], dim=0
3832-
)
3833-
param_name = (
3834-
name.replace("wk", "fused_wk_and_weights_proj")
3835-
if "wk" in name
3836-
else name.replace(
3837-
"weights_proj",
3838-
"fused_wk_and_weights_proj",
3839-
)
3840-
)
3841-
param = params_dict[param_name]
3842-
3843-
weight_loader = getattr(
3844-
param, "weight_loader", default_weight_loader
3845-
)
3846-
maybe_executor_submit(
3847-
executor=executor,
3848-
futures=futures,
3849-
use_async=use_async_loading,
3850-
func=weight_loader,
3851-
func_args=(param, fused_weight),
3852-
)
3853-
cached_wk_and_weights_proj.pop(wk_name)
3854-
cached_wk_and_weights_proj.pop(weights_proj_name)
38553784
else:
38563785
if (
38573786
"k_scale" in name or "v_scale" in name

0 commit comments

Comments
 (0)