Skip to content

Commit 768094d

Browse files
author
qiuxuan.lzw
committed
support external custom models
1 parent 1dcde53 commit 768094d

File tree

7 files changed

+71
-1
lines changed

7 files changed

+71
-1
lines changed

python/sglang/srt/configs/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,9 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
957957
"DeepseekOCRForCausalLM",
958958
]
959959

960+
if envs.SGLANG_EXTERNAL_MM_MODEL_ARCH.value:
961+
multimodal_model_archs.append(envs.SGLANG_EXTERNAL_MM_MODEL_ARCH.value)
962+
960963

961964
def is_multimodal_model(model_architectures: List[str]):
962965
if any(

python/sglang/srt/environ.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ class Envs:
295295
# Health Check
296296
SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION = EnvBool(True)
297297

298+
# External models
299+
SGLANG_EXTERNAL_MODEL_PACKAGE = EnvStr("")
300+
SGLANG_EXTERNAL_MM_MODEL_ARCH = EnvStr("")
301+
SGLANG_EXTERNAL_MM_PROCESSOR_PACKAGE = EnvStr("")
302+
298303
# fmt: on
299304

300305

python/sglang/srt/managers/multimodal_processor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
PROCESSOR_MAPPING = {}
1313

1414

15-
def import_processors(package_name: str):
15+
def import_processors(package_name: str, overwrite: bool = False):
1616
package = importlib.import_module(package_name)
1717
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
1818
if not ispkg:
@@ -32,6 +32,11 @@ def import_processors(package_name: str):
3232
):
3333
assert hasattr(cls, "models")
3434
for arch in getattr(cls, "models"):
35+
if overwrite:
36+
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
37+
if model_cls.__name__ == arch.__name__:
38+
del PROCESSOR_MAPPING[model_cls]
39+
break
3540
PROCESSOR_MAPPING[arch] = cls
3641

3742

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
from sglang.srt.configs.model_config import ModelConfig
4343
from sglang.srt.disaggregation.utils import DisaggregationMode
44+
from sglang.srt.environ import envs
4445
from sglang.srt.lora.lora_registry import LoRARegistry
4546
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
4647
from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
@@ -210,6 +211,10 @@ def __init__(
210211
# Initialize tokenizer and processor
211212
if self.model_config.is_multimodal:
212213
import_processors("sglang.srt.multimodal.processors")
214+
if envs.SGLANG_EXTERNAL_MM_PROCESSOR_PACKAGE.value:
215+
import_processors(
216+
envs.SGLANG_EXTERNAL_MM_PROCESSOR_PACKAGE.value, overwrite=True
217+
)
213218
try:
214219
_processor = get_processor(
215220
server_args.tokenizer_path,

python/sglang/srt/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,6 @@ def import_model_classes(package_name: str):
123123

124124
ModelRegistry = _ModelRegistry()
125125
ModelRegistry.register("sglang.srt.models")
126+
127+
if envs.SGLANG_EXTERNAL_MODEL_PACKAGE.value:
128+
ModelRegistry.register(envs.SGLANG_EXTERNAL_MODEL_PACKAGE.value, overwrite=True)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sglang.srt.models.qwen2_vl import (
2+
Qwen2VLForConditionalGeneration as OriginalQwen2VLForConditionalGeneration,
3+
)
4+
from sglang.srt.multimodal.processors.qwen_vl import QwenVLImageProcessor
5+
6+
7+
class Qwen2VLForConditionalGeneration(OriginalQwen2VLForConditionalGeneration):
8+
def __init__(self, config, quant_config, prefix: str = "") -> None:
9+
super().__init__(config, quant_config, prefix)
10+
print("init custom model:", self.__class__.__name__)
11+
12+
13+
class CustomProcessor(QwenVLImageProcessor):
14+
models = [Qwen2VLForConditionalGeneration]
15+
16+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
17+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
18+
print("init custom processor:", self.__class__.__name__)
19+
20+
21+
EntryClass = Qwen2VLForConditionalGeneration

test/srt/test_external_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
import unittest
3+
4+
import sglang as sgl
5+
from sglang.test.test_utils import CustomTestCase
6+
7+
8+
class TestExternalModels(CustomTestCase):
9+
def test_external_model(self):
10+
os.environ["SGLANG_EXTERNAL_MODEL_PACKAGE"] = "external_models"
11+
os.environ["SGLANG_EXTERNAL_MM_PROCESSOR_PACKAGE"] = "external_models"
12+
prompt = "Today is a sunny day and I like"
13+
model_path = "Qwen/Qwen2-VL-2B-Instruct"
14+
15+
engine = sgl.Engine(
16+
model_path=model_path,
17+
cuda_graph_max_bs=1,
18+
max_total_tokens=64,
19+
enable_multimodal=True,
20+
)
21+
out = engine.generate(prompt)["text"]
22+
engine.shutdown()
23+
24+
self.assertGreater(len(out), 0)
25+
26+
27+
if __name__ == "__main__":
28+
unittest.main()

0 commit comments

Comments
 (0)