Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions eval_protocol/pytest/evaluation_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
retry_tasks = rollout_processor([row], retry_config)
result = await retry_tasks[0]

# Apply post-processing quality checks if configured
# This must be inside the retry function so ResponseQualityError can trigger retries
if config.post_processor is not None:
Expand All @@ -380,7 +380,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
except ResponseQualityError as quality_error:
# Re-raise ResponseQualityError to trigger retry logic
raise quality_error

return result

async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
Expand Down Expand Up @@ -464,6 +464,7 @@ async def execute_row_with_backoff_and_log(
yield result

finally:
await rollout_processor.acleanup()
rollout_processor.cleanup()


Expand Down
86 changes: 41 additions & 45 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import asyncio
import time
from typing import Any, Dict, List, Optional
from typing import List, Optional

import requests
import aiohttp

from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import (
DataLoaderConfig,
)
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.exceptions import exception_for_status_code

Expand Down Expand Up @@ -51,6 +47,12 @@ def __init__(
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
self._session: Optional[aiohttp.ClientSession] = None

def _get_or_create_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
tasks: List[asyncio.Task[EvaluationRow]] = []
Expand Down Expand Up @@ -88,48 +90,26 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
init_payload = build_init_request(row, config, model_base_url)

# Fire-and-poll
def _post_init() -> None:
url = f"{remote_base_url}/init"
try:
r = requests.post(url, json=init_payload.model_dump(), timeout=300)
r.raise_for_status()
except requests.exceptions.Timeout:
raise TimeoutError(
f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds."
)

await asyncio.to_thread(_post_init)
init_url = f"{remote_base_url}/init"

timeout_init = aiohttp.ClientTimeout(total=300)

try:
session = self._get_or_create_session()
async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp:
if resp.status >= 400:
body = await resp.text()
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
resp.raise_for_status()
await resp.read() # Drain the response body and release the connection back to the pool
except asyncio.TimeoutError:
raise TimeoutError(
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds."
)

terminated = False
deadline = time.time() + timeout_seconds

def _get_status() -> Dict[str, Any]:
url = f"{remote_base_url}/status"
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
r.raise_for_status()
return r.json()

continue_polling_status = True
while time.time() < deadline:
try:
if continue_polling_status:
status = await asyncio.to_thread(_get_status)
terminated = bool(status.get("terminated", False))
if terminated:
break
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code == 404:
# 404 means server doesn't implement /status endpoint, stop polling
logger.debug(
f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}"
)
continue_polling_status = False
else:
raise
except Exception:
# For all other exceptions, raise them
raise

# Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop)
completed_logs = await asyncio.to_thread(
self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
Expand Down Expand Up @@ -200,5 +180,21 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
return tasks

async def acleanup(self) -> None:
"""Async cleanup - preferred when you can await."""
if self._session and not self._session.closed:
await self._session.close()

def cleanup(self) -> None:
return None
"""Sync cleanup - best-effort, schedules close if event loop is running."""
if self._session and not self._session.closed:
try:
loop = asyncio.get_running_loop()
loop.create_task(self._session.close())
except RuntimeError:
# No running event loop - can't safely close the session.
# The session will be garbage collected eventually, but warn about it.
logger.warning(
"RemoteRolloutProcessor.cleanup() called outside of async context. "
"Session may not be properly closed. Use `await processor.acleanup()` when possible."
)
4 changes: 4 additions & 0 deletions eval_protocol/pytest/rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
"""Process evaluation rows and return async tasks. Must be implemented by subclasses."""
pass

async def acleanup(self) -> None:
"""Async cleanup - preferred when you can await."""
pass

def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses if cleanup is needed."""
pass
1 change: 1 addition & 0 deletions eval_protocol/training/gepa_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ async def evaluate_with_ep(
}

finally:
await rollout_processor.acleanup()
rollout_processor.cleanup()

def run_ep_evaluation(
Expand Down
23 changes: 12 additions & 11 deletions tests/pytest/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest

from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry
Expand All @@ -16,6 +16,7 @@ def mock_rollout_processor(self):
"""Create a mock rollout processor that returns async tasks."""
processor = MagicMock()
processor.cleanup = MagicMock()
processor.acleanup = AsyncMock() # async cleanup method
return processor

@pytest.fixture
Expand Down Expand Up @@ -71,8 +72,8 @@ async def mock_task():
assert mock_config.logger.log.call_count == 1
mock_config.logger.log.assert_called_once_with(results[0])

# Verify cleanup was called
mock_rollout_processor.cleanup.assert_called_once()
# Verify async cleanup was called (aclose is preferred over cleanup)
mock_rollout_processor.acleanup.assert_awaited_once()

@pytest.mark.asyncio
async def test_logger_called_on_failed_execution(self, mock_rollout_processor, mock_config, sample_dataset):
Expand All @@ -97,8 +98,8 @@ async def failing_task():
assert results[0].rollout_status.code == 13 # INTERNAL error code
assert "Test error" in results[0].rollout_status.message

# Verify cleanup was called
mock_rollout_processor.cleanup.assert_called_once()
# Verify async cleanup was called (aclose is preferred over cleanup)
mock_rollout_processor.acleanup.assert_awaited_once()

@pytest.mark.asyncio
async def test_logger_called_on_retry_execution(self, mock_rollout_processor, mock_config, sample_dataset):
Expand Down Expand Up @@ -134,8 +135,8 @@ async def flaky_task():
assert mock_config.logger.log.call_count == 1
mock_config.logger.log.assert_called_once_with(results[0])

# Verify cleanup was called
mock_rollout_processor.cleanup.assert_called_once()
# Verify async cleanup was called (aclose is preferred over cleanup)
mock_rollout_processor.acleanup.assert_awaited_once()

@pytest.mark.asyncio
async def test_logger_called_for_multiple_rows(self, mock_rollout_processor, mock_config):
Expand Down Expand Up @@ -182,8 +183,8 @@ async def mock_task():
assert mock_config.logger.log.call_count == 2
assert len(results) == 2

# Verify cleanup was called
mock_rollout_processor.cleanup.assert_called_once()
# Verify async cleanup was called (aclose is preferred over cleanup)
mock_rollout_processor.acleanup.assert_awaited_once()

@pytest.mark.asyncio
async def test_logger_called_even_when_processor_fails_to_initialize(
Expand All @@ -198,5 +199,5 @@ async def test_logger_called_even_when_processor_fails_to_initialize(
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
pass

# Verify cleanup was called even though the function failed
mock_rollout_processor.cleanup.assert_called_once()
# Verify async cleanup was called even though the function failed
mock_rollout_processor.acleanup.assert_awaited_once()
Loading