diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index 48f8a015..64f0c8b3 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -371,7 +371,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow: retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False}) retry_tasks = rollout_processor([row], retry_config) result = await retry_tasks[0] - + # Apply post-processing quality checks if configured # This must be inside the retry function so ResponseQualityError can trigger retries if config.post_processor is not None: @@ -380,7 +380,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow: except ResponseQualityError as quality_error: # Re-raise ResponseQualityError to trigger retry logic raise quality_error - + return result async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow: @@ -464,6 +464,7 @@ async def execute_row_with_backoff_and_log( yield result finally: + await rollout_processor.acleanup() rollout_processor.cleanup() diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index aa1c5d44..374978e1 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -1,14 +1,10 @@ import asyncio import time -from typing import Any, Dict, List, Optional +from typing import List, Optional -import requests +import aiohttp from eval_protocol.models import EvaluationRow, Status -from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.types.remote_rollout_processor import ( - DataLoaderConfig, -) from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter from eval_protocol.exceptions import exception_for_status_code @@ -51,6 +47,12 @@ def __init__( self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url) + self._session: Optional[aiohttp.ClientSession] = None + + def _get_or_create_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: tasks: List[asyncio.Task[EvaluationRow]] = [] @@ -88,48 +90,26 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: init_payload = build_init_request(row, config, model_base_url) # Fire-and-poll - def _post_init() -> None: - url = f"{remote_base_url}/init" - try: - r = requests.post(url, json=init_payload.model_dump(), timeout=300) - r.raise_for_status() - except requests.exceptions.Timeout: - raise TimeoutError( - f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds." - ) - - await asyncio.to_thread(_post_init) + init_url = f"{remote_base_url}/init" + + timeout_init = aiohttp.ClientTimeout(total=300) + + try: + session = self._get_or_create_session() + async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp: + if resp.status >= 400: + body = await resp.text() + raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}") + resp.raise_for_status() + await resp.read() # Drain the response body and release the connection back to the pool + except asyncio.TimeoutError: + raise TimeoutError( + f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds." + ) - terminated = False deadline = time.time() + timeout_seconds - def _get_status() -> Dict[str, Any]: - url = f"{remote_base_url}/status" - r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15) - r.raise_for_status() - return r.json() - - continue_polling_status = True while time.time() < deadline: - try: - if continue_polling_status: - status = await asyncio.to_thread(_get_status) - terminated = bool(status.get("terminated", False)) - if terminated: - break - except requests.exceptions.HTTPError as e: - if e.response is not None and e.response.status_code == 404: - # 404 means server doesn't implement /status endpoint, stop polling - logger.debug( - f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}" - ) - continue_polling_status = False - else: - raise - except Exception: - # For all other exceptions, raise them - raise - # Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop) completed_logs = await asyncio.to_thread( self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"] @@ -200,5 +180,21 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] return tasks + async def acleanup(self) -> None: + """Async cleanup - preferred when you can await.""" + if self._session and not self._session.closed: + await self._session.close() + def cleanup(self) -> None: - return None + """Sync cleanup - best-effort, schedules close if event loop is running.""" + if self._session and not self._session.closed: + try: + loop = asyncio.get_running_loop() + loop.create_task(self._session.close()) + except RuntimeError: + # No running event loop - can't safely close the session. + # The session will be garbage collected eventually, but warn about it. + logger.warning( + "RemoteRolloutProcessor.cleanup() called outside of async context. " + "Session may not be properly closed. Use `await processor.acleanup()` when possible." + ) diff --git a/eval_protocol/pytest/rollout_processor.py b/eval_protocol/pytest/rollout_processor.py index 95fbfa1b..c15413d1 100644 --- a/eval_protocol/pytest/rollout_processor.py +++ b/eval_protocol/pytest/rollout_processor.py @@ -19,6 +19,10 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> """Process evaluation rows and return async tasks. Must be implemented by subclasses.""" pass + async def acleanup(self) -> None: + """Async cleanup - preferred when you can await.""" + pass + def cleanup(self) -> None: """Cleanup resources. Override in subclasses if cleanup is needed.""" pass diff --git a/eval_protocol/training/gepa_trainer.py b/eval_protocol/training/gepa_trainer.py index d91efe67..d8625bf2 100644 --- a/eval_protocol/training/gepa_trainer.py +++ b/eval_protocol/training/gepa_trainer.py @@ -503,6 +503,7 @@ async def evaluate_with_ep( } finally: + await rollout_processor.acleanup() rollout_processor.cleanup() def run_ep_evaluation( diff --git a/tests/pytest/test_utils.py b/tests/pytest/test_utils.py index 0176c279..09378fb7 100644 --- a/tests/pytest/test_utils.py +++ b/tests/pytest/test_utils.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry @@ -16,6 +16,7 @@ def mock_rollout_processor(self): """Create a mock rollout processor that returns async tasks.""" processor = MagicMock() processor.cleanup = MagicMock() + processor.acleanup = AsyncMock() # async cleanup method return processor @pytest.fixture @@ -71,8 +72,8 @@ async def mock_task(): assert mock_config.logger.log.call_count == 1 mock_config.logger.log.assert_called_once_with(results[0]) - # Verify cleanup was called - mock_rollout_processor.cleanup.assert_called_once() + # Verify async cleanup was called (aclose is preferred over cleanup) + mock_rollout_processor.acleanup.assert_awaited_once() @pytest.mark.asyncio async def test_logger_called_on_failed_execution(self, mock_rollout_processor, mock_config, sample_dataset): @@ -97,8 +98,8 @@ async def failing_task(): assert results[0].rollout_status.code == 13 # INTERNAL error code assert "Test error" in results[0].rollout_status.message - # Verify cleanup was called - mock_rollout_processor.cleanup.assert_called_once() + # Verify async cleanup was called (aclose is preferred over cleanup) + mock_rollout_processor.acleanup.assert_awaited_once() @pytest.mark.asyncio async def test_logger_called_on_retry_execution(self, mock_rollout_processor, mock_config, sample_dataset): @@ -134,8 +135,8 @@ async def flaky_task(): assert mock_config.logger.log.call_count == 1 mock_config.logger.log.assert_called_once_with(results[0]) - # Verify cleanup was called - mock_rollout_processor.cleanup.assert_called_once() + # Verify async cleanup was called (aclose is preferred over cleanup) + mock_rollout_processor.acleanup.assert_awaited_once() @pytest.mark.asyncio async def test_logger_called_for_multiple_rows(self, mock_rollout_processor, mock_config): @@ -182,8 +183,8 @@ async def mock_task(): assert mock_config.logger.log.call_count == 2 assert len(results) == 2 - # Verify cleanup was called - mock_rollout_processor.cleanup.assert_called_once() + # Verify async cleanup was called (aclose is preferred over cleanup) + mock_rollout_processor.acleanup.assert_awaited_once() @pytest.mark.asyncio async def test_logger_called_even_when_processor_fails_to_initialize( @@ -198,5 +199,5 @@ async def test_logger_called_even_when_processor_fails_to_initialize( async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config): pass - # Verify cleanup was called even though the function failed - mock_rollout_processor.cleanup.assert_called_once() + # Verify async cleanup was called even though the function failed + mock_rollout_processor.acleanup.assert_awaited_once()