Skip to content

Commit aea78b8

Browse files
authored
[Feat] Add support for Batch API Rate limiting - PR1 adds support for input based rate limits (BerriAI#16075)
* add count_input_file_usage * add count_input_file_usage * fix count_input_file_usage * _get_batch_job_input_file_usage * fixes imports * use _get_batch_job_input_file_usage * test_batch_rate_limits * add _check_and_increment_batch_counters * add get_rate_limiter_for_call_type * test_batch_rate_limit_multiple_requests * fixes for batch limits * fix linting * fix MYPY linting
1 parent 8a7f39d commit aea78b8

File tree

8 files changed

+881
-12
lines changed

8 files changed

+881
-12
lines changed

batch_small.jsonl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
2+
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
3+
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
4+

litellm/batches/batch_utils.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import json
2-
from typing import Any, List, Literal, Tuple, Optional
2+
import time
3+
from typing import Any, List, Literal, Optional, Tuple
4+
5+
import httpx
36

47
import litellm
58
from litellm._logging import verbose_logger
9+
from litellm._uuid import uuid
610
from litellm.types.llms.openai import Batch
7-
from litellm.types.utils import CallTypes, Usage
11+
from litellm.types.utils import CallTypes, ModelResponse, Usage
12+
from litellm.utils import token_counter
813

914

1015
async def calculate_batch_cost_and_usage(
@@ -107,6 +112,10 @@ def calculate_vertex_ai_batch_cost_and_usage(
107112
"""
108113
Calculate both cost and usage from Vertex AI batch responses
109114
"""
115+
from litellm.litellm_core_utils.litellm_logging import Logging
116+
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
117+
VertexGeminiConfig,
118+
)
110119
total_cost = 0.0
111120
total_tokens = 0
112121
prompt_tokens = 0
@@ -115,14 +124,7 @@ def calculate_vertex_ai_batch_cost_and_usage(
115124
for response in vertex_ai_batch_responses:
116125
if response.get("status") == "JOB_STATE_SUCCEEDED": # Check if response was successful
117126
# Transform Vertex AI response to OpenAI format if needed
118-
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
119-
from litellm import ModelResponse
120-
from litellm.litellm_core_utils.litellm_logging import Logging
121-
from litellm.types.utils import CallTypes
122-
from litellm._uuid import uuid
123-
import httpx
124-
import time
125-
127+
126128
# Create required arguments for the transformation method
127129
model_response = ModelResponse()
128130

@@ -163,8 +165,9 @@ def calculate_vertex_ai_batch_cost_and_usage(
163165
total_cost += cost
164166

165167
# Extract usage from the transformed response
166-
if hasattr(openai_format_response, 'usage') and openai_format_response.usage:
167-
usage = openai_format_response.usage
168+
usage_obj = getattr(openai_format_response, 'usage', None)
169+
if usage_obj:
170+
usage = usage_obj
168171
else:
169172
# Fallback: create usage from response dict
170173
response_dict = openai_format_response.dict() if hasattr(openai_format_response, 'dict') else {}
@@ -278,6 +281,33 @@ def _get_batch_job_total_usage_from_file_content(
278281
completion_tokens=completion_tokens,
279282
)
280283

284+
def _get_batch_job_input_file_usage(
285+
file_content_dictionary: List[dict],
286+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
287+
model_name: Optional[str] = None,
288+
) -> Usage:
289+
"""
290+
Count the number of tokens in the input file
291+
292+
Used for batch rate limiting to count the number of tokens in the input file
293+
"""
294+
prompt_tokens: int = 0
295+
completion_tokens: int = 0
296+
297+
for _item in file_content_dictionary:
298+
body = _item.get("body", {})
299+
model = body.get("model", model_name or "")
300+
messages = body.get("messages", [])
301+
302+
if messages:
303+
item_tokens = token_counter(model=model, messages=messages)
304+
prompt_tokens += item_tokens
305+
306+
return Usage(
307+
total_tokens=prompt_tokens + completion_tokens,
308+
prompt_tokens=prompt_tokens,
309+
completion_tokens=completion_tokens,
310+
)
281311

282312
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
283313
"""

0 commit comments

Comments
 (0)