Skip to content

Commit 5ae0ac4

Browse files
kaixihFridge003
andauthored
[NVIDIA] Fix use case of SGLANG_ENABLE_FLASHINFER_GEMM (#13274)
Co-authored-by: Baizhou Zhang <[email protected]>
1 parent 22f641a commit 5ae0ac4

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

docs/references/environment_variables.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ SGLang supports various environment variables that can be used to configure its
6868
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
6969
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
7070
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
71-
| `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` |
71+
| `SGLANG_ENABLE_FLASHINFER_FP8_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` |
7272
| `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` | Select backend for `mm_fp4` on Blackwell GPUS | `` |
7373
| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` |
7474
| `SGLANG_CUTLASS_MOE` (deprecated) | Use Cutlass FP8 MoE kernel on Blackwell GPUs (deprecated, use --moe-runner-backend=cutlass) | `false` |

python/sglang/srt/environ.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class Envs:
206206

207207
# Flashinfer
208208
SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True)
209-
SGLANG_ENABLE_FLASHINFER_GEMM = EnvBool(False)
209+
SGLANG_ENABLE_FLASHINFER_FP8_GEMM = EnvBool(False)
210210
# Default to the pick from flashinfer
211211
SGLANG_FLASHINFER_FP4_GEMM_BACKEND = EnvStr("")
212212
SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024)
@@ -307,6 +307,9 @@ def _print_deprecated_env(new_name: str, old_name: str):
307307

308308
def _convert_SGL_to_SGLANG():
309309
_print_deprecated_env("SGLANG_LOG_GC", "SGLANG_GC_LOG")
310+
_print_deprecated_env(
311+
"SGLANG_ENABLE_FLASHINFER_FP8_GEMM", "SGLANG_ENABLE_FLASHINFER_GEMM"
312+
)
310313

311314
for key, value in os.environ.items():
312315
if key.startswith("SGL_"):

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44

5+
from sglang.srt.environ import envs
56
from sglang.srt.layers import deep_gemm_wrapper
67
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
78
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
@@ -127,17 +128,17 @@ def cutlass_block_fp8_supported() -> bool:
127128

128129

129130
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
130-
ENABLE_FLASHINFER_GEMM = (
131-
get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
131+
ENABLE_FLASHINFER_FP8_GEMM = (
132+
envs.SGLANG_ENABLE_FLASHINFER_FP8_GEMM.get()
132133
and is_blackwell_supported()
133134
and is_flashinfer_available()
134135
)
135-
if ENABLE_FLASHINFER_GEMM:
136+
if ENABLE_FLASHINFER_FP8_GEMM:
136137
from flashinfer.gemm import gemm_fp8_nt_groupwise
137138

138139

139140
def dispatch_w8a8_block_fp8_linear() -> Callable:
140-
if ENABLE_FLASHINFER_GEMM:
141+
if ENABLE_FLASHINFER_FP8_GEMM:
141142
return flashinfer_gemm_w8a8_block_fp8_linear
142143
elif CUTLASS_BLOCK_FP8_SUPPORTED:
143144
return cutlass_w8a8_block_fp8_linear_with_fallback

python/sglang/srt/models/deepseek_v2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
per_token_group_quant_mla_deep_gemm_masked_fp8,
9090
)
9191
from sglang.srt.layers.quantization.fp8_utils import (
92+
ENABLE_FLASHINFER_FP8_GEMM,
9293
block_quant_dequant,
9394
block_quant_to_tensor_quant,
9495
channel_quant_to_tensor_quant,
@@ -3420,7 +3421,8 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
34203421
self_attn.use_deep_gemm_bmm = True
34213422

34223423
if (
3423-
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
3424+
not ENABLE_FLASHINFER_FP8_GEMM
3425+
and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
34243426
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
34253427
and hasattr(self.quant_config, "weight_block_size")
34263428
and self.quant_config.weight_block_size is not None

0 commit comments

Comments
 (0)