Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,11 @@ async def base_process_llm_request(

responses = await llm_responses

response = responses[1]
# Guardrails (pre/during-call) can inject a mock response to short-circuit the LLM call.
# Prefer it when present so blocked/filtered output is returned instead of the model response.
response = self.data.get("mock_response")
if response is None:
response = responses[1]

hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
Expand Down Expand Up @@ -804,7 +808,7 @@ async def _handle_llm_api_exception(
# This matches the original behavior before the refactor in commit 511d435f6f
error_body = await e.response.aread()
error_text = error_body.decode("utf-8")

raise HTTPException(
status_code=e.response.status_code,
detail={"error": error_text},
Expand Down Expand Up @@ -1072,9 +1076,9 @@ def _inject_cost_into_usage_dict(obj: dict, model_name: str) -> Optional[dict]:

# Add cache-related fields to **params (handled by Usage.__init__)
if cache_creation_input_tokens is not None:
usage_kwargs["cache_creation_input_tokens"] = (
cache_creation_input_tokens
)
usage_kwargs[
"cache_creation_input_tokens"
] = cache_creation_input_tokens
if cache_read_input_tokens is not None:
usage_kwargs["cache_read_input_tokens"] = cache_read_input_tokens

Expand All @@ -1093,7 +1097,9 @@ def _inject_cost_into_usage_dict(obj: dict, model_name: str) -> Optional[dict]:
return obj
return None

def maybe_get_model_id(self, _logging_obj: Optional[LiteLLMLoggingObj]) -> Optional[str]:
def maybe_get_model_id(
self, _logging_obj: Optional[LiteLLMLoggingObj]
) -> Optional[str]:
"""
Get model_id from logging object or request metadata.

Expand All @@ -1103,10 +1109,7 @@ def maybe_get_model_id(self, _logging_obj: Optional[LiteLLMLoggingObj]) -> Optio
model_id = None
if _logging_obj:
# 1. Try getting from litellm_params (updated during call)
if (
hasattr(_logging_obj, "litellm_params")
and _logging_obj.litellm_params
):
if hasattr(_logging_obj, "litellm_params") and _logging_obj.litellm_params:
# First check direct model_info path (set by router.py with selected deployment)
model_info = _logging_obj.litellm_params.get("model_info") or {}
model_id = model_info.get("id", None)
Expand Down
99 changes: 98 additions & 1 deletion tests/test_litellm/proxy/test_common_request_processing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import copy
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest
from fastapi import Request, status
from fastapi import Request, Response, status
from fastapi.responses import StreamingResponse

import litellm
import litellm.proxy.common_request_processing as common_request_processing
from litellm._uuid import uuid
from litellm.integrations.opentelemetry import UserAPIKeyAuth
from litellm.proxy.common_request_processing import (
Expand Down Expand Up @@ -75,6 +77,101 @@ async def mock_common_processing_pre_call_logic(
pytest.fail("litellm_call_id is not a valid UUID")
assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"]

@pytest.mark.asyncio
async def test_base_process_llm_request_prefers_guardrail_mock_response(
self, monkeypatch
):
processing_obj = ProxyBaseLLMRequestProcessing(
data={
"messages": [],
"metadata": {},
"litellm_metadata": {"model_info": {"id": "fallback-model"}},
}
)

guardrail_response = litellm.ModelResponse(
model="bedrock-guardrail",
hidden_params={"model_id": "guardrail-model"},
)
llm_response = litellm.ModelResponse(
model="real-model",
hidden_params={"model_id": "real-model"},
)

async def mock_common_processing(self, *args, **kwargs):
logging_obj = SimpleNamespace(litellm_call_id="test-call-id")
self.data["litellm_call_id"] = "test-call-id"
self.data["litellm_logging_obj"] = logging_obj
return self.data, logging_obj

monkeypatch.setattr(
ProxyBaseLLMRequestProcessing,
"common_processing_pre_call_logic",
mock_common_processing,
)

async def mock_route_request(*args, **kwargs):
async def _inner():
return llm_response

return _inner()

monkeypatch.setattr(
common_request_processing,
"route_request",
mock_route_request,
)

check_response_size_is_safe_mock = AsyncMock()
monkeypatch.setattr(
common_request_processing,
"check_response_size_is_safe",
check_response_size_is_safe_mock,
)

async def mock_during_call_hook(*args, **kwargs):
kwargs["data"]["mock_response"] = guardrail_response

proxy_logging_obj = MagicMock(spec=ProxyLogging)
proxy_logging_obj.during_call_hook = AsyncMock(
side_effect=mock_during_call_hook
)
proxy_logging_obj.update_request_status = AsyncMock(return_value=None)
proxy_logging_obj.post_call_success_hook = AsyncMock(
return_value=guardrail_response
)

user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
user_api_key_dict.tpm_limit = None
user_api_key_dict.rpm_limit = None
user_api_key_dict.max_budget = None
user_api_key_dict.spend = 0
user_api_key_dict.allowed_model_region = None

fastapi_response = Response()
proxy_config = MagicMock(spec=ProxyConfig)

result = await processing_obj.base_process_llm_request(
request=MagicMock(spec=Request),
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="acompletion",
proxy_logging_obj=proxy_logging_obj,
general_settings={},
proxy_config=proxy_config,
select_data_generator=lambda **kwargs: None,
)

assert result is guardrail_response
assert (
proxy_logging_obj.post_call_success_hook.await_args.kwargs["response"]
is guardrail_response
)
assert (
check_response_size_is_safe_mock.await_args.kwargs["response"]
is guardrail_response
)

@pytest.mark.asyncio
async def test_stream_timeout_header_processing(self):
"""
Expand Down
Loading