Skip to content

Commit 251c3ff

Browse files
committed
feat: preserve logprobs from chat completions API in ModelResponse
The SDK already accepts `top_logprobs` in ModelSettings and passes it to the API, but the logprobs returned in the response were discarded during conversion. This change: 1. Adds an optional `logprobs` field to ModelResponse dataclass 2. Extracts logprobs from `choice.logprobs.content` in the chat completions model and includes them in the ModelResponse This enables use cases like RLHF training, confidence scoring, and uncertainty estimation that require access to token-level log probabilities.
1 parent db68d1c commit 251c3ff

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/agents/items.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,13 @@ class ModelResponse:
356356
be passed to `Runner.run`.
357357
"""
358358

359+
logprobs: list[Any] | None = None
360+
"""Token log probabilities from the model response.
361+
Only populated when using the chat completions API with `top_logprobs` set in ModelSettings.
362+
Each element corresponds to a token and contains the token string, log probability, and
363+
optionally the top alternative tokens with their log probabilities.
364+
"""
365+
359366
def to_input_items(self) -> list[TResponseInputItem]:
360367
"""Convert the output into a list of input items suitable for passing to the model."""
361368
# We happen to know that the shape of the Pydantic output items are the same as the

src/agents/models/openai_chatcompletions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,15 @@ async def get_response(
129129

130130
items = Converter.message_to_output_items(message) if message is not None else []
131131

132+
logprobs_data = None
133+
if first_choice and first_choice.logprobs and first_choice.logprobs.content:
134+
logprobs_data = [lp.model_dump() for lp in first_choice.logprobs.content]
135+
132136
return ModelResponse(
133137
output=items,
134138
usage=usage,
135139
response_id=None,
140+
logprobs=logprobs_data,
136141
)
137142

138143
async def stream_response(

0 commit comments

Comments
 (0)