Skip to content

Commit d8afd20

Browse files
rogeryounghxuebi
authored andcommitted
Support moe topk sigmoid kernel (#13049)
Co-authored-by: xuebi <[email protected]>
1 parent a49553b commit d8afd20

File tree

10 files changed

+992
-0
lines changed

10 files changed

+992
-0
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ set(SOURCES
323323
"csrc/moe/moe_sum.cu"
324324
"csrc/moe/moe_sum_reduce.cu"
325325
"csrc/moe/moe_topk_softmax_kernels.cu"
326+
"csrc/moe/moe_topk_sigmoid_kernels.cu"
326327
"csrc/moe/nvfp4_blockwise_moe.cu"
327328
"csrc/moe/fp8_blockwise_moe_kernel.cu"
328329
"csrc/moe/prepare_moe_input.cu"
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import itertools
2+
import os
3+
4+
import pytest
5+
import torch
6+
import triton
7+
from sgl_kernel import topk_sigmoid
8+
9+
# CI environment detection
10+
IS_CI = (
11+
os.getenv("CI", "false").lower() == "true"
12+
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
13+
)
14+
15+
16+
def torch_topk_sigmoid_native(
17+
gating_output: torch.Tensor,
18+
topk: int,
19+
renormalize: bool,
20+
correction_bias: torch.Tensor = None,
21+
):
22+
scores = gating_output.sigmoid()
23+
if correction_bias is not None:
24+
n_routed_experts = gating_output.shape[-1]
25+
scores_for_choice = scores.view(
26+
-1, n_routed_experts
27+
) + correction_bias.unsqueeze(0)
28+
_, topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1)
29+
topk_weights = scores.gather(1, topk_indices)
30+
else:
31+
topk_weights, topk_indices = torch.topk(scores, k=topk, dim=-1)
32+
33+
if renormalize:
34+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
35+
36+
return topk_weights, topk_indices
37+
38+
39+
def sglang_topk_sigmoid(
40+
gating_output: torch.Tensor,
41+
topk: int,
42+
renormalize: bool,
43+
correction_bias: torch.Tensor = None,
44+
):
45+
num_tokens, num_experts = gating_output.shape
46+
47+
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
48+
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
49+
50+
topk_sigmoid(
51+
topk_weights,
52+
topk_indices,
53+
gating_output,
54+
renormalize=renormalize,
55+
correction_bias=correction_bias,
56+
)
57+
58+
return topk_weights, topk_indices
59+
60+
61+
def get_topk_sigmoid_input(num_tokens, num_experts):
62+
gating_output = torch.randn(
63+
(num_tokens, num_experts), dtype=torch.float32, device="cuda"
64+
)
65+
correction_bias = torch.randn((num_experts), dtype=torch.float32, device="cuda")
66+
return gating_output, correction_bias
67+
68+
69+
def calculate_diff(num_tokens, num_experts, topk):
70+
gating_output, correction_bias = get_topk_sigmoid_input(num_tokens, num_experts)
71+
72+
weights_torch, indices_torch = torch_topk_sigmoid_native(
73+
gating_output.clone(),
74+
topk,
75+
True,
76+
correction_bias.clone(),
77+
)
78+
weights_sglang, indices_sglang = sglang_topk_sigmoid(
79+
gating_output.clone(),
80+
topk,
81+
True,
82+
correction_bias.clone(),
83+
)
84+
85+
weights_diff = torch.abs(weights_torch - weights_sglang).mean().item()
86+
indices_match = torch.equal(indices_torch, indices_sglang)
87+
88+
if (
89+
torch.allclose(weights_torch, weights_sglang, atol=1e-3, rtol=1e-3)
90+
and indices_match
91+
):
92+
print("✅ Torch and SGLang topk_sigmoid implementations match")
93+
else:
94+
print(
95+
f"❌ Implementations differ: Weights diff={weights_diff}, Indices match={indices_match}"
96+
)
97+
98+
99+
# CI environment uses simplified parameters
100+
if IS_CI:
101+
num_tokens_range = [128] # Single value for CI
102+
num_experts_range = [32] # Single value for CI
103+
topk_range = [2] # Single value for CI
104+
else:
105+
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
106+
num_experts_range = [32, 64, 128, 256, 12, 512]
107+
topk_range = [1, 2, 4, 8]
108+
109+
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
110+
111+
112+
# Filter providers based on vLLM availability
113+
line_vals = ["sglang", "torch"]
114+
line_names = ["SGLang", "Torch"]
115+
styles = [("blue", "-"), ("green", "-")]
116+
117+
118+
@triton.testing.perf_report(
119+
triton.testing.Benchmark(
120+
x_names=["num_tokens", "num_experts", "topk"],
121+
x_vals=configs,
122+
line_arg="provider",
123+
line_vals=line_vals,
124+
line_names=line_names,
125+
styles=styles,
126+
ylabel="Latency (us)",
127+
plot_name="topk-sigmoid-performance",
128+
args={},
129+
)
130+
)
131+
def benchmark(num_tokens, num_experts, topk, provider):
132+
gating_output, correction_bias = get_topk_sigmoid_input(num_tokens, num_experts)
133+
134+
if provider == "torch" or provider == "torch1":
135+
136+
def fn():
137+
return torch_topk_sigmoid_native(
138+
gating_output,
139+
topk,
140+
True,
141+
correction_bias,
142+
)
143+
144+
elif provider == "sglang" or provider == "sglang1":
145+
146+
def fn():
147+
return sglang_topk_sigmoid(gating_output, topk, True, correction_bias)
148+
149+
quantiles = [0.5, 0.2, 0.8]
150+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
151+
152+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
153+
154+
155+
if __name__ == "__main__":
156+
# Simplify configs for CI environment
157+
if IS_CI:
158+
test_configs = [(20, 32, 2)] # Single config for CI
159+
else:
160+
test_configs = [
161+
(20, 256, 4),
162+
(20, 256, 8),
163+
(20, 12, 4),
164+
(20, 12, 1),
165+
(20, 512, 4),
166+
(20, 512, 1),
167+
]
168+
169+
for num_tokens, num_experts, topk in test_configs:
170+
calculate_diff(num_tokens, num_experts, topk)
171+
benchmark.run(print_data=True)

sgl-kernel/csrc/common_extension.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
230230
"moe_softcapping, Tensor? correction_bias) -> ()");
231231
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
232232

233+
m.def(
234+
"topk_sigmoid(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, Tensor? "
235+
"correction_bias) -> ()");
236+
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);
237+
233238
m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()");
234239
m.impl("moe_sum_reduce", torch::kCUDA, &moe_sum_reduce);
235240

sgl-kernel/csrc/common_extension_rocm.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
100100
"moe_softcapping, Tensor? correction_bias) -> ()");
101101
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
102102

103+
m.def(
104+
"topk_sigmoid(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, Tensor? "
105+
"correction_bias) -> ()");
106+
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);
107+
103108
/*
104109
* From csrc/speculative
105110
*/

0 commit comments

Comments
 (0)