Skip to content

Commit aac01ef

Browse files
fix(usage): Normalize None token detail objects on Usage initialization
Extends #2034 to handle providers that return None for entire input_tokens_details and output_tokens_details objects (not just the fields within them). This affects non-streaming responses. Related to #1179 (which fixed the streaming case). Some providers like llama-stack return null for these optional fields in their JSON responses. The OpenAI SDK maps these to None in Python. Previously, passing None to the Usage constructor would fail Pydantic validation before __post_init__ could normalize them. This PR uses Pydantic's BeforeValidator to normalize None values at the field level, before Pydantic's type validation runs. The validators also convert Chat Completions API types (PromptTokensDetails, CompletionTokensDetails) to Responses API types (InputTokensDetails, OutputTokensDetails).
1 parent a05af4b commit aac01ef

File tree

3 files changed

+86
-21
lines changed

3 files changed

+86
-21
lines changed

src/agents/models/openai_chatcompletions.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from openai.types.chat.chat_completion import Choice
1212
from openai.types.responses import Response
1313
from openai.types.responses.response_prompt_param import ResponsePromptParam
14-
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
1514

1615
from .. import _debug
1716
from ..agent_output import AgentOutputSchemaBase
@@ -102,18 +101,9 @@ async def get_response(
102101
input_tokens=response.usage.prompt_tokens,
103102
output_tokens=response.usage.completion_tokens,
104103
total_tokens=response.usage.total_tokens,
105-
input_tokens_details=InputTokensDetails(
106-
cached_tokens=getattr(
107-
response.usage.prompt_tokens_details, "cached_tokens", 0
108-
)
109-
or 0,
110-
),
111-
output_tokens_details=OutputTokensDetails(
112-
reasoning_tokens=getattr(
113-
response.usage.completion_tokens_details, "reasoning_tokens", 0
114-
)
115-
or 0,
116-
),
104+
# BeforeValidator in Usage normalizes these from Chat Completions types
105+
input_tokens_details=response.usage.prompt_tokens_details, # type: ignore[arg-type]
106+
output_tokens_details=response.usage.completion_tokens_details, # type: ignore[arg-type]
117107
)
118108
if response.usage
119109
else Usage()

src/agents/usage.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,36 @@
1+
from __future__ import annotations
2+
13
from dataclasses import field
4+
from typing import Annotated
25

6+
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails
37
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
8+
from pydantic import BeforeValidator
49
from pydantic.dataclasses import dataclass
510

611

12+
def _normalize_input_tokens_details(
13+
v: InputTokensDetails | PromptTokensDetails | None,
14+
) -> InputTokensDetails:
15+
"""Converts None or PromptTokensDetails to InputTokensDetails."""
16+
if v is None:
17+
return InputTokensDetails(cached_tokens=0)
18+
if isinstance(v, PromptTokensDetails):
19+
return InputTokensDetails(cached_tokens=v.cached_tokens or 0)
20+
return v
21+
22+
23+
def _normalize_output_tokens_details(
24+
v: OutputTokensDetails | CompletionTokensDetails | None,
25+
) -> OutputTokensDetails:
26+
"""Converts None or CompletionTokensDetails to OutputTokensDetails."""
27+
if v is None:
28+
return OutputTokensDetails(reasoning_tokens=0)
29+
if isinstance(v, CompletionTokensDetails):
30+
return OutputTokensDetails(reasoning_tokens=v.reasoning_tokens or 0)
31+
return v
32+
33+
734
@dataclass
835
class RequestUsage:
936
"""Usage details for a single API request."""
@@ -32,16 +59,16 @@ class Usage:
3259
input_tokens: int = 0
3360
"""Total input tokens sent, across all requests."""
3461

35-
input_tokens_details: InputTokensDetails = field(
36-
default_factory=lambda: InputTokensDetails(cached_tokens=0)
37-
)
62+
input_tokens_details: Annotated[
63+
InputTokensDetails, BeforeValidator(_normalize_input_tokens_details)
64+
] = field(default_factory=lambda: InputTokensDetails(cached_tokens=0))
3865
"""Details about the input tokens, matching responses API usage details."""
3966
output_tokens: int = 0
4067
"""Total output tokens received, across all requests."""
4168

42-
output_tokens_details: OutputTokensDetails = field(
43-
default_factory=lambda: OutputTokensDetails(reasoning_tokens=0)
44-
)
69+
output_tokens_details: Annotated[
70+
OutputTokensDetails, BeforeValidator(_normalize_output_tokens_details)
71+
] = field(default_factory=lambda: OutputTokensDetails(reasoning_tokens=0))
4572
"""Details about the output tokens, matching responses API usage details."""
4673

4774
total_tokens: int = 0
@@ -70,7 +97,7 @@ def __post_init__(self) -> None:
7097
if self.output_tokens_details.reasoning_tokens is None:
7198
self.output_tokens_details = OutputTokensDetails(reasoning_tokens=0)
7299

73-
def add(self, other: "Usage") -> None:
100+
def add(self, other: Usage) -> None:
74101
"""Add another Usage object to this one, aggregating all fields.
75102
76103
This method automatically preserves request_usage_entries.

tests/test_usage.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails
12
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
23

34
from agents.usage import RequestUsage, Usage
@@ -270,7 +271,24 @@ def test_anthropic_cost_calculation_scenario():
270271

271272

272273
def test_usage_normalizes_none_token_details():
273-
# Some providers don't populate optional fields, resulting in None values
274+
# Some providers don't populate optional token detail fields
275+
# (cached_tokens, reasoning_tokens), and the OpenAI SDK's generated
276+
# code can bypass Pydantic validation (e.g., via model_construct),
277+
# allowing None values. We normalize these to 0 to prevent TypeErrors.
278+
279+
# Test entire objects being None (BeforeValidator)
280+
usage = Usage(
281+
requests=1,
282+
input_tokens=100,
283+
input_tokens_details=None, # type: ignore[arg-type]
284+
output_tokens=50,
285+
output_tokens_details=None, # type: ignore[arg-type]
286+
total_tokens=150,
287+
)
288+
assert usage.input_tokens_details.cached_tokens == 0
289+
assert usage.output_tokens_details.reasoning_tokens == 0
290+
291+
# Test fields within objects being None (__post_init__)
274292
input_details = InputTokensDetails(cached_tokens=0)
275293
input_details.__dict__["cached_tokens"] = None
276294

@@ -289,3 +307,33 @@ def test_usage_normalizes_none_token_details():
289307
# __post_init__ should normalize None to 0
290308
assert usage.input_tokens_details.cached_tokens == 0
291309
assert usage.output_tokens_details.reasoning_tokens == 0
310+
311+
312+
def test_usage_normalizes_chat_completions_types():
313+
# Chat Completions API uses PromptTokensDetails and CompletionTokensDetails,
314+
# while Usage expects InputTokensDetails and OutputTokensDetails (Responses API).
315+
# The BeforeValidator should convert between these types.
316+
317+
prompt_details = PromptTokensDetails(audio_tokens=10, cached_tokens=50)
318+
completion_details = CompletionTokensDetails(
319+
accepted_prediction_tokens=5,
320+
audio_tokens=10,
321+
reasoning_tokens=100,
322+
rejected_prediction_tokens=2,
323+
)
324+
325+
usage = Usage(
326+
requests=1,
327+
input_tokens=200,
328+
input_tokens_details=prompt_details, # type: ignore[arg-type]
329+
output_tokens=150,
330+
output_tokens_details=completion_details, # type: ignore[arg-type]
331+
total_tokens=350,
332+
)
333+
334+
# Should convert to Responses API types, extracting the relevant fields
335+
assert isinstance(usage.input_tokens_details, InputTokensDetails)
336+
assert usage.input_tokens_details.cached_tokens == 50
337+
338+
assert isinstance(usage.output_tokens_details, OutputTokensDetails)
339+
assert usage.output_tokens_details.reasoning_tokens == 100

0 commit comments

Comments
 (0)