Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,15 +1804,21 @@ async def async_anthropic_messages_handler(
Optional[litellm.types.utils.ProviderSpecificHeader],
kwargs.get("provider_specific_header", None),
)
extra_headers = ProviderSpecificHeaderUtils.get_provider_specific_headers(
provider_specific_headers = ProviderSpecificHeaderUtils.get_provider_specific_headers(
provider_specific_header=provider_specific_header,
custom_llm_provider=custom_llm_provider,
)
forwarded_headers = kwargs.get("headers", None)
if forwarded_headers and extra_headers:
merged_headers = {**forwarded_headers, **extra_headers}
else:
merged_headers = forwarded_headers or extra_headers
# Also check for extra_headers in kwargs (from config or direct calls)
extra_headers_from_kwargs = kwargs.get("extra_headers", None)
# Merge all header sources: forwarded < extra_headers < provider_specific
merged_headers = {}
if forwarded_headers:
merged_headers.update(forwarded_headers)
if extra_headers_from_kwargs:
merged_headers.update(extra_headers_from_kwargs)
if provider_specific_headers:
merged_headers.update(provider_specific_headers)
(
headers,
api_base,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def test_converse_transformation_anthropic_beta(self):
assert "additionalModelRequestFields" in result
additional_fields = result["additionalModelRequestFields"]
assert "anthropic_beta" in additional_fields
assert additional_fields["anthropic_beta"] == ["context-1m-2025-08-07", "interleaved-thinking-2025-05-14"]
# Sort both arrays before comparing to avoid flakiness from ordering differences
assert sorted(additional_fields["anthropic_beta"]) == sorted(["context-1m-2025-08-07", "interleaved-thinking-2025-05-14"])

def test_messages_transformation_anthropic_beta(self):
"""Test that Messages API transformation includes anthropic_beta in request."""
Expand All @@ -96,7 +97,8 @@ def test_messages_transformation_anthropic_beta(self):
)

assert "anthropic_beta" in result
assert result["anthropic_beta"] == ["output-128k-2025-02-19"]
# Sort both arrays before comparing to avoid flakiness from ordering differences
assert sorted(result["anthropic_beta"]) == sorted(["output-128k-2025-02-19"])

def test_converse_computer_use_compatibility(self):
"""Test that user anthropic_beta headers work with computer use tools."""
Expand Down Expand Up @@ -287,4 +289,4 @@ def test_prompt_caching_with_other_beta_headers(self):
assert "prompt-caching-2024-07-31" not in result["anthropic_beta"]
else:
# If no beta headers, that's also fine
assert True
assert True
150 changes: 145 additions & 5 deletions tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import io
import os
import pathlib
import ssl
import sys
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, Mock, patch

import pytest

sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.types.router import GenericLiteLLMParams


def test_prepare_fake_stream_request():
Expand Down Expand Up @@ -75,3 +72,146 @@ def test_prepare_fake_stream_request():
assert "stream" not in result_data
assert result_data["model"] == "gpt-4"
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]


@pytest.mark.asyncio
async def test_async_anthropic_messages_handler_extra_headers():
"""
Test that async_anthropic_messages_handler correctly extracts and merges
extra_headers from kwargs with proper priority.
"""
handler = BaseLLMHTTPHandler()

# Mock the config
mock_config = Mock()
mock_config.validate_anthropic_messages_environment = Mock(
return_value=({"x-api-key": "test-key"}, "https://api.anthropic.com")
)
mock_config.transform_anthropic_messages_request = Mock(
return_value={"model": "claude-3-opus-20240229", "messages": []}
)

# Mock the client
mock_client = AsyncMock()
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello!"}],
"model": "claude-3-opus-20240229",
"stop_reason": "end_turn",
}
mock_client.post = AsyncMock(return_value=mock_response)

# Mock logging object
mock_logging_obj = Mock()
mock_logging_obj.update_environment_variables = Mock()
mock_logging_obj.model_call_details = {}
mock_logging_obj.stream = False

# Test case 1: Only extra_headers in kwargs
kwargs = {
"extra_headers": {
"X-Custom-Header": "from-kwargs",
"X-Auth-Token": "token123",
}
}

with patch(
"litellm.litellm_core_utils.get_provider_specific_headers.ProviderSpecificHeaderUtils.get_provider_specific_headers"
) as mock_provider_headers:
mock_provider_headers.return_value = None

# Capture what headers are passed to validate_anthropic_messages_environment
captured_headers = {}
def capture_validate(*args, **kwargs):
captured_headers.update(kwargs.get("headers", {}))
return ({"x-api-key": "test-key"}, "https://api.anthropic.com")

mock_config.validate_anthropic_messages_environment = capture_validate

try:
await handler.async_anthropic_messages_handler(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "Hello"}],
anthropic_messages_provider_config=mock_config,
anthropic_messages_optional_request_params={},
custom_llm_provider="anthropic",
litellm_params=GenericLiteLLMParams(),
logging_obj=mock_logging_obj,
client=mock_client,
kwargs=kwargs,
)
except Exception:
pass # We're testing header extraction, not the full flow

# Verify extra_headers were extracted and merged
assert "X-Custom-Header" in captured_headers
assert captured_headers["X-Custom-Header"] == "from-kwargs"
assert "X-Auth-Token" in captured_headers
assert captured_headers["X-Auth-Token"] == "token123"


@pytest.mark.asyncio
async def test_async_anthropic_messages_handler_header_priority():
"""
Test that async_anthropic_messages_handler respects header priority:
forwarded < extra_headers < provider_specific
"""
handler = BaseLLMHTTPHandler()

# Mock the config
mock_config = Mock()
mock_client = AsyncMock()
mock_logging_obj = Mock()
mock_logging_obj.update_environment_variables = Mock()
mock_logging_obj.model_call_details = {}
mock_logging_obj.stream = False

# Test with all three header sources
kwargs = {
"headers": {"X-Priority": "forwarded", "X-Forwarded-Only": "keep"},
"extra_headers": {"X-Priority": "extra", "X-Extra-Only": "also-keep"},
}

with patch(
"litellm.litellm_core_utils.get_provider_specific_headers.ProviderSpecificHeaderUtils.get_provider_specific_headers"
) as mock_provider_headers:
mock_provider_headers.return_value = {
"X-Priority": "provider",
"X-Provider-Only": "keep-this-too"
}

captured_headers = {}
def capture_validate(*args, **kwargs):
captured_headers.update(kwargs.get("headers", {}))
return ({"x-api-key": "test-key"}, "https://api.anthropic.com")

mock_config.validate_anthropic_messages_environment = capture_validate
mock_config.transform_anthropic_messages_request = Mock(
return_value={"model": "claude-3-opus-20240229", "messages": []}
)

try:
await handler.async_anthropic_messages_handler(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "Hello"}],
anthropic_messages_provider_config=mock_config,
anthropic_messages_optional_request_params={},
custom_llm_provider="anthropic",
litellm_params=GenericLiteLLMParams(),
logging_obj=mock_logging_obj,
client=mock_client,
kwargs=kwargs,
)
except Exception:
pass

# Verify priority: provider_specific should win
assert captured_headers["X-Priority"] == "provider"
# Verify all unique headers from different sources are present
assert captured_headers["X-Forwarded-Only"] == "keep"
assert captured_headers["X-Extra-Only"] == "also-keep"
assert captured_headers["X-Provider-Only"] == "keep-this-too"
Loading