Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .default_agent_rollout_processor import AgentRolloutProcessor
from .default_dataset_adapter import default_dataset_adapter
from .default_klavis_sandbox_rollout_processor import KlavisSandboxRolloutProcessor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing try/except for optional klavis dependency import

The KlavisSandboxRolloutProcessor is imported unconditionally at line 3, but it depends on the optional klavis package. This will cause an ImportError for any user who imports from eval_protocol.pytest without having klavis installed. Other optional dependency imports like PydanticAgentRolloutProcessor and LangGraphRolloutProcessor are correctly wrapped in try/except blocks (lines 16-22 and 25-31), but KlavisSandboxRolloutProcessor lacks this protection. The import and __all__ export at line 35 both need to be made conditional like the other optional dependencies.

Additional Locations (1)

Fix in Cursor Fix in Web

from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
from .default_no_op_rollout_processor import NoOpRolloutProcessor
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
Expand Down Expand Up @@ -31,6 +32,7 @@

__all__ = [
"AgentRolloutProcessor",
"KlavisSandboxRolloutProcessor",
"MCPGymRolloutProcessor",
"RolloutProcessor",
"SingleTurnRolloutProcessor",
Expand Down
174 changes: 174 additions & 0 deletions eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import asyncio
import json
import logging
import os
import tempfile
import time
from typing import Any, Callable, Dict, List, Optional

from pydantic import BaseModel, Field

from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig

from eval_protocol.pytest.default_agent_rollout_processor import Agent
from klavis import Klavis
from klavis.types import CreateSandboxResponse, SandboxMcpServer
from openai.types import CompletionUsage

logger = logging.getLogger(__name__)


class KlavisSandboxRolloutProcessor(RolloutProcessor):
def __init__(
self,
server_name: str,
initialize_data_factory: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None,
):
super().__init__()
self.server_name = server_name
self.initialize_data_factory = initialize_data_factory
self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY"))

def _init_sandbox(self) -> CreateSandboxResponse:
try:
server_name_enum = SandboxMcpServer(self.server_name)
return self.klavis_client.sandbox.create_sandbox(server_name=server_name_enum)
except Exception as e:
logger.error(f"Error creating sandbox: {str(e)}", exc_info=True)
raise

@staticmethod
def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str:
"""Create a temporary MCP config file and return its path."""
config = {
"mcpServers": {
server_key: {
"url": server_url,
"transport": "streamable_http",
**({"authorization": f"Bearer {auth_token}"} if auth_token else {})
}
}
}

# Create a temp file that persists for the session
fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_")
with os.fdopen(fd, 'w') as f:
json.dump(config, f)
return path

def __call__(
self, rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[asyncio.Task[EvaluationRow]]:
"""Process evaluation rows with Klavis sandbox lifecycle management"""
semaphore = config.semaphore

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with complete sandbox lifecycle"""

start_time = time.perf_counter()
agent: Agent | None = None
temp_config_path: str | None = None
sandbox: CreateSandboxResponse | None = None

try:
# Step 0: Create a sandbox for this row
sandbox = self._init_sandbox()
logger.info(f"Sandbox created: {sandbox}")

# Step 1: Initialize data in the sandbox
init_data: Dict[str, Any] | None = None
if self.initialize_data_factory:
init_data = self.initialize_data_factory(row)
else:
# Allow datasets to provide initialization payload directly
init_data = (
(row.input_metadata.session_data or {}).get("initialize_data")
if row.input_metadata is not None
else None
)

if init_data:
logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}")
initialize_method = getattr(
self.klavis_client.sandbox, f"initialize_{sandbox.server_name.value}_sandbox"
)
init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data)
logger.info(f"Initialization response: {init_response}")

# Step 2: Create temporary MCP config with sandbox URL
temp_config_path = self.create_mcp_config(
server_url=sandbox.server_url, server_key=sandbox.server_name.value
)
logger.info(f"MCP config created: {temp_config_path}")

# Step 3: Run agent with sandbox MCP server
logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox")
agent = Agent(
model=row.input_metadata.completion_params["model"],
row=row,
config_path=temp_config_path,
logger=config.logger,
)
await agent.setup()
await agent.call_agent()

# Update usage metadata
row.execution_metadata.usage = CompletionUsage(
prompt_tokens=agent.usage.get("prompt_tokens", 0),
completion_tokens=agent.usage.get("completion_tokens", 0),
total_tokens=agent.usage.get("total_tokens", 0),
)
row = agent.evaluation_row
logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}")

# Step 4: Export sandbox data
dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox")
dump_response = dump_method(sandbox_id=sandbox.sandbox_id)
sandbox_data = dump_response.data
logger.info(f"Sandbox data: {sandbox_data}")

# Store sandbox data in row metadata for evaluation
if not row.execution_metadata.extra:
row.execution_metadata.extra = {}
row.execution_metadata.extra["sandbox_data"] = sandbox_data
row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id
row.execution_metadata.extra["server_name"] = self.server_name

except Exception as e:
logger.error(f"Error processing row {row.execution_metadata.rollout_id}: {str(e)}", exc_info=True)
if not row.execution_metadata.extra:
row.execution_metadata.extra = {}
row.execution_metadata.extra["error"] = str(e)
raise

finally:
# Cleanup agent MCP client and temp config
if agent and agent.mcp_client:
await agent.mcp_client.cleanup()
if temp_config_path and os.path.exists(temp_config_path):
os.unlink(temp_config_path)

# Release sandbox
if sandbox and sandbox.sandbox_id:
try:
self.klavis_client.sandbox.delete_sandbox(
server_name=sandbox.server_name, sandbox_id=sandbox.sandbox_id
)
logger.info(f"Sandbox {sandbox.sandbox_id} released successfully")
except Exception as e:
logger.error(f"Error releasing sandbox {sandbox.sandbox_id}: {str(e)}", exc_info=True)

row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time

return row

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
result = await process_row(r)
return result

# Create and return tasks
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ openenv = [
dspy = [
"dspy>=3.0.0",
]
klavis = [
"klavis>=2.18.0",
]

# Optional deps for LangGraph example/tests
langgraph = [
Expand Down
2 changes: 2 additions & 0 deletions tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"initialize_data": {"messages": [{"subject": "Project Update", "to": "[email protected]", "body": "The project is progressing well. We should have the final deliverables by next week.", "cc": "", "bcc": "", "from": "[email protected]", "reply_to": "", "labels": ["INBOX"]}, {"subject": "Spam Newsletter", "to": "[email protected]", "body": "Check out our amazing deals! Click here now!", "cc": "", "bcc": "", "from": "[email protected]", "reply_to": "", "labels": ["INBOX"]}], "drafts": []}, "messages": "Please delete the email with subject \"Spam Newsletter\" from my inbox.", "ground_truth": {"messages": [{"subject": "Project Update", "to": "[email protected]", "body": "The project is progressing well. We should have the final deliverables by next week.", "cc": "", "bcc": "", "from": "[email protected]", "reply_to": "", "labels": ["INBOX"]}], "drafts": []}}
{"initialize_data": {"messages": [], "drafts": []}, "messages": "Please directly send an email to [email protected] with subject \"Meeting Tomorrow\" and body \"Hi Zihao, just confirming our meeting tomorrow at 2pm. Best regards.\"", "ground_truth": {"messages": [{"subject": "Meeting Tomorrow", "to": "[email protected]", "body": "Hi Zihao, just confirming our meeting tomorrow at 2pm. Best regards.", "cc": "", "bcc": "", "from": "", "reply_to": "", "labels": ["SENT"]}], "drafts": []}}
142 changes: 142 additions & 0 deletions tests/pytest/test_pytest_klavis_sandbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
import logging
import os

from eval_protocol.models import EvaluateResult, EvaluationRow, Message
from eval_protocol.pytest import KlavisSandboxRolloutProcessor, evaluation_test
from openai import AsyncOpenAI
from pydantic import BaseModel

logger = logging.getLogger(__name__)


class ResponseFormat(BaseModel):
score: float


def klavis_gmail_sandbox_dataset_adapter(rows: list[dict]) -> list[EvaluationRow]:
"""Dataset adapter for sandbox JSONL rows.

Supports the new schema:
- initialize_data: dict (passed to Klavis sandbox initializer)
- messages: str (task instruction)
- ground_truth: dict (expected final sandbox state)

"""
adapted: list[EvaluationRow] = []
system_prompt = (
"You are a helpful assistant with access to Gmail. "
"You can send emails, draft emails, and manage messages, etc."
)

for r in rows:
if isinstance(r.get("messages"), str) and "initialize_data" in r:
init_data = r.get("initialize_data") or {}
task = r.get("messages") or ""
ground_truth = r.get("ground_truth")

row = EvaluationRow(
messages=[
Message(role="system", content=system_prompt),
Message(role="user", content=task),
],
ground_truth=ground_truth,
)
row.input_metadata.session_data = {
"initialize_data": init_data,
"task": task,
}
adapted.append(row)
else:
adapted.append(EvaluationRow(**r))

return adapted


@evaluation_test(
input_dataset=["tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl"],
rollout_processor=KlavisSandboxRolloutProcessor(
server_name="gmail",
),
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-thinking"}],
mode="pointwise",
dataset_adapter=klavis_gmail_sandbox_dataset_adapter,
)
async def test_pytest_gmail_sandbox(row: EvaluationRow) -> EvaluationRow:
"""
Evaluate Gmail sandbox results by comparing with ground truth using LLM judge.

The sandbox data is exported after agent execution and compared with expected output.
Sandbox data is available in row.execution_metadata.extra["sandbox_data"].
"""
ground_truth = row.ground_truth
sandbox_data = row.execution_metadata.extra.get("sandbox_data", {}) if row.execution_metadata.extra else {}
final_message = row.messages[-1].content if row.messages else ""
initialize_data = (row.input_metadata.session_data or {}).get("initialize_data", {})
task = (row.input_metadata.session_data or {}).get("task", "")

logger.info(f"Evaluating row {row.execution_metadata.rollout_id}")
logger.info(f"Final message: {final_message}")
logger.info(f"Sandbox data: {json.dumps(sandbox_data, indent=2, default=str)}")
logger.info(f"Ground truth: {ground_truth}")

async with AsyncOpenAI(
api_key=os.environ["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1"
) as client:

evaluation_prompt = f"""You are evaluating an AI agent's performance on a Gmail sandbox task.

Task:
{task or (row.messages[-1].content if row.messages else 'N/A')}

Initial Gmail Sandbox State (initialize_data):
{json.dumps(initialize_data, indent=2, default=str)}

Expected Final Gmail Sandbox State (ground_truth):
{json.dumps(ground_truth, indent=2, default=str)}

Gmail Sandbox State After Execution:
{json.dumps(sandbox_data, indent=2, default=str)}

Evaluate whether the agent successfully completed the task by checking:
1. Does the final sandbox state match the expected ground_truth state?
2. If there are small formatting differences, judge semantically
3. Use the initial state only as context; the key is whether the correct changes happened.

Return:
- score: 1.0 if task completed successfully, 0.5 if partially completed, 0.0 if failed

"""

try:
response = await client.chat.completions.create(
model="accounts/fireworks/models/kimi-k2-thinking",
messages=[
{
"role": "system",
"content": "You are a precise evaluator of AI agent performance. Analyze the task, execution, and results carefully.",
},
{"role": "user", "content": evaluation_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {"name": "ResponseFormat", "schema": ResponseFormat.model_json_schema()},
},
temperature=0.0,
)

response_text = response.choices[0].message.content
logger.info(f"LLM judge response: {response_text}")

parsed = json.loads(response_text or "{}")
score = parsed.get("score", 0.0)

row.evaluation_result = EvaluateResult(score=score)
except Exception as e:
logger.error(f"Error during LLM evaluation: {str(e)}", exc_info=True)
row.evaluation_result = EvaluateResult(
score=0.0,
reason=f"Evaluation error: {str(e)}",
)

return row
21 changes: 20 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.