-
Notifications
You must be signed in to change notification settings - Fork 10
Klavis Sandbox on Fireworks EP #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zihaolin96
wants to merge
2
commits into
eval-protocol:main
Choose a base branch
from
zihaolin96:zihao/klavis_sandbox
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
174 changes: 174 additions & 0 deletions
174
eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| import asyncio | ||
| import json | ||
| import logging | ||
| import os | ||
| import tempfile | ||
| import time | ||
| from typing import Any, Callable, Dict, List, Optional | ||
|
|
||
| from pydantic import BaseModel, Field | ||
|
|
||
| from eval_protocol.models import EvaluationRow | ||
| from eval_protocol.pytest.rollout_processor import RolloutProcessor | ||
| from eval_protocol.pytest.types import RolloutProcessorConfig | ||
|
|
||
| from eval_protocol.pytest.default_agent_rollout_processor import Agent | ||
| from klavis import Klavis | ||
| from klavis.types import CreateSandboxResponse, SandboxMcpServer | ||
| from openai.types import CompletionUsage | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class KlavisSandboxRolloutProcessor(RolloutProcessor): | ||
| def __init__( | ||
| self, | ||
| server_name: str, | ||
| initialize_data_factory: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None, | ||
| ): | ||
| super().__init__() | ||
| self.server_name = server_name | ||
| self.initialize_data_factory = initialize_data_factory | ||
| self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY")) | ||
|
|
||
| def _init_sandbox(self) -> CreateSandboxResponse: | ||
| try: | ||
| server_name_enum = SandboxMcpServer(self.server_name) | ||
| return self.klavis_client.sandbox.create_sandbox(server_name=server_name_enum) | ||
| except Exception as e: | ||
| logger.error(f"Error creating sandbox: {str(e)}", exc_info=True) | ||
| raise | ||
|
|
||
| @staticmethod | ||
| def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str: | ||
| """Create a temporary MCP config file and return its path.""" | ||
| config = { | ||
| "mcpServers": { | ||
| server_key: { | ||
| "url": server_url, | ||
| "transport": "streamable_http", | ||
| **({"authorization": f"Bearer {auth_token}"} if auth_token else {}) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| # Create a temp file that persists for the session | ||
| fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_") | ||
| with os.fdopen(fd, 'w') as f: | ||
| json.dump(config, f) | ||
| return path | ||
|
|
||
| def __call__( | ||
| self, rows: List[EvaluationRow], config: RolloutProcessorConfig | ||
| ) -> List[asyncio.Task[EvaluationRow]]: | ||
| """Process evaluation rows with Klavis sandbox lifecycle management""" | ||
| semaphore = config.semaphore | ||
|
|
||
| async def process_row(row: EvaluationRow) -> EvaluationRow: | ||
| """Process a single row with complete sandbox lifecycle""" | ||
|
|
||
| start_time = time.perf_counter() | ||
| agent: Agent | None = None | ||
| temp_config_path: str | None = None | ||
| sandbox: CreateSandboxResponse | None = None | ||
|
|
||
| try: | ||
| # Step 0: Create a sandbox for this row | ||
| sandbox = self._init_sandbox() | ||
| logger.info(f"Sandbox created: {sandbox}") | ||
|
|
||
| # Step 1: Initialize data in the sandbox | ||
| init_data: Dict[str, Any] | None = None | ||
| if self.initialize_data_factory: | ||
| init_data = self.initialize_data_factory(row) | ||
| else: | ||
| # Allow datasets to provide initialization payload directly | ||
| init_data = ( | ||
| (row.input_metadata.session_data or {}).get("initialize_data") | ||
| if row.input_metadata is not None | ||
| else None | ||
| ) | ||
|
|
||
| if init_data: | ||
| logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}") | ||
| initialize_method = getattr( | ||
| self.klavis_client.sandbox, f"initialize_{sandbox.server_name.value}_sandbox" | ||
| ) | ||
| init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data) | ||
| logger.info(f"Initialization response: {init_response}") | ||
|
|
||
| # Step 2: Create temporary MCP config with sandbox URL | ||
| temp_config_path = self.create_mcp_config( | ||
| server_url=sandbox.server_url, server_key=sandbox.server_name.value | ||
| ) | ||
| logger.info(f"MCP config created: {temp_config_path}") | ||
|
|
||
| # Step 3: Run agent with sandbox MCP server | ||
| logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox") | ||
| agent = Agent( | ||
| model=row.input_metadata.completion_params["model"], | ||
| row=row, | ||
| config_path=temp_config_path, | ||
| logger=config.logger, | ||
| ) | ||
| await agent.setup() | ||
| await agent.call_agent() | ||
|
|
||
| # Update usage metadata | ||
| row.execution_metadata.usage = CompletionUsage( | ||
| prompt_tokens=agent.usage.get("prompt_tokens", 0), | ||
| completion_tokens=agent.usage.get("completion_tokens", 0), | ||
| total_tokens=agent.usage.get("total_tokens", 0), | ||
| ) | ||
| row = agent.evaluation_row | ||
| logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}") | ||
|
|
||
| # Step 4: Export sandbox data | ||
| dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox") | ||
| dump_response = dump_method(sandbox_id=sandbox.sandbox_id) | ||
| sandbox_data = dump_response.data | ||
| logger.info(f"Sandbox data: {sandbox_data}") | ||
|
|
||
| # Store sandbox data in row metadata for evaluation | ||
| if not row.execution_metadata.extra: | ||
| row.execution_metadata.extra = {} | ||
| row.execution_metadata.extra["sandbox_data"] = sandbox_data | ||
| row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id | ||
| row.execution_metadata.extra["server_name"] = self.server_name | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Error processing row {row.execution_metadata.rollout_id}: {str(e)}", exc_info=True) | ||
| if not row.execution_metadata.extra: | ||
| row.execution_metadata.extra = {} | ||
| row.execution_metadata.extra["error"] = str(e) | ||
| raise | ||
|
|
||
| finally: | ||
| # Cleanup agent MCP client and temp config | ||
| if agent and agent.mcp_client: | ||
| await agent.mcp_client.cleanup() | ||
| if temp_config_path and os.path.exists(temp_config_path): | ||
| os.unlink(temp_config_path) | ||
|
|
||
| # Release sandbox | ||
| if sandbox and sandbox.sandbox_id: | ||
| try: | ||
| self.klavis_client.sandbox.delete_sandbox( | ||
| server_name=sandbox.server_name, sandbox_id=sandbox.sandbox_id | ||
| ) | ||
| logger.info(f"Sandbox {sandbox.sandbox_id} released successfully") | ||
| except Exception as e: | ||
| logger.error(f"Error releasing sandbox {sandbox.sandbox_id}: {str(e)}", exc_info=True) | ||
|
|
||
| row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time | ||
|
|
||
| return row | ||
|
|
||
| async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: | ||
| async with semaphore: | ||
| result = await process_row(r) | ||
| return result | ||
|
|
||
| # Create and return tasks | ||
| tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] | ||
| return tasks |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| {"initialize_data": {"messages": [{"subject": "Project Update", "to": "[email protected]", "body": "The project is progressing well. We should have the final deliverables by next week.", "cc": "", "bcc": "", "from": "[email protected]", "reply_to": "", "labels": ["INBOX"]}, {"subject": "Spam Newsletter", "to": "[email protected]", "body": "Check out our amazing deals! Click here now!", "cc": "", "bcc": "", "from": "[email protected]", "reply_to": "", "labels": ["INBOX"]}], "drafts": []}, "messages": "Please delete the email with subject \"Spam Newsletter\" from my inbox.", "ground_truth": {"messages": [{"subject": "Project Update", "to": "[email protected]", "body": "The project is progressing well. We should have the final deliverables by next week.", "cc": "", "bcc": "", "from": "[email protected]", "reply_to": "", "labels": ["INBOX"]}], "drafts": []}} | ||
| {"initialize_data": {"messages": [], "drafts": []}, "messages": "Please directly send an email to [email protected] with subject \"Meeting Tomorrow\" and body \"Hi Zihao, just confirming our meeting tomorrow at 2pm. Best regards.\"", "ground_truth": {"messages": [{"subject": "Meeting Tomorrow", "to": "[email protected]", "body": "Hi Zihao, just confirming our meeting tomorrow at 2pm. Best regards.", "cc": "", "bcc": "", "from": "", "reply_to": "", "labels": ["SENT"]}], "drafts": []}} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| import json | ||
| import logging | ||
| import os | ||
|
|
||
| from eval_protocol.models import EvaluateResult, EvaluationRow, Message | ||
| from eval_protocol.pytest import KlavisSandboxRolloutProcessor, evaluation_test | ||
| from openai import AsyncOpenAI | ||
| from pydantic import BaseModel | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class ResponseFormat(BaseModel): | ||
| score: float | ||
|
|
||
|
|
||
| def klavis_gmail_sandbox_dataset_adapter(rows: list[dict]) -> list[EvaluationRow]: | ||
| """Dataset adapter for sandbox JSONL rows. | ||
|
|
||
| Supports the new schema: | ||
| - initialize_data: dict (passed to Klavis sandbox initializer) | ||
| - messages: str (task instruction) | ||
| - ground_truth: dict (expected final sandbox state) | ||
|
|
||
| """ | ||
| adapted: list[EvaluationRow] = [] | ||
| system_prompt = ( | ||
| "You are a helpful assistant with access to Gmail. " | ||
| "You can send emails, draft emails, and manage messages, etc." | ||
| ) | ||
|
|
||
| for r in rows: | ||
| if isinstance(r.get("messages"), str) and "initialize_data" in r: | ||
| init_data = r.get("initialize_data") or {} | ||
| task = r.get("messages") or "" | ||
| ground_truth = r.get("ground_truth") | ||
|
|
||
| row = EvaluationRow( | ||
| messages=[ | ||
| Message(role="system", content=system_prompt), | ||
| Message(role="user", content=task), | ||
| ], | ||
| ground_truth=ground_truth, | ||
| ) | ||
| row.input_metadata.session_data = { | ||
| "initialize_data": init_data, | ||
| "task": task, | ||
| } | ||
| adapted.append(row) | ||
| else: | ||
| adapted.append(EvaluationRow(**r)) | ||
|
|
||
| return adapted | ||
|
|
||
|
|
||
| @evaluation_test( | ||
| input_dataset=["tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl"], | ||
| rollout_processor=KlavisSandboxRolloutProcessor( | ||
| server_name="gmail", | ||
| ), | ||
| completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-thinking"}], | ||
| mode="pointwise", | ||
| dataset_adapter=klavis_gmail_sandbox_dataset_adapter, | ||
| ) | ||
| async def test_pytest_gmail_sandbox(row: EvaluationRow) -> EvaluationRow: | ||
| """ | ||
| Evaluate Gmail sandbox results by comparing with ground truth using LLM judge. | ||
|
|
||
| The sandbox data is exported after agent execution and compared with expected output. | ||
| Sandbox data is available in row.execution_metadata.extra["sandbox_data"]. | ||
| """ | ||
| ground_truth = row.ground_truth | ||
| sandbox_data = row.execution_metadata.extra.get("sandbox_data", {}) if row.execution_metadata.extra else {} | ||
| final_message = row.messages[-1].content if row.messages else "" | ||
| initialize_data = (row.input_metadata.session_data or {}).get("initialize_data", {}) | ||
| task = (row.input_metadata.session_data or {}).get("task", "") | ||
|
|
||
| logger.info(f"Evaluating row {row.execution_metadata.rollout_id}") | ||
| logger.info(f"Final message: {final_message}") | ||
| logger.info(f"Sandbox data: {json.dumps(sandbox_data, indent=2, default=str)}") | ||
| logger.info(f"Ground truth: {ground_truth}") | ||
|
|
||
| async with AsyncOpenAI( | ||
| api_key=os.environ["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1" | ||
| ) as client: | ||
|
|
||
| evaluation_prompt = f"""You are evaluating an AI agent's performance on a Gmail sandbox task. | ||
|
|
||
| Task: | ||
| {task or (row.messages[-1].content if row.messages else 'N/A')} | ||
|
|
||
| Initial Gmail Sandbox State (initialize_data): | ||
| {json.dumps(initialize_data, indent=2, default=str)} | ||
|
|
||
| Expected Final Gmail Sandbox State (ground_truth): | ||
| {json.dumps(ground_truth, indent=2, default=str)} | ||
|
|
||
| Gmail Sandbox State After Execution: | ||
| {json.dumps(sandbox_data, indent=2, default=str)} | ||
|
|
||
| Evaluate whether the agent successfully completed the task by checking: | ||
| 1. Does the final sandbox state match the expected ground_truth state? | ||
| 2. If there are small formatting differences, judge semantically | ||
| 3. Use the initial state only as context; the key is whether the correct changes happened. | ||
|
|
||
| Return: | ||
| - score: 1.0 if task completed successfully, 0.5 if partially completed, 0.0 if failed | ||
|
|
||
| """ | ||
|
|
||
| try: | ||
| response = await client.chat.completions.create( | ||
| model="accounts/fireworks/models/kimi-k2-thinking", | ||
| messages=[ | ||
| { | ||
| "role": "system", | ||
| "content": "You are a precise evaluator of AI agent performance. Analyze the task, execution, and results carefully.", | ||
| }, | ||
| {"role": "user", "content": evaluation_prompt}, | ||
| ], | ||
| response_format={ | ||
| "type": "json_schema", | ||
| "json_schema": {"name": "ResponseFormat", "schema": ResponseFormat.model_json_schema()}, | ||
| }, | ||
| temperature=0.0, | ||
| ) | ||
|
|
||
| response_text = response.choices[0].message.content | ||
| logger.info(f"LLM judge response: {response_text}") | ||
|
|
||
| parsed = json.loads(response_text or "{}") | ||
| score = parsed.get("score", 0.0) | ||
|
|
||
| row.evaluation_result = EvaluateResult(score=score) | ||
| except Exception as e: | ||
| logger.error(f"Error during LLM evaluation: {str(e)}", exc_info=True) | ||
| row.evaluation_result = EvaluateResult( | ||
| score=0.0, | ||
| reason=f"Evaluation error: {str(e)}", | ||
| ) | ||
|
|
||
| return row |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing try/except for optional klavis dependency import
The
KlavisSandboxRolloutProcessoris imported unconditionally at line 3, but it depends on the optionalklavispackage. This will cause anImportErrorfor any user who imports fromeval_protocol.pytestwithout havingklavisinstalled. Other optional dependency imports likePydanticAgentRolloutProcessorandLangGraphRolloutProcessorare correctly wrapped in try/except blocks (lines 16-22 and 25-31), butKlavisSandboxRolloutProcessorlacks this protection. The import and__all__export at line 35 both need to be made conditional like the other optional dependencies.Additional Locations (1)
eval_protocol/pytest/__init__.py#L34-L35