From a5f3cb34a76f937933f7ba9a30b1fdf4769a32b9 Mon Sep 17 00:00:00 2001 From: Zihao Lin Date: Thu, 25 Dec 2025 10:20:33 -0800 Subject: [PATCH 1/2] init --- eval_protocol/pytest/__init__.py | 2 + ...efault_klavis_sandbox_rollout_processor.py | 157 ++++++++++++++++++ pyproject.toml | 3 + .../datasets/klavis_gmail_sandbox_test.jsonl | 2 + tests/pytest/test_pytest_klavis_sandbox.py | 104 ++++++++++++ uv.lock | 21 ++- 6 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py create mode 100644 tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl create mode 100644 tests/pytest/test_pytest_klavis_sandbox.py diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 26485f43..c600dc71 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -1,5 +1,6 @@ from .default_agent_rollout_processor import AgentRolloutProcessor from .default_dataset_adapter import default_dataset_adapter +from .default_klavis_sandbox_rollout_processor import KlavisSandboxRolloutProcessor from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor from .default_no_op_rollout_processor import NoOpRolloutProcessor from .default_single_turn_rollout_process import SingleTurnRolloutProcessor @@ -31,6 +32,7 @@ __all__ = [ "AgentRolloutProcessor", + "KlavisSandboxRolloutProcessor", "MCPGymRolloutProcessor", "RolloutProcessor", "SingleTurnRolloutProcessor", diff --git a/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py b/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py new file mode 100644 index 00000000..c865f3ae --- /dev/null +++ b/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py @@ -0,0 +1,157 @@ +import asyncio +import json +import logging +import os +import tempfile +import time +from typing import Any, Callable, Dict, List, Optional + +from pydantic import BaseModel, Field + +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + +from eval_protocol.pytest.default_agent_rollout_processor import Agent +from klavis import Klavis +from klavis.types import CreateSandboxResponse, SandboxMcpServer +from openai.types import CompletionUsage + +logger = logging.getLogger(__name__) + + +class KlavisSandboxRolloutProcessor(RolloutProcessor): + def __init__( + self, + server_name: str, + initialize_data_factory: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None, + ): + super().__init__() + self.server_name = server_name + self.initialize_data_factory = initialize_data_factory + self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY")) + self.sandbox = self._init_sandbox() + + def _init_sandbox(self) -> CreateSandboxResponse: + try: + server_name_enum = SandboxMcpServer(self.server_name) + return self.klavis_client.sandbox.create_sandbox(server_name=server_name_enum) + except Exception as e: + logger.error(f"Error creating sandbox: {str(e)}", exc_info=True) + raise + + @staticmethod + def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str: + """Create a temporary MCP config file and return its path.""" + config = { + "mcpServers": { + server_key: { + "url": server_url, + "transport": "streamable_http", + **({"authorization": f"Bearer {auth_token}"} if auth_token else {}) + } + } + } + + # Create a temp file that persists for the session + fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_") + with os.fdopen(fd, 'w') as f: + json.dump(config, f) + return path + + def __call__( + self, rows: List[EvaluationRow], config: RolloutProcessorConfig + ) -> List[asyncio.Task[EvaluationRow]]: + """Process evaluation rows with Klavis sandbox lifecycle management""" + if not self.sandbox: + raise RuntimeError("Sandbox not initialized") + + semaphore = config.semaphore + + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row with complete sandbox lifecycle""" + + start_time = time.perf_counter() + + try: + # Step 1: Initialize data in the sandbox + if self.initialize_data_factory: + logger.info(f"Initializing {self.server_name} sandbox {self.sandbox.sandbox_id}") + init_data = self.initialize_data_factory(row) + initialize_method = getattr(self.klavis_client.sandbox, f"initialize_{self.sandbox.server_name}_sandbox") + initialize_method(sandbox_id=self.sandbox.sandbox_id, **init_data) + logger.info(f"Sandbox initialized successfully") + + # Step 2: Create temporary MCP config with sandbox URL + temp_config_path = self.create_mcp_config(server_url=self.sandbox.server_url, server_key=self.sandbox.server_name) + + # Step 3: Run agent with sandbox MCP server + logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox") + agent = Agent( + model=row.input_metadata.completion_params["model"], + row=row, + config_path=temp_config_path, + logger=config.logger, + ) + await agent.setup() + await agent.call_agent() + + # Update usage metadata + row.execution_metadata.usage = CompletionUsage( + prompt_tokens=agent.usage.get("prompt_tokens", 0), + completion_tokens=agent.usage.get("completion_tokens", 0), + total_tokens=agent.usage.get("total_tokens", 0), + ) + row = agent.evaluation_row + logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}") + + # Step 4: Export sandbox data + logger.info(f"Exporting {self.server_name} sandbox data") + dump_method = getattr(self.klavis_client.sandbox, f"dump_{self.sandbox.server_name}_sandbox") + dump_response = dump_method(sandbox_id=self.sandbox.sandbox_id) + sandbox_data = dump_response.data + + # Store sandbox data in row metadata for evaluation + if not row.execution_metadata.extra: + row.execution_metadata.extra = {} + row.execution_metadata.extra["sandbox_data"] = sandbox_data + row.execution_metadata.extra["sandbox_id"] = self.sandbox.sandbox_id + row.execution_metadata.extra["server_name"] = self.server_name + + except Exception as e: + logger.error(f"Error processing row {row.execution_metadata.rollout_id}: {str(e)}", exc_info=True) + if not row.execution_metadata.extra: + row.execution_metadata.extra = {} + row.execution_metadata.extra["error"] = str(e) + raise + + finally: + # Cleanup agent MCP client and temp config + if agent and agent.mcp_client: + await agent.mcp_client.cleanup() + if temp_config_path and os.path.exists(temp_config_path): + os.unlink(temp_config_path) + + # Release sandbox + if self.sandbox.sandbox_id: + try: + logger.info(f"Releasing {self.server_name} sandbox {self.sandbox.sandbox_id}") + self.klavis_client.sandbox.delete_sandbox( + server_name=self.sandbox.server_name, sandbox_id=self.sandbox.sandbox_id + ) + logger.info(f"Sandbox {self.sandbox.sandbox_id} released successfully") + except Exception as e: + logger.error(f"Error releasing sandbox {self.sandbox.sandbox_id}: {str(e)}", exc_info=True) + + row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time + + return row + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + result = await process_row(r) + return result + + # Create and return tasks + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + return tasks diff --git a/pyproject.toml b/pyproject.toml index 400e8d40..216709ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,9 @@ openenv = [ dspy = [ "dspy>=3.0.0", ] +klavis = [ + "klavis>=2.18.0", +] # Optional deps for LangGraph example/tests langgraph = [ diff --git a/tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl b/tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl new file mode 100644 index 00000000..088fbc60 --- /dev/null +++ b/tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl @@ -0,0 +1,2 @@ +{"messages": [{"role": "system", "content": "You are a helpful assistant with access to Gmail. You can send emails, draft emails, and manage messages."}, {"role": "user", "content": "Send an email to john@example.com with subject 'Meeting Tomorrow' and body 'Hi John, Just confirming our meeting tomorrow at 2pm. Best regards.'"}], "ground_truth": "One email sent to john@example.com with subject 'Meeting Tomorrow' containing meeting confirmation"} +{"messages": [{"role": "system", "content": "You are a helpful assistant with access to Gmail. You can send emails, draft emails, and manage messages."}, {"role": "user", "content": "Draft an email to sarah@company.com with subject 'Project Update' and body 'Hi Sarah, The project is progressing well. I will send you the detailed report by Friday.'"}], "ground_truth": "One draft email created for sarah@company.com with subject 'Project Update' about project progress"} diff --git a/tests/pytest/test_pytest_klavis_sandbox.py b/tests/pytest/test_pytest_klavis_sandbox.py new file mode 100644 index 00000000..79431b59 --- /dev/null +++ b/tests/pytest/test_pytest_klavis_sandbox.py @@ -0,0 +1,104 @@ +import json +import logging +import os + +from eval_protocol.models import EvaluateResult, EvaluationRow +from eval_protocol.pytest import KlavisSandboxRolloutProcessor, evaluation_test +from openai import AsyncOpenAI +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class ResponseFormat(BaseModel): + score: float + reasoning: str + + +@evaluation_test( + input_dataset=["tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl"], + rollout_processor=KlavisSandboxRolloutProcessor( + server_name="gmail", + # Optional: provide custom initialization data factory + # initialize_data_factory=lambda row: {"messages": [], "drafts": []}, + ), + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p2"}], + mode="pointwise", +) +async def test_pytest_gmail_sandbox(row: EvaluationRow) -> EvaluationRow: + """ + Evaluate Gmail sandbox results by comparing with ground truth using LLM judge. + + The sandbox data is exported after agent execution and compared with expected output. + Sandbox data is available in row.execution_metadata.metadata["sandbox_data"]. + """ + ground_truth = row.ground_truth + sandbox_data = row.execution_metadata.extra.get("sandbox_data", {}) if row.execution_metadata.extra else {} + final_message = row.messages[-1].content if row.messages else "" + + logger.info(f"Evaluating row {row.execution_metadata.rollout_id}") + logger.info(f"Final message: {final_message}") + logger.info(f"Sandbox data: {json.dumps(sandbox_data, indent=2, default=str)}") + logger.info(f"Ground truth: {ground_truth}") + + async with AsyncOpenAI( + api_key=os.environ["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1" + ) as client: + # Use LLM to judge if the sandbox data matches the ground truth + evaluation_prompt = f"""You are evaluating an AI agent's performance on a Gmail task. + +Task: {row.messages[0].content if row.messages else 'N/A'} + +Ground Truth: {ground_truth} + +Agent's Final Response: {final_message} + +Gmail Sandbox State After Execution: +{json.dumps(sandbox_data, indent=2, default=str)} + +Evaluate whether the agent successfully completed the task by checking: +1. Did the agent understand and attempt the task? +2. Does the sandbox data reflect the expected outcome described in the ground truth? +3. Are there any emails sent/drafted that match the task requirements? + +Return: +- score: 1.0 if task completed successfully, 0.5 if partially completed, 0.0 if failed +- reasoning: Explain your evaluation in 1-2 sentences +""" + + try: + response = await client.chat.completions.create( + model="accounts/fireworks/models/deepseek-v3p2", + messages=[ + { + "role": "system", + "content": "You are a precise evaluator of AI agent performance. Analyze the task, execution, and results carefully.", + }, + {"role": "user", "content": evaluation_prompt}, + ], + response_format={ + "type": "json_schema", + "json_schema": {"name": "ResponseFormat", "schema": ResponseFormat.model_json_schema()}, + }, + temperature=0.0, + ) + + response_text = response.choices[0].message.content + logger.info(f"LLM judge response: {response_text}") + + parsed = json.loads(response_text or "{}") + score = parsed.get("score", 0.0) + reasoning = parsed.get("reasoning", "No reasoning provided") + + row.evaluation_result = EvaluateResult( + score=score, + reason=reasoning, + ) + except Exception as e: + logger.error(f"Error during LLM evaluation: {str(e)}", exc_info=True) + row.evaluation_result = EvaluateResult( + score=0.0, + reason=f"Evaluation error: {str(e)}", + ) + + return row diff --git a/uv.lock b/uv.lock index 972f90b1..03ac7096 100644 --- a/uv.lock +++ b/uv.lock @@ -1236,6 +1236,9 @@ huggingface = [ { name = "datasets" }, { name = "transformers" }, ] +klavis = [ + { name = "klavis" }, +] langchain = [ { name = "langchain-core" }, ] @@ -1319,6 +1322,7 @@ requires-dist = [ { name = "hydra-core", specifier = ">=1.3.2" }, { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.30.0" }, { name = "jupyter", marker = "extra == 'dev'", specifier = ">=1.1.1" }, + { name = "klavis", marker = "extra == 'klavis'", specifier = ">=2.18.0" }, { name = "langchain", marker = "extra == 'langgraph-tools'", specifier = ">=0.3.0" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = ">=0.3.0" }, { name = "langchain-core", marker = "extra == 'langgraph'", specifier = ">=0.3.75" }, @@ -1379,7 +1383,7 @@ requires-dist = [ { name = "websockets", specifier = ">=15.0.1" }, { name = "werkzeug", marker = "extra == 'dev'", specifier = ">=2.0.0" }, ] -provides-extras = ["dev", "trl", "openevals", "box2d", "langfuse", "huggingface", "langsmith", "bigquery", "svgbench", "pydantic", "supabase", "chinook", "langchain", "braintrust", "openenv", "dspy", "langgraph", "langgraph-tools", "proxy"] +provides-extras = ["dev", "trl", "openevals", "box2d", "langfuse", "huggingface", "langsmith", "bigquery", "svgbench", "pydantic", "supabase", "chinook", "langchain", "braintrust", "openenv", "dspy", "klavis", "langgraph", "langgraph-tools", "proxy"] [package.metadata.requires-dev] dev = [ @@ -2911,6 +2915,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d3/32/da7f44bcb1105d3e88a0b74ebdca50c59121d2ddf71c9e34ba47df7f3a56/keyring-25.6.0-py3-none-any.whl", hash = "sha256:552a3f7af126ece7ed5c89753650eec89c7eaae8617d0aa4d9ad2b75111266bd", size = 39085, upload-time = "2024-12-25T15:26:44.377Z" }, ] +[[package]] +name = "klavis" +version = "2.18.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "pydantic" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/36/14fbb41e800b4af91ddca293e54428b7be1dd51503272ef8e77347922868/klavis-2.18.0.tar.gz", hash = "sha256:5dd6e8ab3523008889729e0095fb2055d427e3bb91c9da2d2e2594db685bbb08", size = 149469, upload-time = "2025-12-12T20:35:56.224Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/c6/8eccfcf538c18c6bd4a5aa3a62af8ba3d9ea22d799df06a604ed87e84c67/klavis-2.18.0-py3-none-any.whl", hash = "sha256:33e7d855bd1b8526859391341c44808fb7da5beb2ad31d7c4a8e383ca1e3846b", size = 335786, upload-time = "2025-12-12T20:35:54.631Z" }, +] + [[package]] name = "langchain" version = "0.3.27" From 7f35e2487eccdb8badcf6c765b3d3ca2d134b734 Mon Sep 17 00:00:00 2001 From: Zihao Lin Date: Thu, 25 Dec 2025 12:41:00 -0800 Subject: [PATCH 2/2] update logic --- ...efault_klavis_sandbox_rollout_processor.py | 55 ++++++++----- .../datasets/klavis_gmail_sandbox_test.jsonl | 4 +- tests/pytest/test_pytest_klavis_sandbox.py | 80 ++++++++++++++----- 3 files changed, 97 insertions(+), 42 deletions(-) diff --git a/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py b/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py index c865f3ae..27d44b80 100644 --- a/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py +++ b/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py @@ -30,7 +30,6 @@ def __init__( self.server_name = server_name self.initialize_data_factory = initialize_data_factory self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY")) - self.sandbox = self._init_sandbox() def _init_sandbox(self) -> CreateSandboxResponse: try: @@ -63,27 +62,46 @@ def __call__( self, rows: List[EvaluationRow], config: RolloutProcessorConfig ) -> List[asyncio.Task[EvaluationRow]]: """Process evaluation rows with Klavis sandbox lifecycle management""" - if not self.sandbox: - raise RuntimeError("Sandbox not initialized") - semaphore = config.semaphore async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with complete sandbox lifecycle""" start_time = time.perf_counter() + agent: Agent | None = None + temp_config_path: str | None = None + sandbox: CreateSandboxResponse | None = None try: + # Step 0: Create a sandbox for this row + sandbox = self._init_sandbox() + logger.info(f"Sandbox created: {sandbox}") + # Step 1: Initialize data in the sandbox + init_data: Dict[str, Any] | None = None if self.initialize_data_factory: - logger.info(f"Initializing {self.server_name} sandbox {self.sandbox.sandbox_id}") init_data = self.initialize_data_factory(row) - initialize_method = getattr(self.klavis_client.sandbox, f"initialize_{self.sandbox.server_name}_sandbox") - initialize_method(sandbox_id=self.sandbox.sandbox_id, **init_data) - logger.info(f"Sandbox initialized successfully") - + else: + # Allow datasets to provide initialization payload directly + init_data = ( + (row.input_metadata.session_data or {}).get("initialize_data") + if row.input_metadata is not None + else None + ) + + if init_data: + logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}") + initialize_method = getattr( + self.klavis_client.sandbox, f"initialize_{sandbox.server_name.value}_sandbox" + ) + init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data) + logger.info(f"Initialization response: {init_response}") + # Step 2: Create temporary MCP config with sandbox URL - temp_config_path = self.create_mcp_config(server_url=self.sandbox.server_url, server_key=self.sandbox.server_name) + temp_config_path = self.create_mcp_config( + server_url=sandbox.server_url, server_key=sandbox.server_name.value + ) + logger.info(f"MCP config created: {temp_config_path}") # Step 3: Run agent with sandbox MCP server logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox") @@ -106,16 +124,16 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}") # Step 4: Export sandbox data - logger.info(f"Exporting {self.server_name} sandbox data") - dump_method = getattr(self.klavis_client.sandbox, f"dump_{self.sandbox.server_name}_sandbox") - dump_response = dump_method(sandbox_id=self.sandbox.sandbox_id) + dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox") + dump_response = dump_method(sandbox_id=sandbox.sandbox_id) sandbox_data = dump_response.data + logger.info(f"Sandbox data: {sandbox_data}") # Store sandbox data in row metadata for evaluation if not row.execution_metadata.extra: row.execution_metadata.extra = {} row.execution_metadata.extra["sandbox_data"] = sandbox_data - row.execution_metadata.extra["sandbox_id"] = self.sandbox.sandbox_id + row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id row.execution_metadata.extra["server_name"] = self.server_name except Exception as e: @@ -133,15 +151,14 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: os.unlink(temp_config_path) # Release sandbox - if self.sandbox.sandbox_id: + if sandbox and sandbox.sandbox_id: try: - logger.info(f"Releasing {self.server_name} sandbox {self.sandbox.sandbox_id}") self.klavis_client.sandbox.delete_sandbox( - server_name=self.sandbox.server_name, sandbox_id=self.sandbox.sandbox_id + server_name=sandbox.server_name, sandbox_id=sandbox.sandbox_id ) - logger.info(f"Sandbox {self.sandbox.sandbox_id} released successfully") + logger.info(f"Sandbox {sandbox.sandbox_id} released successfully") except Exception as e: - logger.error(f"Error releasing sandbox {self.sandbox.sandbox_id}: {str(e)}", exc_info=True) + logger.error(f"Error releasing sandbox {sandbox.sandbox_id}: {str(e)}", exc_info=True) row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time diff --git a/tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl b/tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl index 088fbc60..8aefb76e 100644 --- a/tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl +++ b/tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl @@ -1,2 +1,2 @@ -{"messages": [{"role": "system", "content": "You are a helpful assistant with access to Gmail. You can send emails, draft emails, and manage messages."}, {"role": "user", "content": "Send an email to john@example.com with subject 'Meeting Tomorrow' and body 'Hi John, Just confirming our meeting tomorrow at 2pm. Best regards.'"}], "ground_truth": "One email sent to john@example.com with subject 'Meeting Tomorrow' containing meeting confirmation"} -{"messages": [{"role": "system", "content": "You are a helpful assistant with access to Gmail. You can send emails, draft emails, and manage messages."}, {"role": "user", "content": "Draft an email to sarah@company.com with subject 'Project Update' and body 'Hi Sarah, The project is progressing well. I will send you the detailed report by Friday.'"}], "ground_truth": "One draft email created for sarah@company.com with subject 'Project Update' about project progress"} +{"initialize_data": {"messages": [{"subject": "Project Update", "to": "zihao@klavisai.com", "body": "The project is progressing well. We should have the final deliverables by next week.", "cc": "", "bcc": "", "from": "sarah@klavisai.com", "reply_to": "", "labels": ["INBOX"]}, {"subject": "Spam Newsletter", "to": "zihao@klavisai.com", "body": "Check out our amazing deals! Click here now!", "cc": "", "bcc": "", "from": "marketing@spammy.com", "reply_to": "", "labels": ["INBOX"]}], "drafts": []}, "messages": "Please delete the email with subject \"Spam Newsletter\" from my inbox.", "ground_truth": {"messages": [{"subject": "Project Update", "to": "zihao@klavisai.com", "body": "The project is progressing well. We should have the final deliverables by next week.", "cc": "", "bcc": "", "from": "sarah@klavisai.com", "reply_to": "", "labels": ["INBOX"]}], "drafts": []}} +{"initialize_data": {"messages": [], "drafts": []}, "messages": "Please directly send an email to zihao@klavisai.com with subject \"Meeting Tomorrow\" and body \"Hi Zihao, just confirming our meeting tomorrow at 2pm. Best regards.\"", "ground_truth": {"messages": [{"subject": "Meeting Tomorrow", "to": "zihao@klavisai.com", "body": "Hi Zihao, just confirming our meeting tomorrow at 2pm. Best regards.", "cc": "", "bcc": "", "from": "", "reply_to": "", "labels": ["SENT"]}], "drafts": []}} diff --git a/tests/pytest/test_pytest_klavis_sandbox.py b/tests/pytest/test_pytest_klavis_sandbox.py index 79431b59..7ae84bc3 100644 --- a/tests/pytest/test_pytest_klavis_sandbox.py +++ b/tests/pytest/test_pytest_klavis_sandbox.py @@ -2,7 +2,7 @@ import logging import os -from eval_protocol.models import EvaluateResult, EvaluationRow +from eval_protocol.models import EvaluateResult, EvaluationRow, Message from eval_protocol.pytest import KlavisSandboxRolloutProcessor, evaluation_test from openai import AsyncOpenAI from pydantic import BaseModel @@ -12,29 +12,68 @@ class ResponseFormat(BaseModel): score: float - reasoning: str + + +def klavis_gmail_sandbox_dataset_adapter(rows: list[dict]) -> list[EvaluationRow]: + """Dataset adapter for sandbox JSONL rows. + + Supports the new schema: + - initialize_data: dict (passed to Klavis sandbox initializer) + - messages: str (task instruction) + - ground_truth: dict (expected final sandbox state) + + """ + adapted: list[EvaluationRow] = [] + system_prompt = ( + "You are a helpful assistant with access to Gmail. " + "You can send emails, draft emails, and manage messages, etc." + ) + + for r in rows: + if isinstance(r.get("messages"), str) and "initialize_data" in r: + init_data = r.get("initialize_data") or {} + task = r.get("messages") or "" + ground_truth = r.get("ground_truth") + + row = EvaluationRow( + messages=[ + Message(role="system", content=system_prompt), + Message(role="user", content=task), + ], + ground_truth=ground_truth, + ) + row.input_metadata.session_data = { + "initialize_data": init_data, + "task": task, + } + adapted.append(row) + else: + adapted.append(EvaluationRow(**r)) + + return adapted @evaluation_test( input_dataset=["tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl"], rollout_processor=KlavisSandboxRolloutProcessor( server_name="gmail", - # Optional: provide custom initialization data factory - # initialize_data_factory=lambda row: {"messages": [], "drafts": []}, ), - completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p2"}], + completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-thinking"}], mode="pointwise", + dataset_adapter=klavis_gmail_sandbox_dataset_adapter, ) async def test_pytest_gmail_sandbox(row: EvaluationRow) -> EvaluationRow: """ Evaluate Gmail sandbox results by comparing with ground truth using LLM judge. The sandbox data is exported after agent execution and compared with expected output. - Sandbox data is available in row.execution_metadata.metadata["sandbox_data"]. + Sandbox data is available in row.execution_metadata.extra["sandbox_data"]. """ ground_truth = row.ground_truth sandbox_data = row.execution_metadata.extra.get("sandbox_data", {}) if row.execution_metadata.extra else {} final_message = row.messages[-1].content if row.messages else "" + initialize_data = (row.input_metadata.session_data or {}).get("initialize_data", {}) + task = (row.input_metadata.session_data or {}).get("task", "") logger.info(f"Evaluating row {row.execution_metadata.rollout_id}") logger.info(f"Final message: {final_message}") @@ -44,31 +83,34 @@ async def test_pytest_gmail_sandbox(row: EvaluationRow) -> EvaluationRow: async with AsyncOpenAI( api_key=os.environ["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1" ) as client: - # Use LLM to judge if the sandbox data matches the ground truth - evaluation_prompt = f"""You are evaluating an AI agent's performance on a Gmail task. -Task: {row.messages[0].content if row.messages else 'N/A'} + evaluation_prompt = f"""You are evaluating an AI agent's performance on a Gmail sandbox task. + +Task: +{task or (row.messages[-1].content if row.messages else 'N/A')} -Ground Truth: {ground_truth} +Initial Gmail Sandbox State (initialize_data): +{json.dumps(initialize_data, indent=2, default=str)} -Agent's Final Response: {final_message} +Expected Final Gmail Sandbox State (ground_truth): +{json.dumps(ground_truth, indent=2, default=str)} Gmail Sandbox State After Execution: {json.dumps(sandbox_data, indent=2, default=str)} Evaluate whether the agent successfully completed the task by checking: -1. Did the agent understand and attempt the task? -2. Does the sandbox data reflect the expected outcome described in the ground truth? -3. Are there any emails sent/drafted that match the task requirements? +1. Does the final sandbox state match the expected ground_truth state? +2. If there are small formatting differences, judge semantically +3. Use the initial state only as context; the key is whether the correct changes happened. Return: - score: 1.0 if task completed successfully, 0.5 if partially completed, 0.0 if failed -- reasoning: Explain your evaluation in 1-2 sentences + """ try: response = await client.chat.completions.create( - model="accounts/fireworks/models/deepseek-v3p2", + model="accounts/fireworks/models/kimi-k2-thinking", messages=[ { "role": "system", @@ -88,12 +130,8 @@ async def test_pytest_gmail_sandbox(row: EvaluationRow) -> EvaluationRow: parsed = json.loads(response_text or "{}") score = parsed.get("score", 0.0) - reasoning = parsed.get("reasoning", "No reasoning provided") - row.evaluation_result = EvaluateResult( - score=score, - reason=reasoning, - ) + row.evaluation_result = EvaluateResult(score=score) except Exception as e: logger.error(f"Error during LLM evaluation: {str(e)}", exc_info=True) row.evaluation_result = EvaluateResult(