Skip to content

Commit 416cd01

Browse files
CSWYF3634076jinyouzhi
authored andcommitted
[Model] Add Ernie4.5 VL Model Support (vllm-project#22514)
Signed-off-by: wangyafeng <[email protected]>
1 parent da34215 commit 416cd01

File tree

11 files changed

+3370
-1
lines changed

11 files changed

+3370
-1
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ Specified using `--task generate`.
529529
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ |
530530
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | | ✅︎ | ✅︎ |
531531
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | | ✅︎ | ✅︎ |
532+
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ |
532533
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. | | | |
533534
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ |
534535
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |

examples/offline_inference/vision_language.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
150150
)
151151

152152

153+
# Ernie4.5-VL
154+
def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
155+
model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
156+
157+
engine_args = EngineArgs(
158+
model=model_name,
159+
max_model_len=4096,
160+
max_num_seqs=5,
161+
limit_mm_per_prompt={modality: 1},
162+
trust_remote_code=True,
163+
)
164+
165+
if modality == "image":
166+
placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
167+
elif modality == "video":
168+
placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
169+
170+
prompts = [
171+
(
172+
f"<|begin_of_sentence|>User: {question}{placeholder}\n"
173+
"Assistant: <think></think>"
174+
)
175+
for question in questions
176+
]
177+
178+
return ModelRequestData(
179+
engine_args=engine_args,
180+
prompts=prompts,
181+
)
182+
183+
153184
# Florence2
154185
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
155186
assert modality == "image"
@@ -1115,6 +1146,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
11151146
"blip-2": run_blip2,
11161147
"chameleon": run_chameleon,
11171148
"deepseek_vl_v2": run_deepseek_vl2,
1149+
"ernie45_vl": run_ernie45_vl,
11181150
"florence2": run_florence2,
11191151
"fuyu": run_fuyu,
11201152
"gemma3": run_gemma3,

requirements/test.in

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,6 @@ numpy
5151
runai-model-streamer==0.11.0
5252
runai-model-streamer-s3==0.11.0
5353
fastsafetensors>=0.1.10
54-
pydantic>=2.10 # 2.9 leads to error on python 3.10
54+
pydantic>=2.10 # 2.9 leads to error on python 3.10
55+
terratorch==1.1rc2 # required for PrithviMAE test
56+
decord==0.6.0

requirements/test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ datasets==3.0.2
114114
# mteb
115115
decorator==5.1.1
116116
# via librosa
117+
decord==0.6.0
118+
# via -r requirements/test.in
117119
dill==0.3.8
118120
# via
119121
# datasets
@@ -356,6 +358,7 @@ numpy==1.26.4
356358
# contourpy
357359
# cupy-cuda12x
358360
# datasets
361+
# decord
359362
# einx
360363
# encodec
361364
# evaluate

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def _test_processing_correctness_one(
264264
"Salesforce/blip2-opt-2.7b",
265265
"facebook/chameleon-7b",
266266
"deepseek-ai/deepseek-vl2-tiny",
267+
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
267268
"microsoft/Florence-2-base",
268269
"adept/fuyu-8b",
269270
"google/gemma-3-4b-it",

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ def check_available_online(
329329
max_transformers_version="4.48", # noqa: E501
330330
transformers_version_reason="HF model is not compatible.", # noqa: E501
331331
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
332+
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
333+
trust_remote_code=True),
332334
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
333335
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
334336
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
import torch
7+
8+
from .common import apply_rotary_emb_dispatch
9+
from .mrope import MRotaryEmbedding
10+
11+
12+
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
13+
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
14+
15+
def forward(
16+
self,
17+
positions: torch.Tensor,
18+
query: torch.Tensor,
19+
key: Optional[torch.Tensor] = None,
20+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
21+
assert positions.ndim == 1 or positions.ndim == 2
22+
assert key is not None
23+
24+
num_tokens = positions.shape[-1]
25+
cos_sin = self.cos_sin_cache[positions]
26+
cos, sin = cos_sin.chunk(2, dim=-1)
27+
if positions.ndim == 2:
28+
assert self.mrope_section
29+
30+
section_h = self.mrope_section[0] # 22
31+
section_w = self.mrope_section[1] # 22
32+
section_t = self.mrope_section[2] # 20
33+
assert section_h == section_w
34+
# Split according to [h w h w h w h w... t t t...]
35+
section_cos_t = cos[..., -section_t:]
36+
section_cos_h = cos[..., :section_h + section_w:2]
37+
section_cos_w = cos[..., 1:section_h + section_w:2]
38+
39+
cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[
40+
1], section_cos_w[2]
41+
cos_hw = torch.stack([cos_h, cos_w],
42+
dim=-1).reshape(cos_h.shape[:-1] +
43+
(cos_h.shape[-1] * 2, ))
44+
cos = torch.cat([cos_hw, cos_t], dim=-1)
45+
46+
section_sin_t = sin[..., -section_t:]
47+
section_sin_h = sin[..., :section_h + section_w:2]
48+
section_sin_w = sin[..., 1:section_h + section_w:2]
49+
50+
sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[
51+
1], section_sin_w[2]
52+
sin_hw = torch.stack([sin_h, sin_w],
53+
dim=-1).reshape(sin_h.shape[:-1] +
54+
(sin_h.shape[-1] * 2, ))
55+
sin = torch.cat([sin_hw, sin_t], dim=-1)
56+
57+
query_shape = query.shape
58+
query = query.view(num_tokens, -1, self.head_size)
59+
query_rot = query[..., :self.rotary_dim]
60+
query_pass = query[..., self.rotary_dim:]
61+
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
62+
self.is_neox_style)
63+
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
64+
65+
key_shape = key.shape
66+
key = key.view(num_tokens, -1, self.head_size)
67+
key_rot = key[..., :self.rotary_dim]
68+
key_pass = key[..., self.rotary_dim:]
69+
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
70+
self.is_neox_style)
71+
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
72+
return query, key

0 commit comments

Comments
 (0)