Skip to content

Commit 2471602

Browse files
committed
Added support for twelvelabs pegasus
1 parent 7227747 commit 2471602

File tree

5 files changed

+224
-0
lines changed

5 files changed

+224
-0
lines changed

litellm/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,9 @@ def add_known_models():
12221222
from .llms.bedrock.chat.invoke_transformations.amazon_titan_transformation import (
12231223
AmazonTitanConfig,
12241224
)
1225+
from .llms.bedrock.chat.invoke_transformations.amazon_twelvelabs_pegasus_transformation import (
1226+
AmazonTwelveLabsPegasusConfig,
1227+
)
12251228
from .llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
12261229
AmazonInvokeConfig,
12271230
)

litellm/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,7 @@
851851
"nova",
852852
"deepseek_r1",
853853
"qwen3",
854+
"twelvelabs",
854855
]
855856

856857
BEDROCK_EMBEDDING_PROVIDERS_LITERAL = Literal[
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
Transforms OpenAI-style requests into TwelveLabs Pegasus 1.2 requests for Bedrock.
3+
4+
Reference:
5+
https://docs.twelvelabs.io/docs/models/pegasus
6+
"""
7+
8+
from typing import Any, Dict, List, Optional
9+
10+
from litellm.llms.base_llm.base_utils import type_to_response_format_param
11+
from litellm.llms.base_llm.chat.transformation import BaseConfig
12+
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
13+
AmazonInvokeConfig,
14+
)
15+
from litellm.types.llms.openai import AllMessageValues
16+
from litellm.utils import get_base64_str
17+
18+
19+
class AmazonTwelveLabsPegasusConfig(AmazonInvokeConfig, BaseConfig):
20+
"""
21+
Handles transforming OpenAI-style requests into Bedrock InvokeModel requests for
22+
`twelvelabs.pegasus-1-2-v1:0`.
23+
24+
Pegasus 1.2 requires an `inputPrompt` and a `mediaSource` that either references
25+
an S3 object or a base64-encoded clip. Optional OpenAI params (temperature,
26+
response_format, max_tokens) are translated to the TwelveLabs schema.
27+
"""
28+
29+
def get_supported_openai_params(self, model: str) -> List[str]:
30+
return [
31+
"max_tokens",
32+
"max_completion_tokens",
33+
"temperature",
34+
"response_format",
35+
]
36+
37+
def map_openai_params(
38+
self,
39+
non_default_params: dict,
40+
optional_params: dict,
41+
model: str,
42+
drop_params: bool,
43+
) -> dict:
44+
for param, value in non_default_params.items():
45+
if param in {"max_tokens", "max_completion_tokens"}:
46+
optional_params["maxOutputTokens"] = value
47+
if param == "temperature":
48+
optional_params["temperature"] = value
49+
if param == "response_format":
50+
optional_params["responseFormat"] = self._normalize_response_format(
51+
value
52+
)
53+
return optional_params
54+
55+
def _normalize_response_format(self, value: Any) -> Any:
56+
if isinstance(value, dict):
57+
return value
58+
return type_to_response_format_param(response_format=value) or value
59+
60+
def transform_request(
61+
self,
62+
model: str,
63+
messages: List[AllMessageValues],
64+
optional_params: dict,
65+
litellm_params: dict,
66+
headers: dict,
67+
) -> dict:
68+
input_prompt = self._convert_messages_to_prompt(messages=messages)
69+
request_data: Dict[str, Any] = {"inputPrompt": input_prompt}
70+
71+
media_source = self._build_media_source(optional_params)
72+
if media_source is not None:
73+
request_data["mediaSource"] = media_source
74+
75+
for key in ("temperature", "maxOutputTokens", "responseFormat"):
76+
if key in optional_params:
77+
request_data[key] = optional_params.get(key)
78+
return request_data
79+
80+
def _build_media_source(self, optional_params: dict) -> Optional[dict]:
81+
direct_source = optional_params.get("mediaSource") or optional_params.get(
82+
"media_source"
83+
)
84+
if isinstance(direct_source, dict):
85+
return direct_source
86+
87+
base64_input = optional_params.get("video_base64") or optional_params.get(
88+
"base64_string"
89+
)
90+
if base64_input:
91+
return {"base64String": get_base64_str(base64_input)}
92+
93+
s3_uri = (
94+
optional_params.get("video_s3_uri")
95+
or optional_params.get("s3_uri")
96+
or optional_params.get("media_source_s3_uri")
97+
)
98+
if s3_uri:
99+
s3_location = {"uri": s3_uri}
100+
bucket_owner = (
101+
optional_params.get("video_s3_bucket_owner")
102+
or optional_params.get("s3_bucket_owner")
103+
or optional_params.get("media_source_bucket_owner")
104+
)
105+
if bucket_owner:
106+
s3_location["bucketOwner"] = bucket_owner
107+
return {"s3Location": s3_location}
108+
return None
109+
110+
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
111+
prompt_parts: List[str] = []
112+
for message in messages:
113+
role = message.get("role", "user")
114+
content = message.get("content", "")
115+
if isinstance(content, list):
116+
text_fragments = []
117+
for item in content:
118+
if isinstance(item, dict):
119+
item_type = item.get("type")
120+
if item_type == "text":
121+
text_fragments.append(item.get("text", ""))
122+
elif item_type == "image_url":
123+
text_fragments.append("<image>")
124+
elif item_type == "video_url":
125+
text_fragments.append("<video>")
126+
elif item_type == "audio_url":
127+
text_fragments.append("<audio>")
128+
elif isinstance(item, str):
129+
text_fragments.append(item)
130+
content = " ".join(text_fragments)
131+
prompt_parts.append(f"{role}: {content}")
132+
return "\n".join(part for part in prompt_parts if part).strip()
133+

litellm/llms/bedrock/common_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ def get_bedrock_chat_config(model: str):
616616
return litellm.AmazonInvokeNovaConfig()
617617
elif bedrock_invoke_provider == "qwen3":
618618
return litellm.AmazonQwen3Config()
619+
elif bedrock_invoke_provider == "twelvelabs":
620+
return litellm.AmazonTwelveLabsPegasusConfig()
619621
else:
620622
return litellm.AmazonInvokeConfig()
621623

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from litellm.llms.bedrock.chat.invoke_transformations.amazon_twelvelabs_pegasus_transformation import (
2+
AmazonTwelveLabsPegasusConfig,
3+
)
4+
5+
6+
def _make_messages() -> list[dict]:
7+
return [
8+
{"role": "system", "content": "You are an assistant"},
9+
{"role": "user", "content": "Summarize the attached video."},
10+
]
11+
12+
13+
def test_supported_openai_params():
14+
config = AmazonTwelveLabsPegasusConfig()
15+
supported = config.get_supported_openai_params("twelvelabs.pegasus-1-2-v1:0")
16+
assert "max_tokens" in supported
17+
assert "temperature" in supported
18+
assert "response_format" in supported
19+
20+
21+
def test_map_openai_params_translates_fields():
22+
config = AmazonTwelveLabsPegasusConfig()
23+
optional_params: dict = {}
24+
config.map_openai_params(
25+
non_default_params={
26+
"max_tokens": 20,
27+
"temperature": 0.6,
28+
"response_format": {
29+
"type": "json_schema",
30+
"json_schema": {"name": "video_schema", "schema": {"type": "object"}},
31+
},
32+
},
33+
optional_params=optional_params,
34+
model="twelvelabs.pegasus-1-2-v1:0",
35+
drop_params=False,
36+
)
37+
38+
assert optional_params["maxOutputTokens"] == 20
39+
assert optional_params["temperature"] == 0.6
40+
assert "responseFormat" in optional_params
41+
assert optional_params["responseFormat"]["json_schema"]["name"] == "video_schema"
42+
43+
44+
def test_transform_request_includes_base64_media():
45+
config = AmazonTwelveLabsPegasusConfig()
46+
optional_params = config.map_openai_params(
47+
non_default_params={"max_tokens": 10},
48+
optional_params={},
49+
model="twelvelabs.pegasus-1-2-v1:0",
50+
drop_params=False,
51+
)
52+
optional_params["video_base64"] = "data:video/mp4;base64,AAA"
53+
54+
request = config.transform_request(
55+
model="twelvelabs.pegasus-1-2-v1:0",
56+
messages=_make_messages(),
57+
optional_params=optional_params,
58+
litellm_params={},
59+
headers={},
60+
)
61+
62+
assert request["inputPrompt"].startswith("system:")
63+
assert request["mediaSource"]["base64String"] == "AAA"
64+
assert request["maxOutputTokens"] == 10
65+
66+
67+
def test_transform_request_includes_s3_media():
68+
config = AmazonTwelveLabsPegasusConfig()
69+
optional_params = {
70+
"video_s3_uri": "s3://test-bucket/video.mp4",
71+
"video_s3_bucket_owner": "123456789012",
72+
}
73+
74+
request = config.transform_request(
75+
model="twelvelabs.pegasus-1-2-v1:0",
76+
messages=_make_messages(),
77+
optional_params=optional_params,
78+
litellm_params={},
79+
headers={},
80+
)
81+
82+
s3_location = request["mediaSource"]["s3Location"]
83+
assert s3_location["uri"] == "s3://test-bucket/video.mp4"
84+
assert s3_location["bucketOwner"] == "123456789012"
85+

0 commit comments

Comments
 (0)