Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,53 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
)
)

# Synchronously extract token_ids, routing_matrix, and logprobs from the provider response.
try:
token_ids = []
routing_matrix = []
logprobs_obj = getattr(response.choices[0], "logprobs", None)

if logprobs_obj is not None:
if isinstance(logprobs_obj, dict):
content = logprobs_obj.get("content", [])
else:
content = getattr(logprobs_obj, "content", [])

if isinstance(content, list):
for item in content:
if isinstance(item, dict):
tid = item.get("token_id")
rm = item.get("routing_matrix")
else:
tid = getattr(item, "token_id", None)
rm = getattr(item, "routing_matrix", None)

if tid is not None:
token_ids.append(tid)
if rm is not None:
routing_matrix.append(rm)

logger.info(
"[SingleTurnRolloutProcessor] Extracted %d token_ids and %d routing_matrix entries from logprobs",
len(token_ids),
len(routing_matrix),
)

# Store as 1D lists directly for SingleTurn (no step dimension needed)
if token_ids or routing_matrix or logprobs_obj is not None:
if not row.execution_metadata.extra:
row.execution_metadata.extra = {}
if token_ids:
row.execution_metadata.extra["token_ids"] = token_ids
if routing_matrix:
row.execution_metadata.extra["routing_matrix"] = routing_matrix
if logprobs_obj is not None:
row.execution_metadata.extra["logprobs"] = logprobs_obj
except Exception as e:
logger.warning(
"[SingleTurnRolloutProcessor] Failed to extract token_ids/routing_matrix/logprobs: %s", e
)

row.messages = messages

row.execution_metadata.duration_seconds = time.perf_counter() - start_time
Expand Down
56 changes: 56 additions & 0 deletions tests/pytest/test_pytest_input_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,59 @@ def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[Evaluati
for row in rows:
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
return rows


@pytest.mark.parametrize(
"completion_params",
[
{
"model": "fireworks_ai/accounts/fireworks/models/qwen3-30b-a3b",
"logprobs": True,
# "include_routing_matrix": True, # Requires --enable-moe-stats on server
"temperature": 0.6,
"max_tokens": 256,
}
],
)
@evaluation_test(
input_messages=[
[
[
Message(role="user", content="What is 2+2?"),
]
]
],
rollout_processor=SingleTurnRolloutProcessor(),
mode="all",
)
def test_single_turn_with_logprobs_and_routing_matrix(rows: List[EvaluationRow]) -> List[EvaluationRow]:
"""Test SingleTurnRolloutProcessor with logprobs and routing_matrix extraction."""
for row in rows:
# Check if extra metadata was extracted
extra = row.execution_metadata.extra
print("\n=== DEBUG: execution_metadata.extra ===")
print(f"extra type: {type(extra)}")
print(f"extra keys: {extra.keys() if isinstance(extra, dict) else 'N/A'}")

if isinstance(extra, dict):
if "token_ids" in extra:
token_ids = extra["token_ids"]
print(f"token_ids: found, len={len(token_ids)}, first 10 ids={token_ids[:10]}")
else:
print("token_ids: NOT FOUND")

if "routing_matrix" in extra:
routing_matrix = extra["routing_matrix"]
print(f"routing_matrix: found, len={len(routing_matrix)}")
else:
print("routing_matrix: NOT FOUND")

if "logprobs" in extra:
print("logprobs: found")
else:
print("logprobs: NOT FOUND")

print("=" * 50)

row.evaluation_result = EvaluateResult(score=1.0, reason="Test passed")
return rows
Loading