Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions eval_protocol/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ class ScoreInvalidError(EvalProtocolError):
status_code = 102


class ResponseQualityError(EvalProtocolError):
"""Response quality check failed (Status.Code.RESPONSE_QUALITY_ERROR = 103)"""

status_code = 103


# Convenience mapping from status codes to exception classes
# Only actual error conditions should raise exceptions
STATUS_CODE_TO_EXCEPTION = {
Expand All @@ -157,6 +163,7 @@ class ScoreInvalidError(EvalProtocolError):
100: None, # FINISHED - success, no exception
101: None, # RUNNING - in progress, no exception
102: None, # SCORE_INVALID - success, no exception
103: ResponseQualityError, # RESPONSE_QUALITY_ERROR - quality check failed
}


Expand Down
8 changes: 8 additions & 0 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Code(int, Enum):
FINISHED = 100
RUNNING = 101
SCORE_INVALID = 102
RESPONSE_QUALITY_ERROR = 103

@classmethod
def rollout_running(cls) -> "Status":
Expand Down Expand Up @@ -367,6 +368,13 @@ def score_invalid(
"""Create a status indicating the score is invalid."""
return cls(code=cls.Code.SCORE_INVALID, message=message, details=details or [])

@classmethod
def response_quality_error(
cls, message: str = "Response quality check failed", details: Optional[List[Dict[str, Any]]] = None
) -> "Status":
"""Create a status indicating the response quality check failed."""
return cls(code=cls.Code.RESPONSE_QUALITY_ERROR, message=message, details=details or [])

def is_running(self) -> bool:
"""Check if the status indicates the rollout is running."""
return self.code == self.Code.RUNNING
Expand Down
3 changes: 3 additions & 0 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .evaluation_test import evaluation_test
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
from .rollout_processor import RolloutProcessor
from .rollout_result_post_processor import RolloutResultPostProcessor, NoOpRolloutResultPostProcessor
from .types import RolloutProcessorConfig

# Conditional import for optional dependencies
Expand Down Expand Up @@ -42,6 +43,8 @@
"ExceptionHandlerConfig",
"BackoffConfig",
"get_default_exception_handler_config",
"RolloutResultPostProcessor",
"NoOpRolloutResultPostProcessor",
]

# Only add to __all__ if available
Expand Down
23 changes: 21 additions & 2 deletions eval_protocol/pytest/evaluation_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ServerMode,
)
from eval_protocol.pytest.exception_config import get_default_exception_handler_config
from eval_protocol.exceptions import ResponseQualityError

import logging
import json
Expand Down Expand Up @@ -363,7 +364,18 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
"""Execute rollout for a single row with backoff retry."""
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
retry_tasks = rollout_processor([row], retry_config)
return await retry_tasks[0]
result = await retry_tasks[0]

# Apply post-processing quality checks if configured
# This must be inside the retry function so ResponseQualityError can trigger retries
if config.post_processor is not None:
try:
config.post_processor.process(result)
except ResponseQualityError as quality_error:
# Re-raise ResponseQualityError to trigger retry logic
raise quality_error

return result

async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
"""Execute a single row task with backoff retry."""
Expand All @@ -372,6 +384,13 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
# Try original task first
result = await task # pyright: ignore[reportUnknownVariableType]

# Apply post-processing quality checks if configured
if config.post_processor is not None:
try:
config.post_processor.process(result)
except ResponseQualityError as quality_error:
raise quality_error

_set_rollout_status_to_finished(result)

return result # pyright: ignore[reportUnknownVariableType]
Expand All @@ -384,9 +403,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu

if is_retryable and not should_giveup:
# Use shared backoff function for retryable exceptions
# Note: post-processing is handled inside execute_row_with_backoff_retry
try:
result = await execute_row_with_backoff_retry(row)

_set_rollout_status_to_finished(result)

return result
Expand Down
13 changes: 10 additions & 3 deletions eval_protocol/pytest/exception_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from dataclasses import dataclass, field
from typing import Callable, Set, Type, Union
from typing import Callable, Dict, Set, Type, Union

import backoff

Expand Down Expand Up @@ -47,6 +47,7 @@
eval_protocol.exceptions.UnavailableError,
eval_protocol.exceptions.UnauthenticatedError,
eval_protocol.exceptions.ResourceExhaustedError,
eval_protocol.exceptions.ResponseQualityError,
}


Expand Down Expand Up @@ -79,7 +80,11 @@ class BackoffConfig:
giveup_func: Callable[[Exception], bool] = lambda e: False

def get_backoff_decorator(self, exceptions: Set[Type[Exception]]):
"""Get the appropriate backoff decorator based on configuration."""
"""Get the appropriate backoff decorator based on configuration.

Args:
exceptions: Set of exception types to retry
"""
if not exceptions:
# If no exceptions specified, return a no-op decorator
def no_op_decorator(func):
Expand Down Expand Up @@ -136,7 +141,9 @@ def __post_init__(self):

def get_backoff_decorator(self):
"""Get the backoff decorator configured for this exception handler."""
return self.backoff_config.get_backoff_decorator(self.retryable_exceptions)
return self.backoff_config.get_backoff_decorator(
self.retryable_exceptions
)


def get_default_exception_handler_config() -> ExceptionHandlerConfig:
Expand Down
57 changes: 57 additions & 0 deletions eval_protocol/pytest/rollout_result_post_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Rollout result post-processing plugin for quality checks.

This module provides an abstract base class for post-processing rollout results
to guard response quality. Post-processors can validate results and raise
ResponseQualityError if quality checks fail.
"""

from abc import ABC, abstractmethod

from eval_protocol.models import EvaluationRow


class RolloutResultPostProcessor(ABC):
"""
Abstract base class for rollout result post-processing plugins.

Post-processors validate rollout results and can raise ResponseQualityError
if quality checks fail. This allows for customizable quality guards that
can be overridden by users.
"""

@abstractmethod
def process(self, result: EvaluationRow) -> None:
"""
Process and validate a rollout result.

This method should perform quality checks on the result. If quality
checks fail, it should raise ResponseQualityError with an appropriate
message.

Args:
result: The EvaluationRow result from the rollout

Raises:
ResponseQualityError: If quality checks fail
"""
pass


class NoOpRolloutResultPostProcessor(RolloutResultPostProcessor):
"""
Default no-op implementation of RolloutResultPostProcessor.

This implementation does not perform any quality checks and always passes.
Use this as a default when no post-processing is needed.
"""

def process(self, result: EvaluationRow) -> None:
"""
No-op implementation that does not perform any quality checks.

Args:
result: The EvaluationRow result from the rollout
"""
pass

2 changes: 2 additions & 0 deletions eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ..models import CompletionParams, EvaluationRow, Message
from .exception_config import ExceptionHandlerConfig
from .rollout_result_post_processor import RolloutResultPostProcessor

ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
DatasetPathParam = str
Expand Down Expand Up @@ -75,3 +76,4 @@ class RolloutProcessorConfig:
default_factory=dict
) # any additional kwargs to pass to the rollout processor
exception_handler_config: ExceptionHandlerConfig | None = None # configuration for exception handling with backoff
post_processor: RolloutResultPostProcessor | None = None # optional post-processor for quality checks
114 changes: 114 additions & 0 deletions tests/test_exception_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Unit tests for exception_config module.

Tests the BackoffConfig and ExceptionHandlerConfig classes, including:
1. Backoff decorator creation
2. Per-exception backoff overrides
3. ResponseQualityError default no-backoff configuration
4. Exception grouping to avoid double backoff
"""

import pytest
from eval_protocol.pytest.exception_config import BackoffConfig, ExceptionHandlerConfig, DEFAULT_RETRYABLE_EXCEPTIONS
from eval_protocol.exceptions import ResponseQualityError


def test_backoff_config_no_exceptions():
"""Test that BackoffConfig returns no-op decorator when no exceptions specified."""
config = BackoffConfig()
decorator = config.get_backoff_decorator(set())

# Should be a no-op decorator
def test_func():
return "test"

decorated = decorator(test_func)
assert decorated() == "test"
assert decorated is test_func # Should be the same function


def test_backoff_config_no_overrides():
"""Test that BackoffConfig creates a single decorator."""
config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2)
exceptions = {ConnectionError, TimeoutError}

decorator = config.get_backoff_decorator(exceptions)
assert decorator is not None

# Decorator should be callable
def test_func():
raise ConnectionError("test")

decorated = decorator(test_func)
assert callable(decorated)


def test_exception_handler_config_default_response_quality_error():
"""Test that ExceptionHandlerConfig includes ResponseQualityError by default."""
config = ExceptionHandlerConfig()

# ResponseQualityError should be in retryable_exceptions
assert ResponseQualityError in config.retryable_exceptions


def test_exception_handler_config_get_backoff_decorator():
"""Test that ExceptionHandlerConfig.get_backoff_decorator() works correctly."""
config = ExceptionHandlerConfig()
decorator = config.get_backoff_decorator()

assert decorator is not None
assert callable(decorator)

# Should be able to decorate a function
def test_func():
raise ConnectionError("test")

decorated = decorator(test_func)
assert callable(decorated)


def test_backoff_config_expo_strategy():

"""Test that BackoffConfig creates expo decorator correctly."""
config = BackoffConfig(strategy="expo", base_delay=1.0, max_tries=2)
exceptions = {ConnectionError}

decorator = config.get_backoff_decorator(exceptions)
assert decorator is not None

def test_func():
raise ConnectionError("test")

decorated = decorator(test_func)
assert callable(decorated)


def test_backoff_config_constant_strategy():
"""Test that BackoffConfig creates constant decorator correctly."""
config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2)
exceptions = {ConnectionError}

decorator = config.get_backoff_decorator(exceptions)
assert decorator is not None

def test_func():
raise ConnectionError("test")

decorated = decorator(test_func)
assert callable(decorated)


def test_backoff_config_invalid_strategy():
"""Test that BackoffConfig raises ValueError for invalid strategy."""
config = BackoffConfig(strategy="invalid", base_delay=1.0, max_tries=2)
exceptions = {ConnectionError}

with pytest.raises(ValueError, match="Unknown backoff strategy"):
config.get_backoff_decorator(exceptions)


def test_exception_handler_config_response_quality_error_in_defaults():
"""Test that ResponseQualityError is in DEFAULT_RETRYABLE_EXCEPTIONS."""
assert ResponseQualityError in DEFAULT_RETRYABLE_EXCEPTIONS


Loading
Loading