From f846f70451a235138966b446d496885ef2e44f2a Mon Sep 17 00:00:00 2001 From: Bogdan-Alexandru Stoica Date: Sat, 31 Jan 2026 01:21:04 -0600 Subject: [PATCH 1/6] feat: implemented base classes and interfaces to standardize the agent checker/oracle scripts --- .../arteval_bench/src/evaluator/README.md | 309 +++++++++ .../oracle_artifact_build_primitives.py | 405 +++++++++++ .../oracle_benchmark_prep_primitives.py | 432 ++++++++++++ .../evaluator/oracle_env_setup_primitives.py | 364 ++++++++++ .../oracle_experiment_runs_primitives.py | 631 ++++++++++++++++++ .../arteval_bench/src/evaluator/utils.py | 387 +++++++++++ 6 files changed, 2528 insertions(+) create mode 100644 benchmarks/arteval_bench/src/evaluator/README.md create mode 100644 benchmarks/arteval_bench/src/evaluator/oracle_artifact_build_primitives.py create mode 100644 benchmarks/arteval_bench/src/evaluator/oracle_benchmark_prep_primitives.py create mode 100644 benchmarks/arteval_bench/src/evaluator/oracle_env_setup_primitives.py create mode 100644 benchmarks/arteval_bench/src/evaluator/oracle_experiment_runs_primitives.py create mode 100644 benchmarks/arteval_bench/src/evaluator/utils.py diff --git a/benchmarks/arteval_bench/src/evaluator/README.md b/benchmarks/arteval_bench/src/evaluator/README.md new file mode 100644 index 00000000..d887df22 --- /dev/null +++ b/benchmarks/arteval_bench/src/evaluator/README.md @@ -0,0 +1,309 @@ +# Agent Evaluator Primitives + +This bundle provides primitives for four oracles that verify if an AI agent can succesfully evaluating a set of artifacts, namely setting up, building code, downloading datasets and runing experiments end-to-end. Each oracle corresponds to one stage of the artifact evaluation (AE) process and encodes minimal, objective, and programatically verifiable success criteria. Oracles are designed to be idempotent (safe to run multiple times), non-interactive (no blocking events like I/O actions or manual intervention), and produce a binary outcome (either "pass" or "fail"). + +The oracles verify four canonical stages of the AE process: + +1. Environment setup: check required tools/dependencies exist and meet version constraints; confirm key environment variables and required files/directories are present. +2. Artifact build: run build/install commands and fail if they do not complete successfully. +3. Benchmark preparation: check datasets/benchmarks/tools are present and usable; optionally run quick commands and check for expected output signatures. +4. Experiment runs: compare observed to reference values using similarity or elementwise checks within cutomizable tolerance thresholds. + +Each artifact includes a self-contained oracles in a `_agent_eval/` directory. These scripts extend the base primitives descrived above to create specialized oracles that assert success criteria at each AE stage. + +## Implementing agent evaluators + +When adding a new artifact to `ArtEvalBench`, users need to create an accompanying `_agent_eval/` directory that implements the derived oracles. The `_agent_eval/` directory should have the following minimal structure: +``` +_agent_eval/ +├── main.py +├── oracle_artifact_build.py +├── oracle_benchmark_prep.py +├── oracle_env_setup.py +├── oracle_experiment_runs.py +├── ... +└── refs + ├── ground_truth_results.json + ├── ... + ... +``` + +Each evaluation stage is implemented as a small Python module that derives from the corresponding oracle base class in this bundle. The evaluator also provides a `main.py` entry point that: +- defines an `EntryConfig` object which specifies the required directory structure and file paths, similarity thresholds, ground truth measurements (as files) +- instantiates each oracle in order +- runs them and aggregates a stage-by-stage score + +The evaluator also includes a refs/ directory containing reference artifacts (typically JSON) used by benchmark-prep and experiment-runs checks. These files capture expected outputs in a machine-checkable form—for example: expected dataset manifests/checksums and sizes, expected metric tables (latency/throughput percentiles), accuracy or loss values for a fixed seed, or summaries of generated outputs (counts, totals, or other deterministic statistics). + +Each oracle module follows the same pattern: +- Users create a derived class and implement requirements(). +- `requirements()` returns an ordered sequence of requirement objects. + +The base class provides the following: +- `report()` which returns a structured OracleReport +- `run(verbose=...)` which logs a PASS/FAIL summary and returns boolean `True`/`False` variable + +In most cases, overriding `requirements()` suficies. Custom behavior should be added only when necessary (e.g., additional post-build validation such as checking instrumentation markers, or a custom comparison/similarity policy for experiment outputs. + +### Environment setup oracle primitives (`oracle_env_setup_primitives.py`) + +The environment setup base class defines requirement primitives for verifying that: +- dependencies are installed at a specific versions (e.g., `docker`, `make`, `nodegcc`, etc.) +- configurations are portable, not hardcoded and specific to a single machine (e.g., no absolute file paths allowed) +- environment variables are correctly set (e.g., artifact binary is added to `PATH`) +- required directory structure exists + +Users need to implement a derived class from `OracleEnvSetupBase` and override `requirements(self)`. This method returns an ordered sequence of "requirement" objects, each implementing `check()` which evaluates tat particular requirement (e.g., a dependency has an exact or newer version) and returns a pass/fail outcome along with any relevant diagnostic information (message, stdout/stderr, return code, timeout, etc.). In `main.py`, users need to instantiate the derived oracle and call `run(verbose=...)`, which returns `True` only if all non-optional requirements pass. + +Below is a minimal sketch showing how a derived oracle returns a single dependency-version requirement. + +```py +import sys +from collections.abc import Sequence + +from evaluator.oracle_env_setup_primitives import ( + DependencyVersionRequirement, + OracleEnvSetupBase, + VersionCompare, +) + +class OracleEnvSetup(OracleEnvSetupBase): + def __init__(self, *, config, logger): + super().__init__(logger=logger) + self._config = config + + def requirements(self) -> Sequence[DependencyVersionRequirement]: + return ( + DependencyVersionRequirement( + name="python_version", + cmd=(sys.executable, "--version"), + required_version=(3, 10, 0), + compare=VersionCompare.GEQ, + timeout_seconds=5.0, + ), + ) +``` + + +### Artifact build oracle primitives (`oracle_artifact_build.py`) + +The artifact build base class defines requirement primitives for verifying that: +- core components can be compiled/built/installed from a initial checkout using specific build commands +- required working directories exist before commands run +- build commands complete successfully within a time bound and produce expected process outcomes (e.g., return code, stdout/stderr) + +Users need to implement a derived class from `OracleArtifactBuildBase` and override `requirements(self)`. This method returns an ordered sequence of "requirement" objects, each implementing `check()` which runs a specific build/install command under a configured working directory and returns a pass/fail outcome along with any relevant diagnostic information (message, stdout/stderr, return code, timeout, cwd, etc.). In `main.py`, users need to instantiate the derived oracle and call `run(verbose=...)`, which returns `True` only if all non-optional requirements pass. + +Below is a minimal sketch showing how a derived oracle returns a single build-command requirement: + +```py +from collections.abc import Sequence +from evaluator.oracle_artifact_build_primitives import ( + BuildCommandRequirement, + BuildRequirement, + OracleArtifactBuildBase, +) + +class OracleArtifactBuild(OracleArtifactBuildBase): + def __init__(self, *, config, logger): + super().__init__(logger=logger) + self._config = config + + def requirements(self) -> Sequence[BuildRequirement]: + return ( + BuildCommandRequirement( + name="artifact-core: make tools", + cwd=self._config.repository_paths[self._config.name], + command=( + "make", "-j8", + "tools/diamond-types/target/release/dt", + ), + timeout_seconds=60.0, + ), + ) +``` + +### Benchmark preparation oracle primitives (`oracle_benchmark_prep.py`) + +The benchmark preparation base class defines requirement primitives for verifying that: +- required benchmark/datasets downloaded succesfully and are accesible locally (e.g., directories/files created, benchmarks succesfully compiled/build/installed, etc.) +- benchmark setup steps are runnable (e.g., running functional tests) +- command output contains expected markers when applicable (e.g., check file sizes, commit hashes, etc.) + +Users need to implement a derived class from `OracleBenchmarkPrepBase` and override `requirements(self)`. This method returns an ordered sequence of "requirement" objects, each implementing `check()` which validates a path, optionally executes a setup/verification command, and returns a pass/fail outcome along with any relevant diagnostic information (message, stdout/stderr, return code, timeout, cwd, etc.). In `main.py`, users need to instantiate the derived oracle and call `run(verbose=...)`, which returns `True` only if all non-optional requirements pass. + +Below is a minimal sketch showing how a derived oracle returns two benchmark preparation requirements: one that verifies the repository is at an expected commit, and one that checks a file meets a minimum size threshold. + +```py +from collections.abc import Sequence +from evaluator.oracle_benchmark_prep_primitives import ( + BenchmarkRequirement, + OracleBenchmarkPrepBase, + Requirement, +) + +size_check_script = ( + "import os,sys\n" + "p=sys.argv[1]; m=int(sys.argv[2])\n" + "s=os.path.getsize(p)\n" + "print('OK' if s>=m else f'FAIL size={s} < min={m}')\n" +) + +class OracleBenchmarkPrep(OracleBenchmarkPrepBase): + def __init__(self, *, config, logger): + super().__init__(logger=logger) + self._config = config + + def requirements(self) -> Sequence[Requirement]: + manifest_path = self._config.ground_truth_paths["datasets"] + return ( + BenchmarkRequirement( + name="repo_commit_is_expected", + filepath=repo_root, + cmd=("git", "rev-parse", "HEAD"), + signature="3e1c2a4b5c6d7e8f9a0b1c2d3e4f5a6b7c8d9e0f", + timeout_seconds=5.0, + ), + BenchmarkRequirement( + name="dataset_file_size_at_least_min", + filepath=target_file, + cmd=(sys.executable, "-c", size_check_script, str(target_file), str(min_bytes)), + signature="OK", + timeout_seconds=5.0, + ) + ) +``` + +## Experiment runs oracle primitives (`oracle_experiment_runs.py`) + +The experiment runs base class defines requirement primitives that: +- compares experiment outputs (metrics, timings, scores, etc.) against reference values +- checks if comparisons satisfy a declared policy (e.g., element-wise equivalence, similarity coeficient with a predefined tolerance) +- when mismatch, return a compact summary describing the differences for debugging purposes + +Users need to implement a derived class from `OracleExperimentRunsBase` and override `requirements(self)`. This method returns an ordered sequence of "requirement" objects, each implementing `check()` which computes the configured comparison between observed and reference outputs and returns a pass/fail outcome along with any relevant diagnostic information (message and mismatch summaries, plus any parsing/runtime diagnostics if applicable). In `main.py`, users need to instantiate the derived oracle and call `run(verbose=...)`, which returns `True` only if all non-optional requirements pass. + +Below is a minimal sketch showing how a derived oracle returns a single similarity-threshold requirement. + +```py +from collections.abc import Sequence +from evaluator.oracle_experiment_runs_primitives import ( + ExperimentRunsRequirement, + LabeledSequenceSimilarityThresholdRequirement, + OracleExperimentRunsBase, +) + +def _parse_and_flatten_json(lines: Sequence[str]) -> list[tuple[str, float]]: + obj: Any = json.loads("\n".join(lines)) + + if not isinstance(obj, dict): + raise ValueError("timings results: expected top-level JSON object") + + out: list[tuple[str, float]] = [] + for metric, tags in obj.items(): + if not isinstance(tags, dict): + raise ValueError(f"timings results: {metric!r} must map to an object") + for tag, stats in tags.items(): + if not isinstance(stats, dict): + raise ValueError(f"timings results: {metric}.{tag} must map to an object") + for field, raw in stats.items(): + if not isinstance(field, str): + raise ValueError(f"timings results: non-string field name {field!r}") + if not isinstance(raw, (int, float)): + raise ValueError(f"timings results: {metric}.{tag}.{field} non-numeric {raw!r}") + out.append((f"{metric}.{tag}.{field}", float(raw))) + return out + +class OracleExperimentRuns(OracleExperimentRunsBase): + def __init__(self, *, config, logger): + super().__init__(logger=logger) + self._config = config + + def requirements(self) -> Sequence[ExperimentRunsRequirement]: + return ( + LabeledSequenceSimilarityThresholdRequirement( + name="timings", + label="Timings", + results_path=self._config.results_paths["timings"], + reference_path=self._config.ground_truth_paths["timings"], + threshold=self._config.similarity_ratio, + parse_results_fn=_parse_and_flatten_json, # parsing function defined by the user + parse_reference_fn=_parse_and_flatten_json, # parsing function defined by the user + ), + ) +``` + +### The `main.py` orchestrator + +A typical `main.py` evaluator implements the following + +1. Create a logger (using `utils.get_logger(...)`). +2. Build an `EntryConfig` describing repo locations, output paths, and references. +3. Instantiate each derived oracle with `(config, logger)`. +4. Run each stage in order: + - `EnvSetup.run()` + - `ArtifactBuild.run()` + - `BenchmarkPrep.run()` + - `ExperimentRuns.run()` +5. Return a final score (often via process exit code). + +For example, this is the `main.py` EgWalker's (EuroSys'25) agent evaluator bundle. + +```py +import os +import sys +from pathlib import Path +from evaluator.utils import EntryConfig, LoggerConfig, get_logger, record_result + +from oracle_env_setup import OracleEnvSetup +from oracle_artifact_build import OracleArtifactBuild +from oracle_benchmark_prep import OracleBenchmarkPrep +from oracle_experiment_runs import OracleExperimentRuns + +CONFIG = EntryConfig( + name="eurosys25-egwalker", + home_dir=Path.home() / "eurosys25_egwalker", + repository_paths={ + "eurosys25-egwalker": Path.home() / "eurosys25_egwalker" / "egwalker", + }, + results_paths={ + "timings": Path.home() / "eurosys25_egwalker" / "egwalker" / "results" / "timings.json", + }, + ground_truth_paths={ + "datasets": Path.home() / "eurosys25_egwalker" / "_agent_eval" / "refs" / "datasets.ref.json", + "timings": Path.home() / "eurosys25_egwalker" / "_agent_eval" / "refs" / "timings.ref.json", + }, + similarity_ratio=0.75, +) + +def main(argv: list[str]) -> int: + verbose = "--verbose" in argv + logger = get_logger(LoggerConfig(root_name=os.environ.get("EVAL_LOGGER_NAME", "EGWALKER-EVAL"))) + + results: dict[str, int] = {} + score = 0 + + env_ok = OracleEnvSetup(config=CONFIG, logger=logger).run(verbose=verbose) + score += record_result(results, "OracleEnvSetup", env_ok) + + build_ok = OracleArtifactBuild(config=CONFIG, logger=logger).run(verbose=verbose) + score += record_result(results, "OracleArtifactBuild", build_ok) + + prep_ok = OracleBenchmarkPrep(config=CONFIG, logger=logger).run(verbose=verbose) + score += record_result(results, "OracleBenchmarkPrep", prep_ok) + + runs_ok = OracleExperimentRuns(config=CONFIG, logger=logger).run(verbose=verbose) + score += record_result(results, "OracleExperimentRuns", runs_ok) + + logger.info("Agent scores: %s", results) + return score +``` + +### Best practices + +- Keep `requirements()` deterministic and efficient. +- Avoid interactive implementations, passing flags/config via args, etc. +- Ensure requirements are idempotent so they can be re-executed without side effects, commands are non-interactive and time-bounded. +- Provide clear error messages and include relevant command, path, flags, etc. +- Implement optional requirements as "nice-to-have" checks and output them as warnings. +- Make sure experiment output comparisons use explicit tolerances. \ No newline at end of file diff --git a/benchmarks/arteval_bench/src/evaluator/oracle_artifact_build_primitives.py b/benchmarks/arteval_bench/src/evaluator/oracle_artifact_build_primitives.py new file mode 100644 index 00000000..2f0f4a1d --- /dev/null +++ b/benchmarks/arteval_bench/src/evaluator/oracle_artifact_build_primitives.py @@ -0,0 +1,405 @@ +"""Artifact build oracle primitives. + +This module provides: + 1. Requirement types to specify build commands. + 2. An orchestrator base class that runs build checks, logs results, and returns + a pass/fail outcome. + +Derived oracles typically only override requirements() to declare a list of build +requirements to run, but they can customize command execution policies or logging +behavior if needed. +""" + +from __future__ import annotations + +import abc +import dataclasses +import logging +import os +import pathlib +import subprocess +import types +import selectors +import time + +from collections.abc import Mapping, Sequence + +from evaluator import utils + + +# ------------------------------------------------------------------------------ +# Helper functions +# ------------------------------------------------------------------------------ + + +def _summarize_process_output(stdout: str, stderr: str) -> str: + """Combines and truncates process output to keep messages readable.""" + out = stdout.strip() + err = stderr.strip() + if out and err: + combined = f"stdout:\n{out}\n\nstderr:\n{err}" + else: + combined = out or err + return utils.truncate_text(combined, + utils.DEFAULT_MAX_TRUNCATED_MESSAGE_CHARS) + + +def _require_directory(path: pathlib.Path, *, label: str) -> str | None: + """Returns an error message if path is not an existing directory.""" + if not path.exists(): + return f"{label} missing: {path}" + if not path.is_dir(): + return f"{label} is not a directory: {path}" + return None + + +# ------------------------------------------------------------------------------ +# Oracle's core logic +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass(frozen=True, slots=True) +class BuildContext: + """Context passed to build requirements. + + Attributes: + logger: Logger for diagnostics and shared policies. + """ + + logger: logging.Logger + + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class BuildCommandRequirement(utils.BaseRequirement): + """Runs a build command within a working directory. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + cwd: Base working directory. + command: Command argv to execute. + relative_workdir: Optional subdirectory within cwd used as the actual workdir. + timeout_seconds: Timeout for the command, in seconds. + env_overrides: Environment variables to override for the subprocess. + """ + + cwd: pathlib.Path + command: Sequence[str] + relative_workdir: pathlib.Path | None = None + timeout_seconds: float = 60.0 + env_overrides: Mapping[str, str] = dataclasses.field(default_factory=dict) + + def __post_init__(self) -> None: + object.__setattr__(self, "cwd", utils.to_path(self.cwd)) + if self.relative_workdir is not None: + object.__setattr__(self, "relative_workdir", utils.to_path(self.relative_workdir)) + + if isinstance(self.command, (str, bytes)): + raise TypeError(f"{self.name}: command must be a sequence of argv strings, not a single string/bytes") + + if not self.command: + raise ValueError(f"{self.name}: command must be non-empty") + + bad = [a for a in self.command if not isinstance(a, str) or a == ""] + if bad: + raise TypeError(f"{self.name}: all command argv entries must be non-empty str; bad entries: {bad!r}") + + if self.timeout_seconds <= 0: + raise ValueError(f"{self.name}: timeout (seconds) must be > 0") + + # NOTE: Be tolerant to callers passing non-str values (e.g., Path/int) by + # normalizing everything to str, since subprocess env requires str->str. + env_dict_raw = dict(self.env_overrides) + env_dict: dict[str, str] = {} + for k, v in env_dict_raw.items(): + # Preserve previous strictness for obviously broken keys. + if k is None or k == "": + raise TypeError(f"{self.name}: env_overrides contains an empty env var name: {k!r}") + env_dict[str(k)] = str(v) + + # Prevent obvious "not relative" cases early. + if self.relative_workdir is not None and self.relative_workdir.is_absolute(): + raise ValueError(f"{self.name}: relative_workdir must be a relative path, got: {self.relative_workdir}") + + object.__setattr__(self, "command", tuple(self.command)) + object.__setattr__(self, "env_overrides", types.MappingProxyType(env_dict)) + + @staticmethod + def _is_within_base_dir(*, base: pathlib.Path, target: pathlib.Path) -> bool: + """Returns True iff target is within base (after resolving symlinks). + + Assumes both paths exist (caller should validate directories first). + """ + try: + base_real = base.resolve(strict=True) + target_real = target.resolve(strict=True) + + # NOTE: Prefer pathlib semantics over string commonpath to avoid + # platform corner cases (drives, separators). This also avoids false + # positives from simple string-prefix checks. + try: + target_real.relative_to(base_real) + return True + except ValueError: + return False + except OSError: + return False + + @staticmethod + def _coerce_text(x: object) -> str: + # NOTE: utils.decode_text may not accept str in some codebases. This helper + # safely handles bytes/str/None and keeps the old behavior stable. + if x is None: + return "" + if isinstance(x, str): + return x + if isinstance(x, (bytes, bytearray, memoryview)): + return utils.decode_text(bytes(x)) + # Fallback: best-effort stringification + return str(x) + + def _run_with_limited_output( + self, + *, + workdir: pathlib.Path, + env: Mapping[str, str], + ) -> tuple[int | None, str, str, bool]: + """Run process while limiting captured output to avoid unbounded memory. + + Returns (returncode, stdout, stderr, timed_out). + """ + # NOTE: We run with stdout/stderr pipes in *binary* mode and decode ourselves. + # This avoids UnicodeDecodeError surprises while reading incrementally. + try: + proc = subprocess.Popen( + self.command, + cwd=workdir, + env=dict(env), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=False, + ) + except OSError as exc: + # Let caller map this to CheckResult.failure, preserving existing behavior. + raise + + assert proc.stdout is not None + assert proc.stderr is not None + + sel = selectors.DefaultSelector() + sel.register(proc.stdout, selectors.EVENT_READ, data="stdout") + sel.register(proc.stderr, selectors.EVENT_READ, data="stderr") + + # NOTE: Cap memory usage by storing only up to a fixed number of bytes. + # We use 4x char cap as a conservative UTF-8 upper bound. + byte_cap = int(utils.DEFAULT_MAX_CAPTURE_CHARS) * 4 + + stdout_buf = bytearray() + stderr_buf = bytearray() + + deadline = time.monotonic() + float(self.timeout_seconds) + timed_out = False + + def _read_chunk(stream) -> bytes: + # Prefer read1 when available for buffered streams. + if hasattr(stream, "read1"): + return stream.read1(8192) # type: ignore[attr-defined] + return stream.read(8192) + + # Read incrementally from both pipes until closed or timeout. + while sel.get_map(): + remaining = deadline - time.monotonic() + if remaining <= 0: + timed_out = True + break + + events = sel.select(timeout=min(0.25, remaining)) + for key, _mask in events: + stream = key.fileobj + chunk = _read_chunk(stream) + if not chunk: + try: + sel.unregister(stream) + except Exception: + pass + try: + stream.close() + except Exception: + pass + continue + + if key.data == "stdout": + if len(stdout_buf) < byte_cap: + take = min(len(chunk), byte_cap - len(stdout_buf)) + stdout_buf.extend(chunk[:take]) + # NOTE: Discard remainder to cap memory; continue draining to avoid deadlock. + else: + if len(stderr_buf) < byte_cap: + take = min(len(chunk), byte_cap - len(stderr_buf)) + stderr_buf.extend(chunk[:take]) + + if timed_out: + try: + proc.kill() + except Exception: + pass + + # Best-effort drain for a short period so we capture some tail output + # without risking hangs. + drain_deadline = time.monotonic() + 1.0 + while sel.get_map() and time.monotonic() < drain_deadline: + events = sel.select(timeout=0.1) + for key, _mask in events: + stream = key.fileobj + chunk = _read_chunk(stream) + if not chunk: + try: + sel.unregister(stream) + except Exception: + pass + try: + stream.close() + except Exception: + pass + continue + if key.data == "stdout": + if len(stdout_buf) < byte_cap: + take = min(len(chunk), byte_cap - len(stdout_buf)) + stdout_buf.extend(chunk[:take]) + else: + if len(stderr_buf) < byte_cap: + take = min(len(chunk), byte_cap - len(stderr_buf)) + stderr_buf.extend(chunk[:take]) + + # Reap the process to avoid zombies. + try: + proc.wait(timeout=5.0) + except Exception: + pass + + stdout = utils.truncate_text(self._coerce_text(stdout_buf), utils.DEFAULT_MAX_CAPTURE_CHARS) + stderr = utils.truncate_text(self._coerce_text(stderr_buf), utils.DEFAULT_MAX_CAPTURE_CHARS) + return None, stdout, stderr, True + + # Process finished or pipes closed; reap returncode. + try: + rc = proc.wait(timeout=5.0) + except Exception: + # If something odd happens, keep behavior conservative. + rc = proc.returncode + + stdout = utils.truncate_text(self._coerce_text(stdout_buf), utils.DEFAULT_MAX_CAPTURE_CHARS) + stderr = utils.truncate_text(self._coerce_text(stderr_buf), utils.DEFAULT_MAX_CAPTURE_CHARS) + return rc, stdout, stderr, False + + def check(self, ctx: BuildContext) -> utils.CheckResult: + del ctx # Deliberetly reserved for future extensions + + error = _require_directory(self.cwd, label="working directory") + if error is not None: + return utils.CheckResult.failure(error, cwd=self.cwd) + + workdir = self.cwd + if self.relative_workdir is not None: + workdir = self.cwd / self.relative_workdir + error = _require_directory(workdir, label="working directory") + if error is not None: + return utils.CheckResult.failure(error, cwd=workdir) + + # Walidate cwd and prevent ``espacping'' (e.g., ../ or symlinks) + if not self._is_within_base_dir(base=self.cwd, target=workdir): + return utils.CheckResult.failure( + f"working directory escapes base cwd: base={self.cwd} workdir={workdir}", + cwd=workdir, + ) + + env = os.environ.copy() + if self.env_overrides: + env.update(self.env_overrides) + + try: + # NOTE: Avoid capture_output=True because it can buffer unbounded output + # and spike memory; we capture incrementally with a fixed cap. + returncode, stdout, stderr, timed_out = self._run_with_limited_output( + workdir=workdir, + env=env, + ) + except OSError as exc: + return utils.CheckResult.failure( + f"failed to run command: {exc}", + stdout="", + stderr=str(exc), + returncode=None, + timed_out=False, + cwd=workdir, + ) + + if timed_out: + # Handle case when stdout/stderr is None + return utils.CheckResult.failure( + f"command timed out after {self.timeout_seconds}s", + stdout=stdout, + stderr=stderr, + returncode=None, + timed_out=True, + cwd=workdir, + ) + + if returncode != 0: + detail = _summarize_process_output(stdout, stderr) + msg = f"command failed (rc = {returncode})" + if detail: + msg = f"{msg}: {detail}" + return utils.CheckResult.failure( + msg, + stdout=stdout, + stderr=stderr, + returncode=returncode, + timed_out=False, + cwd=workdir, + ) + + return utils.CheckResult.success( + stdout=stdout, + stderr=stderr, + returncode=returncode, + cwd=workdir, + ) + + +class OracleArtifactBuildBase(abc.ABC): + """Base class for an artifact build oracle. + + Derived classes typically implement requirements() to declare build checks. + + Attributes: + _logger: Logger used for reporting and diagnostics. + """ + + _ORACLE_NAME = "ArtifactBuild" + + def __init__(self, *, logger: logging.Logger) -> None: + self._logger = logger + + @abc.abstractmethod + def requirements(self) -> Sequence[utils.BaseRequirement]: + """Returns an ordered list of build requirements to validate.""" + raise NotImplementedError + + def report(self) -> utils.OracleReport: + """Executes requirements and returns a structured report.""" + ctx = BuildContext(logger=self._logger) + return utils.build_oracle_report( + logger=self._logger, + requirements_fn=self.requirements, + check_fn=lambda req: req.check(ctx), + ) + + def run(self, *, verbose: bool = False) -> bool: + """Returns True iff all required checks pass (logs results).""" + rep = self.report() + return utils.log_oracle_report(self._logger, + label=self._ORACLE_NAME, + report=rep, + verbose=verbose) diff --git a/benchmarks/arteval_bench/src/evaluator/oracle_benchmark_prep_primitives.py b/benchmarks/arteval_bench/src/evaluator/oracle_benchmark_prep_primitives.py new file mode 100644 index 00000000..b3b7335a --- /dev/null +++ b/benchmarks/arteval_bench/src/evaluator/oracle_benchmark_prep_primitives.py @@ -0,0 +1,432 @@ +"""Benchmark preparation oracle primitives. + +This module provides: + 1. Requirement types to specify benchmark-bundle prerequisites (tools, repo + state, expected files). + 2. An orchestrator base class that runs checks, logs results, and returns a + pass/fail outcome. + +Derived oracles typically only override requirements() to declare a list of +preparation requirements (paths, commands, and optional output signatures) to +validate, but they can customize how checks are constructed if needed. +""" + +from __future__ import annotations + +import abc +import dataclasses +import logging +import os +import pathlib +import shlex +import subprocess +import types +import codecs +import locale +import selectors +import time + +from collections.abc import Mapping, Sequence + +from evaluator import utils + + +# ------------------------------------------------------------------------------ +# Basic types and constants +# ------------------------------------------------------------------------------ + + +_CommandT = str | Sequence[str] + + +# ------------------------------------------------------------------------------ +# Helper functions +# ------------------------------------------------------------------------------ + + +def _format_command(cmd: _CommandT, *, use_shell: bool) -> str: + """Returns a readable representation of command suitable for error messages.""" + if isinstance(cmd, str): + return cmd if use_shell else shlex.quote(cmd) + # NOTE: quote() used for readability display only + return " ".join(shlex.quote(str(arg)) for arg in cmd) + + +def _cwd_suffix(cwd: pathlib.Path | None) -> str: + """Formats cwd as an error-message suffix.""" + if cwd is None: + return "" + return f" [cwd = {cwd}]" + + +def _missing_path_error(path: pathlib.Path) -> str | None: + """Returns an error message if a required path does not exist.""" + if not path.exists(): + return f"path missing: {path}" + return None + + +def _run_command( + *, + cmd: _CommandT, + cwd: pathlib.Path | None, + timeout_seconds: float, + env_overrides: Mapping[str, str], + use_shell: bool, + signature: str | None, +) -> utils.CheckResult: + """Runs a command and returns a utils.CheckResult. + + Signature matching is done against raw (untruncated) stdout/stderr to avoid + false negatives, while stdout/stderr stored in the result are truncated to + bounded size for logging. + """ + env = None + if env_overrides: + env = os.environ.copy() + for k, v in env_overrides.items(): + if k is None or str(k) == "": + return utils.CheckResult.failure( + f"invalid env var name in overrides: {k!r}{_cwd_suffix(cwd)}", + stdout="", + stderr="", + returncode=None, + timed_out=False, + cwd=cwd, + ) + env[str(k)] = str(v) + + cmd_display = _format_command(cmd, use_shell=use_shell) + cwd_note = _cwd_suffix(cwd) + + cmd_run: str | Sequence[str] + if use_shell and not isinstance(cmd, str): + cmd_run = _format_command(cmd, use_shell=True) + else: + cmd_run = cmd + + max_chars = utils.DEFAULT_MAX_CAPTURE_CHARS + suffix = "..." + + def _append_bounded(buf: list[str], cur_len: int, text: str) -> tuple[int, bool]: + """Append up to max_chars, return (new_len, overflowed).""" + if cur_len >= max_chars: + return cur_len, True + remaining = max_chars - cur_len + if len(text) <= remaining: + buf.append(text) + return cur_len + len(text), False + buf.append(text[:remaining]) + return max_chars, True + + sig = signature if (signature is not None and signature.strip()) else None + sig_found_stdout = (sig is None) + sig_found_stderr = (sig is None) + k = 0 if sig is None else max(len(sig) - 1, 0) + + stdout_tail = "" + stderr_tail = "" + stderr_head = "" + + encoding = locale.getpreferredencoding(False) or "utf-8" + stdout_dec = codecs.getincrementaldecoder(encoding)(errors="replace") + stderr_dec = codecs.getincrementaldecoder(encoding)(errors="replace") + + try: + proc = subprocess.Popen( + cmd_run, + cwd=cwd, + env=env, + stdin=subprocess.DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=use_shell, + text=False, # bytes, decode incrementally + ) + except OSError as exc: + return utils.CheckResult.failure( + f"failed to run command: {cmd_display}{cwd_note}: {exc}", + stdout="", + stderr=str(exc), + returncode=None, + timed_out=False, + cwd=cwd, + ) + + assert proc.stdout is not None + assert proc.stderr is not None + + sel = selectors.DefaultSelector() + sel.register(proc.stdout, selectors.EVENT_READ, data="stdout") + sel.register(proc.stderr, selectors.EVENT_READ, data="stderr") + + stdout_parts: list[str] = [] + stderr_parts: list[str] = [] + stdout_len = 0 + stderr_len = 0 + stdout_overflow = False + stderr_overflow = False + + deadline = time.monotonic() + float(timeout_seconds) + + def _read_chunk(stream) -> bytes: + if hasattr(stream, "read1"): + return stream.read1(8192) + return stream.read(8192) + + timed_out = False + + while sel.get_map(): + remaining = deadline - time.monotonic() + if remaining <= 0: + timed_out = True + break + + for key, _mask in sel.select(timeout=min(0.25, remaining)): + stream = key.fileobj + chunk = _read_chunk(stream) + if not chunk: + try: + sel.unregister(stream) + except Exception: + pass + try: + stream.close() + except Exception: + pass + continue + + if key.data == "stdout": + text = stdout_dec.decode(chunk) + stdout_len, ov = _append_bounded(stdout_parts, stdout_len, text) + stdout_overflow = stdout_overflow or ov + if sig is not None and not sig_found_stdout: + hay = stdout_tail + text + if sig in hay: + sig_found_stdout = True + stdout_tail = hay[-k:] if k else "" + else: + text = stderr_dec.decode(chunk) + stderr_len, ov = _append_bounded(stderr_parts, stderr_len, text) + stderr_overflow = stderr_overflow or ov + if sig is not None and not sig_found_stderr: + hay = stderr_tail + text + if sig in hay: + sig_found_stderr = True + stderr_tail = hay[-k:] if k else "" + if sig is not None and k and len(stderr_head) < k: + need = k - len(stderr_head) + stderr_head += text[:need] + + if timed_out: + try: + proc.kill() + except Exception: + pass + try: + proc.wait(timeout=5.0) + except Exception: + pass + + stdout = "".join(stdout_parts) + (suffix if stdout_overflow else "") + stderr = "".join(stderr_parts) + (suffix if stderr_overflow else "") + + return utils.CheckResult.failure( + f"command timed out after {timeout_seconds}s: {cmd_display}{cwd_note}", + stdout=stdout, + stderr=stderr, + returncode=None, + timed_out=True, + cwd=cwd, + ) + + try: + proc.wait(timeout=5.0) + except Exception: + pass + + stdout = "".join(stdout_parts) + (suffix if stdout_overflow else "") + stderr = "".join(stderr_parts) + (suffix if stderr_overflow else "") + + if proc.returncode != 0: + return utils.CheckResult.failure( + f"command failed (rc = {proc.returncode}): {cmd_display}{cwd_note}", + stdout=stdout, + stderr=stderr, + returncode=proc.returncode, + timed_out=False, + cwd=cwd, + ) + + if sig is not None: + if not (sig_found_stdout or sig_found_stderr): + boundary = stdout_tail + "\n" + stderr_head + if sig not in boundary: + return utils.CheckResult.failure( + f"signature not found: {sig!r}: {cmd_display}{cwd_note}", + stdout=stdout, + stderr=stderr, + returncode=proc.returncode, + timed_out=False, + cwd=cwd, + ) + + return utils.CheckResult.success( + stdout=stdout, + stderr=stderr, + returncode=proc.returncode, + cwd=cwd, + ) + + +# ------------------------------------------------------------------------------ +# Oracle's core logic +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass(frozen=True, slots=True) +class BenchmarkContext: + """Context passed to benchmark preparation requirements. + + Attributes: + logger: Logger for diagnostics and shared policies. + """ + + logger: logging.Logger + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class FailRequirement(utils.BaseRequirement): + """A requirement that always fails with a fixed message. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + message: Failure message to report. + """ + + message: str + + def check(self, _ctx: BenchmarkContext) -> utils.CheckResult: + return utils.CheckResult.failure(self.message) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class BenchmarkRequirement(utils.BaseRequirement): + """Validates an optional filesystem path and optionally runs a command. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + filepath: Optional path that must exist; also influences working directory. + cmd: Optional command to execute (argv tokens preferred; string only with shell). + signature: Optional substring that must appear in raw stdout or stderr. + timeout_seconds: Timeout for the command, in seconds. + env_overrides: Environment variables to override for the subprocess. + use_shell: Whether to execute the command through the shell. + """ + + filepath: pathlib.Path | str | os.PathLike[str] | None = None + cmd: _CommandT | None = None + signature: str | None = None + timeout_seconds: float = 5.0 + env_overrides: Mapping[str, str] = dataclasses.field(default_factory=dict) + use_shell: bool = False + + def __post_init__(self) -> None: + if not self.name: + raise ValueError("BenchmarkRequirement.name must be non-empty") + + if self.filepath is not None and not isinstance(self.filepath, + pathlib.Path): + object.__setattr__(self, "filepath", utils.to_path(self.filepath)) + + if isinstance(self.cmd, (list, tuple)): + if not self.cmd: + raise ValueError(f"{self.name}: cmd must be non-empty") + object.__setattr__(self, "cmd", tuple(self.cmd)) + elif isinstance(self.cmd, str): + if not self.cmd.strip(): + raise ValueError(f"{self.name}: cmd must be non-empty") + if not self.use_shell: + raise ValueError( + f"{self.name}: string cmd requires use_shell = True (prefer argv tokens)" + ) + elif self.cmd is None: + pass + else: + raise TypeError( + f"{self.name}: cmd must be a string or a sequence of args") + + if self.cmd is None and self.filepath is None: + raise ValueError( + f"{self.name}: must specify at least one of cmd or filepath") + + if self.timeout_seconds <= 0: + raise ValueError(f"{self.name}: timeout_seconds must be > 0") + + if self.signature is not None and not self.signature.strip(): + object.__setattr__(self, "signature", None) + + object.__setattr__(self, "env_overrides", + types.MappingProxyType(dict(self.env_overrides))) + + def check(self, _ctx: BenchmarkContext) -> utils.CheckResult: + cwd: pathlib.Path | None = None + if self.filepath is not None: + assert isinstance(self.filepath, pathlib.Path) + error = _missing_path_error(self.filepath) + if error is not None: + return utils.CheckResult.failure(error, cwd=None) + + cwd = self.filepath if self.filepath.is_dir() else self.filepath.parent + + # If no command is provided, treat this requirement as a pure path check + if self.cmd is None: + return utils.CheckResult.success(cwd=cwd) + + return _run_command( + cmd=self.cmd, + cwd=cwd, + timeout_seconds=self.timeout_seconds, + env_overrides=self.env_overrides, + use_shell=self.use_shell, + signature=self.signature, + ) + + +class OracleBenchmarkPrepBase(abc.ABC): + """Base class for a benchmark preparation oracle. + + Derived classes typically implement requirements() to declare preparation checks. + + Attributes: + _logger: Logger used for reporting and diagnostics. + """ + + _ORACLE_NAME = "BenchmarkPrep" + + def __init__(self, *, logger: logging.Logger) -> None: + self._logger = logger + + @abc.abstractmethod + def requirements(self) -> Sequence[utils.BaseRequirement]: + """Returns an ordered list of requirements to validate.""" + raise NotImplementedError + + def report(self) -> utils.OracleReport: + """Executes requirements and returns a structured report.""" + ctx = BenchmarkContext(logger=self._logger) + return utils.build_oracle_report( + logger=self._logger, + requirements_fn=self.requirements, + check_fn=lambda req: req.check(ctx), + ) + + def run(self, *, verbose: bool = False) -> bool: + """Returns True iff all required checks pass (logs results).""" + rep = self.report() + return utils.log_oracle_report(self._logger, + label=self._ORACLE_NAME, + report=rep, + verbose=verbose) diff --git a/benchmarks/arteval_bench/src/evaluator/oracle_env_setup_primitives.py b/benchmarks/arteval_bench/src/evaluator/oracle_env_setup_primitives.py new file mode 100644 index 00000000..07a2dc62 --- /dev/null +++ b/benchmarks/arteval_bench/src/evaluator/oracle_env_setup_primitives.py @@ -0,0 +1,364 @@ +"""Environment setup oracle primitives. + +This module provides: + 1. Requirement types to specify environment dependencies, variables, and + directory structure. + 2. An orchestrator base class that runs checks, logs results, and returns a + pass/fail outcome. + +Derived oracles typically only override requirements() to declare a list of +requirements to check, but they can customize behavior if needed. +""" + +from __future__ import annotations + +import abc +import dataclasses +import enum +import logging +import os +import re +import shutil +import subprocess + +from collections.abc import Sequence +import pathlib + +from evaluator import utils + + +# ------------------------------------------------------------------------------ +# Basic types and constants +# ------------------------------------------------------------------------------ + + +SemanticVersion = tuple[int, int, int] + +@enum.unique +class VersionCompare(enum.Enum): + """Comparison operator for validating a discovered version.""" + + EQ = "eq" + GEQ = "geq" + LEQ = "leq" + + +@enum.unique +class EnvQuantifier(enum.Enum): + """Matching mode for validating environment variable values.""" + + EXACT = "exact" + CONTAINS = "contains" + REGEX = "regex" + + +@enum.unique +class PathType(enum.Enum): + """Required filesystem object type for a path check.""" + + ANY = "any" + FILE = "file" + DIRECTORY = "directory" + + +# ------------------------------------------------------------------------------ +# Helper functions +# ------------------------------------------------------------------------------ + + +def _parse_semantic_version(text: str) -> SemanticVersion | None: + """Extract the first X.Y(.Z) token from text.""" + match = re.compile(r"(?:^|\s)v?(\d+)\.(\d+)(?:\.(\d+))?").search(text) + if not match: + return None + major = int(match.group(1)) + minor = int(match.group(2)) + patch = int(match.group(3)) if match.group(3) is not None else 0 + return (major, minor, patch) + + +def _format_version(v: SemanticVersion) -> str: + return f"{v[0]}.{v[1]}.{v[2]}" + + +def _normalize_path_entry(entry: str) -> str: + """Normalizes a PATH entry for comparison across platforms.""" + return os.path.normcase(os.path.normpath(entry.strip())) + + +def _split_path_list(value: str) -> list[str]: + return [e.strip() for e in value.split(os.pathsep) if e.strip()] + + +# ------------------------------------------------------------------------------ +# Oracle's core logic +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class DependencyVersionRequirement(utils.BaseRequirement): + """Checks that an executable exists and satisfies a semantic version constraint. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + command: Command argv used to query a version (e.g., ["python", "--version"]). + required_version: Minimum/required semantic version as (major, minor, patch). + compare: Comparison operator to apply against required_version. + version_regex: Optional regex with a capturing group for the version token. + timeout_seconds: Timeout for the version command, in seconds. + """ + + cmd: Sequence[str] + required_version: SemanticVersion + compare: VersionCompare = VersionCompare.GEQ + version_regex: str | None = None + timeout_seconds: float = 5.0 + + _version_pattern: re.Pattern[str] | None = dataclasses.field(init=False, + repr=False, + default=None) + + def __post_init__(self) -> None: + if not self.cmd: + raise ValueError(f"{self.name}: command must be non-empty") + if self.timeout_seconds <= 0: + raise ValueError(f"{self.name}: timeout_seconds must be > 0") + object.__setattr__(self, "command", tuple(self.cmd)) + + if self.version_regex is not None: + pattern = re.compile(self.version_regex, flags=re.IGNORECASE) + if pattern.groups < 1: + raise ValueError( + f"{self.name}: version_regex must contain a capturing group") + object.__setattr__(self, "_version_pattern", pattern) + + def check(self) -> utils.CheckResult: + executable = self.cmd[0] + resolved = shutil.which(executable) + if resolved is None: + return utils.CheckResult.failure(f"not found on PATH: {executable!r}") + + try: + proc = subprocess.run( + (resolved, *self.cmd[1:]), + capture_output=True, + text=True, + check=False, + timeout=self.timeout_seconds, + ) + stdout = utils.decode_text(proc.stdout) + stderr = utils.decode_text(proc.stderr) + except subprocess.TimeoutExpired as exc: + stdout = utils.decode_text(exc.stdout) + stderr = utils.decode_text(exc.stderr) + return utils.CheckResult.failure( + f"version command timed out after {self.timeout_seconds}s", + stdout=stdout, + stderr=stderr, + returncode=None, + timed_out=True, + cwd=None, + ) + except OSError as exc: + return utils.CheckResult.failure( + f"failed to run {executable!r}: {exc}", + stdout="", + stderr=str(exc), + returncode=None, + timed_out=False, + cwd=None, + ) + + combined = (stdout + "\n" + stderr).strip() + + if proc.returncode != 0: + detail = combined if combined else f"rc = {proc.returncode}" + return utils.CheckResult.failure( + f"version command failed: {detail}", + stdout=stdout, + stderr=stderr, + returncode=proc.returncode, + timed_out=False, + cwd=None, + ) + + candidate = combined + if self._version_pattern is not None: + re_match = self._version_pattern.search(candidate) + if not re_match: + return utils.CheckResult.failure( + "version_regex did not match output", + stdout=stdout, + stderr=stderr, + returncode=proc.returncode, + ) + candidate = re_match.group(1) + + found = _parse_semantic_version(candidate) + if found is None: + return utils.CheckResult.failure( + "could not parse version from output", + stdout=stdout, + stderr=stderr, + returncode=proc.returncode, + ) + + if self.compare == VersionCompare.EQ: + ok = found == self.required_version + op = "==" + elif self.compare == VersionCompare.GEQ: + ok = found >= self.required_version + op = ">=" + else: + ok = found <= self.required_version + op = "<=" + + if not ok: + return utils.CheckResult.failure( + f"version {_format_version(found)} does not satisfy " + f"{op} {_format_version(self.required_version)}", + stdout=stdout, + stderr=stderr, + returncode=proc.returncode, + ) + return utils.CheckResult.success( + stdout=stdout, + stderr=stderr, + returncode=proc.returncode, + cwd=None, + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class EnvironmentVariableRequirement(utils.BaseRequirement): + """Validates an environment variable using exact/contains/regex semantics. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + env_var: Environment variable name to check. + expected: Expected value or expected entry/pattern (depending on quantifier). + quantifier: Matching mode to apply when comparing actual vs expected. + """ + + env_var: str + expected: str + quantifier: EnvQuantifier = EnvQuantifier.EXACT + + _expected_pattern: re.Pattern[str] | None = dataclasses.field(init=False, + repr=False, + default=None) + + def __post_init__(self) -> None: + if not self.env_var: + raise ValueError(f"{self.name}: env_var must be non-empty") + if self.quantifier in (EnvQuantifier.CONTAINS, EnvQuantifier.REGEX): + if not self.expected: + raise ValueError(f"{self.name}: expected must be non-empty") + if self.quantifier == EnvQuantifier.REGEX: + object.__setattr__(self, "_expected_pattern", re.compile(self.expected)) + + def check(self) -> utils.CheckResult: + actual = os.environ.get(self.env_var) + if actual is None: + return utils.CheckResult.failure("not set") + + if self.quantifier == EnvQuantifier.EXACT: + if actual == self.expected: + return utils.CheckResult.success() + return utils.CheckResult.failure( + f"expected {self.expected!r}, got {actual!r}") + + entries = _split_path_list(actual) + + if self.quantifier == EnvQuantifier.CONTAINS: + want = _normalize_path_entry(self.expected) + normalized = [_normalize_path_entry(e) for e in entries] + if want in normalized: + return utils.CheckResult.success() + return utils.CheckResult.failure(f"missing entry {self.expected!r}") + + # EnvQuantifier.REGEX + assert self._expected_pattern is not None + if any(self._expected_pattern.search(e) for e in entries): + return utils.CheckResult.success() + return utils.CheckResult.failure( + f"no entry matches regex {self.expected!r}") + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class FilesystemPathRequirement(utils.BaseRequirement): + """Checks whether a filesystem path exists and optionally enforces its type. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + path: Path to validate. + path_type: Required type (any/file/directory). + """ + + path: pathlib.Path | str | os.PathLike[str] + path_type: PathType = PathType.ANY + + def __post_init__(self) -> None: + object.__setattr__(self, "path", utils.to_path(self.path)) + if str(self.path).strip() == "": + raise ValueError(f"{self.name}: path must be non-empty") + + def check(self) -> utils.CheckResult: + if not self.path.exists(): + if self.path_type == PathType.FILE: + return utils.CheckResult.failure(f"file missing: {self.path}") + if self.path_type == PathType.DIRECTORY: + return utils.CheckResult.failure(f"directory missing: {self.path}") + return utils.CheckResult.failure(f"path missing: {self.path}") + + if self.path_type == PathType.ANY: + return utils.CheckResult.success() + + if self.path_type == PathType.FILE: + if self.path.is_file(): + return utils.CheckResult.success() + return utils.CheckResult.failure(f"expected file: {self.path}") + + # PathType.DIRECTORY + if self.path.is_dir(): + return utils.CheckResult.success() + return utils.CheckResult.failure(f"expected directory: {self.path}") + + +class OracleEnvSetupBase(abc.ABC): + """Base class for an environment setup oracle. + + Derived classes typically implement requirements() to declare what to check. + + Attributes: + _logger: Logger used for reporting and diagnostics. + """ + + _ORACLE_NAME = "EnvironmentSetup" + + def __init__(self, logger: logging.Logger) -> None: + self._logger = logger + + @abc.abstractmethod + def requirements(self) -> Sequence[utils.BaseRequirement]: + """Returns an ordered list of requirements to validate.""" + raise NotImplementedError + + def report(self) -> utils.OracleReport: + """Executes requirements and returns a structured report.""" + return utils.build_oracle_report( + logger=self._logger, + requirements_fn=self.requirements, + check_fn=lambda req: req.check(), + ) + + def run(self, *, verbose: bool = False) -> bool: + """Returns True iff all required checks pass (logs results).""" + rep = self.report() + return utils.log_oracle_report(self._logger, + label=self._ORACLE_NAME, + report=rep, + verbose=verbose) diff --git a/benchmarks/arteval_bench/src/evaluator/oracle_experiment_runs_primitives.py b/benchmarks/arteval_bench/src/evaluator/oracle_experiment_runs_primitives.py new file mode 100644 index 00000000..16b49534 --- /dev/null +++ b/benchmarks/arteval_bench/src/evaluator/oracle_experiment_runs_primitives.py @@ -0,0 +1,631 @@ +"""Experiment runs oracle primitives. + +This module provides: + 1. List-level similarity metrics (Jaccard, dot product, cosine, Pearson, + min-max). + 2. Elementwise comparison utilities (equality, similarity scoring, threshold + checks). + 3. Requirement types that adapt these comparisons into utils.CheckResult objects. + 4. An orchestrator base class that runs checks, logs results, and returns a + pass/fail outcome. + +Derived oracles typically only override requirements() to declare a list of +numeric comparison requirements (similarity or elementwise checks) to evaluate, but +they can customize metrics, thresholds, or comparison behavior if needed. +""" + +from __future__ import annotations + +import abc +import dataclasses +import enum +import math +import typing + +from collections.abc import Callable, Sequence + +from evaluator import utils + + +# --------------------------------------------------------------------------- +# Basic types and constants +# --------------------------------------------------------------------------- + + +_CmpT = typing.TypeVar("_CmpT") + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Compared(typing.Generic[_CmpT]): + """A single observed-vs-reference comparison record. + + Attributes: + observed: Value produced by the experiment run. + reference: Expected/ground-truth value. + result: Comparison result (e.g., bool or float score). + """ + + observed: float + reference: float + result: _CmpT + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +def _is_nan(x: float) -> bool: + return math.isnan(x) + + +def _require_equal_lengths( + left: Sequence[float], + right: Sequence[float], + *, + label: str, +) -> None: + if len(left) != len(right): + raise ValueError( + f"{label}: length mismatch: left has {len(left)}, right has {len(right)}" + ) + + +def _require_all_finite(values: Sequence[float], *, label: str) -> None: + for i, v in enumerate(values): + if not math.isfinite(v): + raise ValueError(f"{label}: non-finite value at index {i}: {v!r}") + + +def _jaccard_set_similarity(left: Sequence[float], + right: Sequence[float]) -> float: + """Jaccard similarity treating inputs as sets (order/duplicates ignored).""" + + def _normalize(x: float) -> object: + if _is_nan(x): + return ("nan",) + return x + + left_norm = [_normalize(x) for x in left] + right_norm = [_normalize(x) for x in right] + + a = set(left_norm) + b = set(right_norm) + + if len(a) != len(left_norm): + raise ValueError("jaccard_set_similarity: left input contains duplicates (multiset not allowed)") + if len(b) != len(right_norm): + raise ValueError("jaccard_set_similarity: right input contains duplicates (multiset not allowed)") + + union = a | b + if not union: + return 1.0 + return len(a & b) / len(union) + + +def _dot_product(left: Sequence[float], right: Sequence[float]) -> float: + """Dot product (unbounded). Requires equal lengths and finite inputs.""" + _require_equal_lengths(left, right, label="dot_product") + _require_all_finite(left, label="dot_product.left") + _require_all_finite(right, label="dot_product.right") + return sum(a * b for a, b in zip(left, right, strict=True)) + + +def _cosine_similarity(left: Sequence[float], right: Sequence[float]) -> float: + """Cosine similarity in [-1, 1]. Requires equal lengths and finite inputs. + + Policy for zero vectors: + - If both have zero norm, returns 1.0 (identical "no-signal"). + - If exactly one has zero norm, returns 0.0. + """ + _require_equal_lengths(left, right, label="cosine_similarity") + _require_all_finite(left, label="cosine_similarity.left") + _require_all_finite(right, label="cosine_similarity.right") + + dot = 0.0 + norm_left = 0.0 + norm_right = 0.0 + for a, b in zip(left, right, strict=True): + dot += a * b + norm_left += a * a + norm_right += b * b + + if norm_left == 0.0 and norm_right == 0.0: + return 1.0 + if norm_left == 0.0 or norm_right == 0.0: + return 0.0 + return dot / (math.sqrt(norm_left) * math.sqrt(norm_right)) + + +def _pearson_similarity(left: Sequence[float], right: Sequence[float]) -> float: + """Pearson correlation coefficient in [-1, 1]. + + Requires: + - equal lengths + - at least 2 samples + - finite inputs + + Policy for zero variance: + - If both are constant and identical, returns 1.0. + - If either has zero variance but they differ, returns 0.0. + """ + _require_equal_lengths(left, right, label="pearson_similarity") + if len(left) < 2: + raise ValueError( + f"pearson_similarity: need at least 2 samples, got {len(left)}") + _require_all_finite(left, label="pearson_similarity.left") + _require_all_finite(right, label="pearson_similarity.right") + + n = float(len(left)) + mean_left = sum(left) / n + mean_right = sum(right) / n + + cov = 0.0 + var_left = 0.0 + var_right = 0.0 + for a, b in zip(left, right, strict=True): + da = a - mean_left + db = b - mean_right + cov += da * db + var_left += da * da + var_right += db * db + + if var_left == 0.0 and var_right == 0.0: + return 1.0 if list(left) == list(right) else 0.0 + if var_left == 0.0 or var_right == 0.0: + return 0.0 + + return cov / (math.sqrt(var_left) * math.sqrt(var_right)) + + +def _min_max_similarity(left: Sequence[float], right: Sequence[float]) -> float: + """Min-max similarity in [0, 1] for nonnegative vectors. + + minmax(x, y) = sum_i min(x_i, y_i) / sum_i max(x_i, y_i) + + Requires: + - equal lengths + - finite inputs + - nonnegative inputs + + Policy for all-zeros: + - If denominator is 0.0, returns 1.0 (identical "no-signal"). + """ + _require_equal_lengths(left, right, label="min_max_similarity") + _require_all_finite(left, label="min_max_similarity.left") + _require_all_finite(right, label="min_max_similarity.right") + + num = 0.0 + den = 0.0 + for i, (a, b) in enumerate(zip(left, right, strict=True)): + if a < 0.0 or b < 0.0: + raise ValueError( + f"min_max_similarity: negative value at index {i}: left={a!r}, right={b!r}" + ) + num += min(a, b) + den += max(a, b) + + if den == 0.0: + return 1.0 + return num / den + + +def _numbers_equal(a: float, b: float, *, nan_equal: bool) -> bool: + if nan_equal and _is_nan(a) and _is_nan(b): + return True + return a == b + + +def _default_numeric_similarity(a: float, b: float, *, + abs_epsilon: float) -> float: + """Similarity score where 1.0 means identical; decreases with relative error. + + score = 1 - |a-b| / max(|a|, |b|, abs_epsilon) + + Special cases: + - NaN vs NaN => 1.0, NaN vs non-NaN => 0.0 + - +inf vs +inf or -inf vs -inf => 1.0, otherwise 0.0 + """ + if _is_nan(a) or _is_nan(b): + return 1.0 if (_is_nan(a) and _is_nan(b)) else 0.0 + + if math.isinf(a) or math.isinf(b): + return 1.0 if a == b else 0.0 + + denom = max(abs(a), abs(b), abs_epsilon) + score = 1.0 - (abs(a - b) / denom) + + if score < 0.0: + return 0.0 + if score > 1.0: + return 1.0 + return score + + + +def _elementwise_similarity_scores( + observed: Sequence[float], + reference: Sequence[float], + *, + similarity_fn: Callable[[float, float], float] | None, + abs_epsilon: float, +) -> list[Compared[float]]: + _require_equal_lengths(observed, + reference, + label="elementwise_similarity_scores") + if abs_epsilon <= 0: + raise ValueError(f"elementwise_similarity_scores: abs_epsilon must be > 0") + + if similarity_fn is None: + + def similarity_fn(a: float, b: float) -> float: + return _default_numeric_similarity(a, b, abs_epsilon=abs_epsilon) + + out: list[Compared[float]] = [] + for a, b in zip(observed, reference, strict=True): + out.append(Compared(observed=a, reference=b, result=similarity_fn(a, b))) + return out + + +def _elementwise_equal( + observed: Sequence[float], + reference: Sequence[float], + *, + nan_equal: bool, +) -> list[Compared[bool]]: + _require_equal_lengths(observed, reference, label="elementwise_equal") + out: list[Compared[bool]] = [] + for a, b in zip(observed, reference, strict=True): + out.append( + Compared(observed=a, + reference=b, + result=_numbers_equal(a, b, nan_equal=nan_equal))) + return out + + +def _summarize_mismatches_bool( + comparisons: Sequence[Compared[bool]], + *, + max_items: int = 10, +) -> str: + mismatches: list[str] = [] + total_bad = 0 + for i, c in enumerate(comparisons): + if not c.result: + total_bad += 1 + if len(mismatches) < max_items: + mismatches.append( + f"[{i}] observed={c.observed!r}, reference={c.reference!r}") + if not mismatches: + return "" + more = total_bad - len(mismatches) + suffix = f"\n... ({more} more)" if more > 0 else "" + return "mismatches:\n" + "\n".join(mismatches) + suffix + + +def _summarize_mismatches_threshold( + scores: Sequence[Compared[float]], + *, + threshold: float, + max_items: int = 10, +) -> str: + mismatches: list[str] = [] + total_bad = 0 + for i, c in enumerate(scores): + if c.result < threshold: + total_bad += 1 + if len(mismatches) < max_items: + mismatches.append( + f"[{i}] score={c.result:.6f} observed={c.observed!r}, reference={c.reference!r}" + ) + if not mismatches: + return "" + more = total_bad - len(mismatches) + suffix = f"\n... ({more} more)" if more > 0 else "" + return "mismatches:\n" + "\n".join(mismatches) + suffix + + +# --------------------------------------------------------------------------- +# Oracle's core logic +# --------------------------------------------------------------------------- + + +@enum.unique +class SimilarityMetric(enum.Enum): + """List-level metric identifier for computing a single similarity score.""" + + JACCARD_SET = "jaccard_set" + DOT_PRODUCT = "dot_product" + COSINE = "cosine" + PEARSON = "pearson" + MIN_MAX = "min_max" + + +class SimilarityMetrics: + """Namespace for list-level similarity metric implementations.""" + + @staticmethod + def compute( + metric: SimilarityMetric, + left: Sequence[float], + right: Sequence[float], + ) -> float: + if metric == SimilarityMetric.JACCARD_SET: + return _jaccard_set_similarity(left, right) + if metric == SimilarityMetric.DOT_PRODUCT: + return _dot_product(left, right) + if metric == SimilarityMetric.COSINE: + return _cosine_similarity(left, right) + if metric == SimilarityMetric.PEARSON: + return _pearson_similarity(left, right) + if metric == SimilarityMetric.MIN_MAX: + return _min_max_similarity(left, right) + raise ValueError(f"unsupported similarity metric: {metric!r}") + + +class ElementwiseMetrics: + """Namespace for elementwise comparison implementations.""" + + @staticmethod + def equal( + observed: Sequence[float], + reference: Sequence[float], + *, + nan_equal: bool = True, + ) -> list[Compared[bool]]: + return _elementwise_equal(observed, reference, nan_equal=nan_equal) + + @staticmethod + def similarity_scores( + observed: Sequence[float], + reference: Sequence[float], + *, + similarity_fn: Callable[[float, float], float] | None = None, + abs_epsilon: float = 1e-12, + ) -> list[Compared[float]]: + return _elementwise_similarity_scores( + observed, + reference, + similarity_fn=similarity_fn, + abs_epsilon=abs_epsilon, + ) + + @staticmethod + def similarity_threshold( + observed: Sequence[float], + reference: Sequence[float], + *, + threshold: float, + similarity_fn: Callable[[float, float], float] | None = None, + abs_epsilon: float = 1e-12, + ) -> list[Compared[bool]]: + scores = ElementwiseMetrics.similarity_scores( + observed, + reference, + similarity_fn=similarity_fn, + abs_epsilon=abs_epsilon, + ) + if not math.isfinite(threshold): + raise ValueError( + f"similarity_threshold: threshold must be finite, got {threshold!r}") + + out: list[Compared[bool]] = [] + for s in scores: + out.append( + Compared(observed=s.observed, + reference=s.reference, + result=(s.result >= threshold))) + return out + + +@dataclasses.dataclass(...) +class ExperimentRunsContext: + """Context passed to experiment-run requirements. + + Attributes: + logger: Logger for diagnostics and shared policies. + """ + + logger: object + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ListSimilarityRequirement(utils.BaseRequirement): + """Checks a list-level similarity metric against a minimum score. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + observed: Observed numeric sequence. + reference: Reference numeric sequence. + metric: Similarity metric to compute. + min_similarity: Minimum acceptable similarity score. + """ + + observed: Sequence[float] + reference: Sequence[float] + metric: SimilarityMetric = SimilarityMetric.JACCARD_SET + min_similarity: float = 1.0 + + def __post_init__(self) -> None: + if not math.isfinite(self.min_similarity): + raise ValueError(f"{self.name}: min_similarity must be finite") + object.__setattr__(self, "observed", tuple(self.observed)) + object.__setattr__(self, "reference", tuple(self.reference)) + + def check(self, ctx: ExperimentRunsContext) -> utils.CheckResult: + del ctx # Reserved for shared policies/logging + try: + score = SimilarityMetrics.compute(self.metric, self.observed, + self.reference) + except ValueError as exc: + return utils.CheckResult.failure(str(exc)) + + if score < self.min_similarity: + return utils.CheckResult.failure( + f"{self.metric.value} similarity {score:.6f} < min_similarity {self.min_similarity:.6f}" + ) + return utils.CheckResult.success() + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ElementwiseEqualityRequirement(utils.BaseRequirement): + """Checks elementwise equality for all entries. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + observed: Observed numeric sequence. + reference: Reference numeric sequence. + nan_equal: Whether NaN should be considered equal to NaN. + max_mismatches_to_report: Maximum mismatches to include in the failure message. + """ + + observed: Sequence[float] + reference: Sequence[float] + nan_equal: bool = True + max_mismatches_to_report: int = 10 + + def __post_init__(self) -> None: + if self.max_mismatches_to_report <= 0: + raise ValueError(f"{self.name}: max_mismatches_to_report must be > 0") + object.__setattr__(self, "observed", tuple(self.observed)) + object.__setattr__(self, "reference", tuple(self.reference)) + + def check(self, ctx: ExperimentRunsContext) -> utils.CheckResult: + del ctx + try: + comps = ElementwiseMetrics.equal(self.observed, + self.reference, + nan_equal=self.nan_equal) + except ValueError as exc: + return utils.CheckResult.failure(str(exc)) + + if all(c.result for c in comps): + return utils.CheckResult.success() + + detail = _summarize_mismatches_bool(comps, + max_items=self.max_mismatches_to_report) + msg = "elementwise equality check failed" + if detail: + msg = f"{msg}\n{detail}" + return utils.CheckResult.failure(msg) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ElementwiseSimilarityThresholdRequirement(utils.BaseRequirement): + """Checks elementwise similarity scores against a threshold for all entries. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + observed: Observed numeric sequence. + reference: Reference numeric sequence. + threshold: Minimum acceptable similarity score for each element. + abs_epsilon: Absolute epsilon used by the default similarity function. + max_mismatches_to_report: Maximum mismatches to include in the failure message. + """ + + observed: Sequence[float] + reference: Sequence[float] + threshold: float + abs_epsilon: float = 1e-12 + max_mismatches_to_report: int = 10 + + def __post_init__(self) -> None: + if not math.isfinite(self.threshold): + raise ValueError(f"{self.name}: threshold must be finite") + if self.abs_epsilon <= 0: + raise ValueError(f"{self.name}: abs_epsilon must be > 0") + if self.max_mismatches_to_report <= 0: + raise ValueError(f"{self.name}: max_mismatches_to_report must be > 0") + object.__setattr__(self, "observed", tuple(self.observed)) + object.__setattr__(self, "reference", tuple(self.reference)) + + def check(self, ctx: ExperimentRunsContext) -> utils.CheckResult: + del ctx + try: + scores = ElementwiseMetrics.similarity_scores( + self.observed, + self.reference, + abs_epsilon=self.abs_epsilon, + ) + except ValueError as exc: + return utils.CheckResult.failure(str(exc)) + + if all(s.result >= self.threshold for s in scores): + return utils.CheckResult.success() + + detail = _summarize_mismatches_threshold( + scores, + threshold=self.threshold, + max_items=self.max_mismatches_to_report, + ) + msg = f"elementwise similarity below threshold {self.threshold:.6f}" + if detail: + msg = f"{msg}\n{detail}" + return utils.CheckResult.failure(msg) + + +class OracleExperimentRunsBase(abc.ABC): + """Base class for an experiment-runs oracle. + + Derived classes typically implement requirements() to declare experiment checks. + + Attributes: + _logger: Logger used for reporting and diagnostics. + """ + + _ORACLE_NAME = "ExperimentRuns" + + def __init__(self, *, logger: object) -> None: + self._logger = logger + + @staticmethod + def similarity( + metric: SimilarityMetric, + left: Sequence[float], + right: Sequence[float], + ) -> float: + return SimilarityMetrics.compute(metric, left, right) + + @staticmethod + def elementwise_equal( + observed: Sequence[float], + reference: Sequence[float], + *, + nan_equal: bool = True, + ) -> list[Compared[bool]]: + return ElementwiseMetrics.equal(observed, reference, nan_equal=nan_equal) + + @staticmethod + def elementwise_similarity_scores( + observed: Sequence[float], + reference: Sequence[float], + *, + abs_epsilon: float = 1e-12, + ) -> list[Compared[float]]: + return ElementwiseMetrics.similarity_scores(observed, + reference, + abs_epsilon=abs_epsilon) + + @abc.abstractmethod + def requirements(self) -> Sequence[utils.BaseRequirement]: + """Returns an ordered list of requirements to validate.""" + raise NotImplementedError + + def report(self) -> utils.OracleReport: + """Executes requirements and returns a structured report.""" + ctx = ExperimentRunsContext(logger=self._logger) + return utils.build_oracle_report( + logger=self._logger, + requirements_fn=self.requirements, + check_fn=lambda req: req.check(ctx), + ) + + def run(self, *, verbose: bool = False) -> bool: + """Returns True iff all required checks pass (logs results).""" + rep = self.report() + return utils.log_oracle_report(self._logger, + label=self._ORACLE_NAME, + report=rep, + verbose=verbose) diff --git a/benchmarks/arteval_bench/src/evaluator/utils.py b/benchmarks/arteval_bench/src/evaluator/utils.py new file mode 100644 index 00000000..18afb499 --- /dev/null +++ b/benchmarks/arteval_bench/src/evaluator/utils.py @@ -0,0 +1,387 @@ +"""Shared types and helpers for oracle evaluation. + +Includes dataclasses for check outcomes and oracle reports, logger configuration, +and helper functions for building and logging oracle results. +""" + +from __future__ import annotations + +import dataclasses +import logging +import os +import pathlib +import typing +import sys + +from collections.abc import Callable, MutableMapping, Sequence + +# ------------------------------------------------------------------------------ +# Constants and definitions +# ------------------------------------------------------------------------------ + +_LOG_FORMAT = "%(asctime)s | %(levelname)s | %(name)s | %(message)s" +_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +DEFAULT_MAX_TRUNCATED_MESSAGE_CHARS = 4096 +DEFAULT_MAX_CAPTURE_CHARS = 32768 + +Version = typing.Tuple[int, int, int] + + +# ------------------------------------------------------------------------------ +# Shared config helpers +# ---------------------------- + + +@dataclasses.dataclass(frozen=True) +class EntryConfig: + """Shared configuration contract across all evaluation bundles. + + Attributes: + name: Entry name used for reporting. + home_dir: Base directory for the entry. + repository_paths: Named repository root paths. + results_paths: Named result output paths. + ground_truth_paths: Named ground-truth paths. + similarity_ratio: Default similarity ratio threshold used by evaluators. + """ + + name: str + home_dir: pathlib.Path + + repository_paths: typing.Dict[str, pathlib.Path] = typing.field( + default_factory=dict) + results_paths: typing.Dict[str, + pathlib.Path] = typing.field(default_factory=dict) + ground_truth_paths: typing.Dict[str, pathlib.Path] = typing.field( + default_factory=dict) + + similarity_ratio: float = 0.75 + + +@dataclasses.dataclass(frozen=True, slots=True) +class CheckResult: + """Result of running a single check. + + Attributes: + ok: Whether the check passed. + message: Short human-readable summary (suitable for logs/UI). + stdout: Captured stdout, if applicable. + stderr: Captured stderr, if applicable. + returncode: Process return code, if applicable. + timed_out: True if a subprocess timed out. + cwd: Working directory used, if applicable. + """ + + ok: bool + message: str = "" + stdout: str = "" + stderr: str = "" + returncode: int | None = None + timed_out: bool = False + cwd: pathlib.Path | None = None + + @classmethod + def success( + cls, + *, + stdout: str = "", + stderr: str = "", + returncode: int | None = 0, + cwd: pathlib.Path | None = None, + ) -> "CheckResult": + return cls( + ok=True, + message="", + stdout=stdout, + stderr=stderr, + returncode=returncode, + timed_out=False, + cwd=cwd, + ) + + @classmethod + def failure( + cls, + message: str, + *, + stdout: str = "", + stderr: str = "", + returncode: int | None = None, + timed_out: bool = False, + cwd: pathlib.Path | None = None, + ) -> "CheckResult": + return cls( + ok=False, + message=message, + stdout=stdout, + stderr=stderr, + returncode=returncode, + timed_out=timed_out, + cwd=cwd, + ) + + +@dataclasses.dataclass(frozen=True, slots=True) +class RequirementOutcome: + """Outcome of running one requirement. + + Attributes: + name: Requirement name. + optional: Whether this requirement is optional. + result: Result of running the requirement check. + """ + + name: str + optional: bool + result: CheckResult + + +@dataclasses.dataclass(frozen=True, slots=True) +class OracleReport: + """Aggregated outcome of running multiple requirements. + + Attributes: + ok: True if all non-optional requirements passed. + errors: Outcomes for failed non-optional requirements. + warnings: Outcomes for failed optional requirements. + outcomes: Outcomes for all requirements, in execution order. + """ + + ok: bool + errors: tuple[RequirementOutcome, ...] = () + warnings: tuple[RequirementOutcome, ...] = () + outcomes: tuple[RequirementOutcome, ...] = () + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class BaseRequirement(abc.ABC): + """Abstract base class for a single benchmark preparation requirement. + + Attributes: + name: Human-readable requirement name for logs and reports. + optional: Whether failure should be treated as a warning instead of an error. + """ + + name: str + optional: bool = False + + @abc.abstractmethod + def check(self, ctx: BenchmarkContext) -> utils.CheckResult: + """Evaluates the requirement.""" + raise NotImplementedError + + +# ---------------------------- +# Logging helpers +# ---------------------------- + + +@dataclasses.dataclass(frozen=True, slots=True) +class LoggerConfig: + """Configuration for the bundle logger. + + Attributes: + root_name: Root logger name for the bundle. + logs_dir: Directory for log files (if configured/used). + console_level: Logging level for console output. + root_level: Logging level for the root logger. + """ + + root_name: str + logs_dir: pathlib.Path = pathlib.Path("logs") + console_level: int = logging.INFO + root_level: int = logging.DEBUG + + +def truncate_text(text: str, max_chars: int, *, suffix: str = "...") -> str: + """Truncates text to at most max_chars characters.""" + if len(text) <= max_chars: + return text + return text[:max_chars] + suffix + + +def log_result_details(logger: logging.Logger, result: CheckResult) -> None: + """Logs extra details for a CheckResult (used when verbose = True).""" + if result.cwd is not None: + logger.info(" cwd: %s", result.cwd) + if result.returncode is not None: + logger.info(" returncode: %s", result.returncode) + if result.timed_out: + logger.info(" timed_out: True") + + if result.stdout: + logger.info( + " stdout:\n%s", + truncate_text(result.stdout, DEFAULT_MAX_TRUNCATED_MESSAGE_CHARS)) + if result.stderr: + logger.info( + " stderr:\n%s", + truncate_text(result.stderr, DEFAULT_MAX_TRUNCATED_MESSAGE_CHARS)) + + +def _is_console_handler(h: logging.Handler) -> bool: + """Checks if a logging handler targets the standard console output.""" + return ( + isinstance(h, logging.StreamHandler) + and not isinstance(h, logging.FileHandler) + and getattr(h, "stream", None) in (sys.stdout, sys.stderr) + ) + + +def get_logger(config: LoggerConfig, + *, + component: str | None = None) -> logging.Logger: + """Returns a configured logger (optionally namespaced for a component).""" + config.logs_dir.mkdir(parents=True, exist_ok=True) + + root = logging.getLogger(config.root_name) + root.setLevel(config.root_level) + root.propagate = False # Avoid double logging via the root logger + + # Add handlers once + if not any(_is_console_handler(h) for h in root.handlers): + console_handler = logging.StreamHandler() + console_handler.setLevel(config.console_level) + console_handler.setFormatter( + logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT)) + root.addHandler(console_handler) + + if component: + return root.getChild(component) + return root + + +# ---------------------------- +# Oracles report helpers +# ---------------------------- + + +class _RequirementLike(typing.Protocol): + """Structural type for objects treated as requirements by the oracle logic. + + Attributes: + name: Requirement name. + optional: Whether this requirement is optional. + """ + + name: str + optional: bool + + +ReqT = typing.TypeVar("ReqT", bound=_RequirementLike) + + +def build_oracle_report( + *, + logger: logging.Logger, + requirements_fn: Callable[[], Sequence[ReqT]], + check_fn: Callable[[ReqT], CheckResult], +) -> OracleReport: + """Executes requirements and returns a structured OracleReport.""" + errors: list[RequirementOutcome] = [] + warnings: list[RequirementOutcome] = [] + outcomes: list[RequirementOutcome] = [] + + try: + requirements = list(requirements_fn()) + except Exception: + logger.exception("Failed to build requirements") + outcome = RequirementOutcome( + name="requirements", + optional=False, + result=CheckResult.failure("failed to build requirements"), + ) + return OracleReport(ok=False, errors=(outcome,), outcomes=(outcome,)) + + for req in requirements: + try: + result = check_fn(req) + except Exception as exc: + logger.exception("Requirement raised: %s", req.name) + result = CheckResult.failure(f"exception during check: {exc}") + + outcome = RequirementOutcome(name=req.name, + optional=req.optional, + result=result) + outcomes.append(outcome) + + if result.ok: + continue + if req.optional: + warnings.append(outcome) + else: + errors.append(outcome) + + return OracleReport( + ok=not errors, + errors=tuple(errors), + warnings=tuple(warnings), + outcomes=tuple(outcomes), + ) + + +def log_oracle_report( + logger: logging.Logger, + *, + label: str, + report: OracleReport, + verbose: bool = False, +) -> bool: + """Logs a PASS/FAIL summary for an oracle report and returns report.ok.""" + if not report.ok: + logger.info("%s: FAIL", label) + for out in report.errors: + logger.error(" - %s: %s", out.name, out.result.message) + if verbose: + log_result_details(logger, out.result) + for out in report.warnings: + logger.warning(" - %s: %s", out.name, out.result.message) + if verbose: + log_result_details(logger, out.result) + return False + + if report.warnings: + logger.info("%s: PASS (with warnings)", label) + for out in report.warnings: + logger.warning(" - %s: %s", out.name, out.result.message) + if verbose: + log_result_details(logger, out.result) + else: + logger.info("%s: PASS", label) + + return True + + +def record_result( + results: MutableMapping[str, int], + name: str, + ok: bool, +) -> int: + """Records a pass/fail result and returns the numeric score contribution.""" + score = 1 if ok else 0 + results[name] = score + return score + + +# ---------------------------- +# Misc helpers +# ---------------------------- + + +def decode_text(value: object | None) -> str: + """Decpdes subprocess output typing.fields to text.""" + if value is None: + return "" + if isinstance(value, bytes): + return value.decode(errors="replace") + return str(value) + + +def to_path(value: object) -> pathlib.Path: + """Normalizes a path-like value to a Path object.""" + if isinstance(value, pathlib.Path): + return value + if isinstance(value, (str, os.PathLike)): + return pathlib.Path(value) + raise TypeError(f"Value cannot be interpreted as a path: {type(value)!r}") From ce1caacea00dcfef4baf3a653df8002defdc47b5 Mon Sep 17 00:00:00 2001 From: Bogdan-Alexandru Stoica Date: Sat, 31 Jan 2026 01:21:17 -0600 Subject: [PATCH 2/6] chore: some cleaning and restructuring --- benchmarks/arteval_bench/src/{ => core}/__init__.py | 0 .../arteval_bench/src/{ => core}/agents/claudecode/install.sh | 0 .../arteval_bench/src/{ => core}/agents/claudecode/runner.sh | 0 .../arteval_bench/src/{ => core}/agents/minisweagent/runner.sh | 0 .../arteval_bench/src/{ => core}/agents/openhand/config.toml | 0 .../arteval_bench/src/{ => core}/agents/openhand/install.sh | 0 benchmarks/arteval_bench/src/{ => core}/agents/openhand/runner.sh | 0 benchmarks/arteval_bench/src/{ => core}/config_aoi.yaml | 0 .../arteval_bench/src/{ => core}/config_aoi_anthropic_tools.yaml | 0 benchmarks/arteval_bench/src/{ => core}/main.py | 0 benchmarks/arteval_bench/src/{ => core}/main_patch.py | 0 benchmarks/arteval_bench/src/{ => core}/patch_evaluator.py | 0 benchmarks/arteval_bench/src/{ => core}/run_eval_in_env.py | 0 benchmarks/arteval_bench/src/{ => core}/run_eval_sweagent.py | 0 benchmarks/arteval_bench/src/{ => core}/utils.py | 0 15 files changed, 0 insertions(+), 0 deletions(-) rename benchmarks/arteval_bench/src/{ => core}/__init__.py (100%) rename benchmarks/arteval_bench/src/{ => core}/agents/claudecode/install.sh (100%) rename benchmarks/arteval_bench/src/{ => core}/agents/claudecode/runner.sh (100%) rename benchmarks/arteval_bench/src/{ => core}/agents/minisweagent/runner.sh (100%) rename benchmarks/arteval_bench/src/{ => core}/agents/openhand/config.toml (100%) rename benchmarks/arteval_bench/src/{ => core}/agents/openhand/install.sh (100%) rename benchmarks/arteval_bench/src/{ => core}/agents/openhand/runner.sh (100%) rename benchmarks/arteval_bench/src/{ => core}/config_aoi.yaml (100%) rename benchmarks/arteval_bench/src/{ => core}/config_aoi_anthropic_tools.yaml (100%) rename benchmarks/arteval_bench/src/{ => core}/main.py (100%) rename benchmarks/arteval_bench/src/{ => core}/main_patch.py (100%) rename benchmarks/arteval_bench/src/{ => core}/patch_evaluator.py (100%) rename benchmarks/arteval_bench/src/{ => core}/run_eval_in_env.py (100%) rename benchmarks/arteval_bench/src/{ => core}/run_eval_sweagent.py (100%) rename benchmarks/arteval_bench/src/{ => core}/utils.py (100%) diff --git a/benchmarks/arteval_bench/src/__init__.py b/benchmarks/arteval_bench/src/core/__init__.py similarity index 100% rename from benchmarks/arteval_bench/src/__init__.py rename to benchmarks/arteval_bench/src/core/__init__.py diff --git a/benchmarks/arteval_bench/src/agents/claudecode/install.sh b/benchmarks/arteval_bench/src/core/agents/claudecode/install.sh similarity index 100% rename from benchmarks/arteval_bench/src/agents/claudecode/install.sh rename to benchmarks/arteval_bench/src/core/agents/claudecode/install.sh diff --git a/benchmarks/arteval_bench/src/agents/claudecode/runner.sh b/benchmarks/arteval_bench/src/core/agents/claudecode/runner.sh similarity index 100% rename from benchmarks/arteval_bench/src/agents/claudecode/runner.sh rename to benchmarks/arteval_bench/src/core/agents/claudecode/runner.sh diff --git a/benchmarks/arteval_bench/src/agents/minisweagent/runner.sh b/benchmarks/arteval_bench/src/core/agents/minisweagent/runner.sh similarity index 100% rename from benchmarks/arteval_bench/src/agents/minisweagent/runner.sh rename to benchmarks/arteval_bench/src/core/agents/minisweagent/runner.sh diff --git a/benchmarks/arteval_bench/src/agents/openhand/config.toml b/benchmarks/arteval_bench/src/core/agents/openhand/config.toml similarity index 100% rename from benchmarks/arteval_bench/src/agents/openhand/config.toml rename to benchmarks/arteval_bench/src/core/agents/openhand/config.toml diff --git a/benchmarks/arteval_bench/src/agents/openhand/install.sh b/benchmarks/arteval_bench/src/core/agents/openhand/install.sh similarity index 100% rename from benchmarks/arteval_bench/src/agents/openhand/install.sh rename to benchmarks/arteval_bench/src/core/agents/openhand/install.sh diff --git a/benchmarks/arteval_bench/src/agents/openhand/runner.sh b/benchmarks/arteval_bench/src/core/agents/openhand/runner.sh similarity index 100% rename from benchmarks/arteval_bench/src/agents/openhand/runner.sh rename to benchmarks/arteval_bench/src/core/agents/openhand/runner.sh diff --git a/benchmarks/arteval_bench/src/config_aoi.yaml b/benchmarks/arteval_bench/src/core/config_aoi.yaml similarity index 100% rename from benchmarks/arteval_bench/src/config_aoi.yaml rename to benchmarks/arteval_bench/src/core/config_aoi.yaml diff --git a/benchmarks/arteval_bench/src/config_aoi_anthropic_tools.yaml b/benchmarks/arteval_bench/src/core/config_aoi_anthropic_tools.yaml similarity index 100% rename from benchmarks/arteval_bench/src/config_aoi_anthropic_tools.yaml rename to benchmarks/arteval_bench/src/core/config_aoi_anthropic_tools.yaml diff --git a/benchmarks/arteval_bench/src/main.py b/benchmarks/arteval_bench/src/core/main.py similarity index 100% rename from benchmarks/arteval_bench/src/main.py rename to benchmarks/arteval_bench/src/core/main.py diff --git a/benchmarks/arteval_bench/src/main_patch.py b/benchmarks/arteval_bench/src/core/main_patch.py similarity index 100% rename from benchmarks/arteval_bench/src/main_patch.py rename to benchmarks/arteval_bench/src/core/main_patch.py diff --git a/benchmarks/arteval_bench/src/patch_evaluator.py b/benchmarks/arteval_bench/src/core/patch_evaluator.py similarity index 100% rename from benchmarks/arteval_bench/src/patch_evaluator.py rename to benchmarks/arteval_bench/src/core/patch_evaluator.py diff --git a/benchmarks/arteval_bench/src/run_eval_in_env.py b/benchmarks/arteval_bench/src/core/run_eval_in_env.py similarity index 100% rename from benchmarks/arteval_bench/src/run_eval_in_env.py rename to benchmarks/arteval_bench/src/core/run_eval_in_env.py diff --git a/benchmarks/arteval_bench/src/run_eval_sweagent.py b/benchmarks/arteval_bench/src/core/run_eval_sweagent.py similarity index 100% rename from benchmarks/arteval_bench/src/run_eval_sweagent.py rename to benchmarks/arteval_bench/src/core/run_eval_sweagent.py diff --git a/benchmarks/arteval_bench/src/utils.py b/benchmarks/arteval_bench/src/core/utils.py similarity index 100% rename from benchmarks/arteval_bench/src/utils.py rename to benchmarks/arteval_bench/src/core/utils.py From ac5ebe2c0a04ad26aa77af4c6dcb2dfe08013257 Mon Sep 17 00:00:00 2001 From: Bogdan-Alexandru Stoica Date: Sat, 31 Jan 2026 01:22:15 -0600 Subject: [PATCH 3/6] refactor: adapt egwalker's oracles to use the standardized interface --- .../eurosys25_egwalker/_agent_eval/main.py | 96 +++++- .../_agent_eval/oracle_artifact_build.py | 191 +++++----- .../_agent_eval/oracle_benchmark_prep.py | 280 +++++++++------ .../_agent_eval/oracle_env_setup.py | 221 +++++------- .../_agent_eval/oracle_experiment_runs.py | 326 +++++++++++------- .../eurosys25_egwalker/_agent_eval/utils.py | 29 -- 6 files changed, 631 insertions(+), 512 deletions(-) delete mode 100644 benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/utils.py diff --git a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/main.py b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/main.py index b67e39a2..98fe5535 100644 --- a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/main.py +++ b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/main.py @@ -1,32 +1,90 @@ #!/usr/bin/env python3 +"""Runs environment setup, build, benchmark prep, and experiment runs checks for EGWALKER.""" + +from __future__ import annotations + +import os import sys +from pathlib import Path from typing import Dict -# from oracle_artifact_build import OracleArtifactBuild +_AGENT_EVAL_DIR = Path(__file__).resolve().parent +_AGENT_SRC_DIR = _AGENT_EVAL_DIR.parents[3] / "src" +sys.path.append(str(_AGENT_SRC_DIR)) + +from oracle_artifact_build import OracleArtifactBuild +from oracle_benchmark_prep import OracleBenchmarkPrep from oracle_env_setup import OracleEnvSetup -# from oracle_benchmark_prep import OracleBenchmarkPrep -# from oracle_experiment_runs import OracleExperimentRuns +from oracle_experiment_runs import OracleExperimentRuns +from evaluator.utils import EntryConfig, LoggerConfig, get_logger, record_result -from utils import logger -def main(): - results: Dict[str, int] = {} +EGWALKER_CONFIG = EntryConfig( + name="eurosys25-egwalker", + home_dir=Path.home() / "eurosys25_egwalker", + repository_paths={ + "eurosys25-egwalker": Path.home() / "eurosys25_egwalker" / "egwalker", + }, + results_paths={ + # Matches legacy: /results/timings.json + "timings": Path.home() + / "eurosys25_egwalker" + / "egwalker" + / "results" + / "timings.json", + }, + ground_truth_paths={ + "datasets": ( + Path.home() + / "eurosys25_egwalker" + / "_agent_eval" + / "refs" + / "datasets.ref.json" + ), + "timings": ( + Path.home() + / "eurosys25_egwalker" + / "_agent_eval" + / "refs" + / "timings.ref.json" + ), + }, + similarity_ratio=0.75, +) + +def main(argv: list[str]) -> int: + results: Dict[str, int] = {} score = 0 - for cls in (OracleEnvSetup, OracleArtifactBuild, OracleBenchmarkPrep, OracleExperimentRuns): - checker = cls() - ok = checker.run() - name = cls.__name__ - logger.info(f"{name}: {'PASS' if ok else 'FAIL'}") - if ok: - results[name] = 1 - score += 1 - else: - results[name] = 0 - - logger.info(f"Agent scores: {results}") + + verbose = "--verbose" in argv + + logger_name = os.environ.get("EVAL_LOGGER_NAME", "EGWALKER-EVAL") + logger = get_logger(LoggerConfig(root_name=logger_name)) + + env_checker = OracleEnvSetup(config=EGWALKER_CONFIG, logger=logger) + score += record_result( + logger, results, type(env_checker).__name__, env_checker.run(verbose=verbose) + ) + + build_checker = OracleArtifactBuild(config=EGWALKER_CONFIG, logger=logger) + score += record_result( + logger, results, type(build_checker).__name__, build_checker.run(verbose=verbose) + ) + + prep_checker = OracleBenchmarkPrep(config=EGWALKER_CONFIG, logger=logger) + score += record_result( + logger, results, type(prep_checker).__name__, prep_checker.run(verbose=verbose) + ) + + runs_checker = OracleExperimentRuns(config=EGWALKER_CONFIG, logger=logger) + score += record_result( + logger, results, type(runs_checker).__name__, runs_checker.run(verbose=verbose) + ) + + logger.info("Agent scores: %s", results) return score if __name__ == "__main__": - main() \ No newline at end of file + raise SystemExit(main(sys.argv[1:])) diff --git a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_artifact_build.py b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_artifact_build.py index 462a2d92..c71db0d5 100644 --- a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_artifact_build.py +++ b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_artifact_build.py @@ -1,93 +1,114 @@ -#!/usr/bin/env python3 -import os -import subprocess -from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple +"""Artifact build oracle for the Eurosys'25 EGWALKER artifact. + +Validates: + - Required repository working directories exist. + - Build commands execute successfully (captures stdout/stderr/return code). +""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +import logging from pathlib import Path -from utils import REPO_DIR -from utils import logger +from evaluator.oracle_artifact_build_primitives import ( + BuildCommandRequirement, + BuildRequirement, + OracleArtifactBuildBase, +) +from evaluator.utils import EntryConfig -@dataclass(frozen=True) +@dataclass(frozen = True, slots = True, kw_only = True) class BuildTarget: + """Declarative description of one build command to run.""" + name: str - repo_key: str - cmd: List[str] - - -BUILD_TARGETS: List[BuildTarget] = [ - BuildTarget( - name="artifact-core", - repo_key="artifact-core", - cmd=[ - "make", - "-j8", - "tools/diamond-types/target/release/dt", - "tools/crdt-converter/target/release/crdt-converter", - "tools/diamond-types/target/release/paper-stats", - "tools/paper-benchmarks/target/memusage/paper-benchmarks", - "tools/paper-benchmarks/target/release/paper-benchmarks", - "tools/ot-bench/target/memusage/ot-bench", - "tools/ot-bench/target/release/ot-bench" - ], - ), -] - - -class OracleArtifactBuild: - - def __init__(self) -> None: - self.repo_dir = REPO_DIR - - def run_shell_command( + command: Sequence[str] + cwd_relative: Path | None = None + optional: bool = False + timeout_seconds: float = 60.0 + env_overrides: Mapping[str, str] = field(default_factory = dict) + + def __post_init__(self) -> None: + if not self.name: + raise ValueError("BuildTarget.name must be non-empty") + if not self.command: + raise ValueError(f"{self.name}: command must be non-empty") + if self.timeout_seconds <= 0: + raise ValueError(f"{self.name}: timeout_seconds must be > 0") + + # Normalize for downstream requirements. + if self.cwd_relative is not None and not isinstance(self.cwd_relative, Path): + object.__setattr__(self, "cwd_relative", Path(self.cwd_relative)) + + # Freeze command to avoid accidental mutation. + object.__setattr__(self, "command", tuple(self.command)) + + +class OracleArtifactBuild(OracleArtifactBuildBase): + """The artifact build oracle for artifact-core. + + Defaults: + * Runs build commands in the repo keyed by config.name. + * EntryConfig.repository_paths must contain an entry for config.name. + """ + + _DEFAULT_TARGET_SPECS: tuple[tuple[str, tuple[str, ...], float], ...] = ( + ( + "artifact-core: make tools", + ( + "make", + "-j8", + "tools/diamond-types/target/release/dt", + "tools/crdt-converter/target/release/crdt-converter", + "tools/diamond-types/target/release/paper-stats", + "tools/paper-benchmarks/target/memusage/paper-benchmarks", + "tools/paper-benchmarks/target/release/paper-benchmarks", + "tools/ot-bench/target/memusage/ot-bench", + "tools/ot-bench/target/release/ot-bench", + ), + 60.0, + ), + ) + + def __init__( self, - cmd: Iterable[str], - cwd: Optional[Path] = None, - ) -> Tuple[int, str, str]: - """ - Run a command and return (rc, stdout, stderr) tuple. - """ - try: - cp = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - cwd=str(cwd) if cwd is not None else None, + *, + config: EntryConfig, + logger: logging.Logger, + targets: Sequence[BuildTarget] | None = None, + ) -> None: + super().__init__(logger = logger) + self._config = config + + if targets is None: + targets = self._make_default_targets() + self._targets = tuple(targets) + + names = [t.name for t in self._targets] + if len(names) != len(set(names)): + raise ValueError(f"Duplicate build target names: {names!r}") + + def _make_default_targets(self) -> tuple[BuildTarget, ...]: + """Creates default targets (stored in the EntryConfig object).""" + return tuple( + BuildTarget(name = name, command = command, timeout_seconds = timeout_seconds) + for (name, command, timeout_seconds) in self._DEFAULT_TARGET_SPECS + ) + + def requirements(self) -> Sequence[BuildRequirement]: + """Returns an ordered list of build requirements to validate.""" + return tuple( + BuildCommandRequirement( + name = target.name, + optional = target.optional, + cwd = self._config.repository_paths[self._config.name], + command = target.command, + cwd_relative = target.cwd_relative, + timeout_seconds = target.timeout_seconds, + env_overrides = target.env_overrides, ) - return cp.returncode, cp.stdout or "", cp.stderr or "" - except FileNotFoundError: - return 127, "", "" - - def build_target(self, target: BuildTarget) -> Optional[str]: - """ - Build a single target using its configured repository and command. - """ - repo_path = Path(os.path.expanduser(self.repo_dir)) - if not repo_path.exists(): - return f"{target.name} repo directory missing" - - rc, out, err = self.run_shell_command(target.cmd, cwd=repo_path) - if rc != 0: - return f"{target.name} build failed (error code: {rc}; error message: {err})" - - return None - - def build_check(self): - """ - Run builds for all configured targets and collect failures. - """ - problems: List[str] = [] - for target in BUILD_TARGETS: - msg = self.build_target(target) - if msg: - problems.append(msg) - if problems: - return False, "; ".join(problems) - return True, "" - - def run(self): - ok, why = self.build_check() - logger.info(f"Build: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - return ok \ No newline at end of file + for target in self._targets + ) \ No newline at end of file diff --git a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_benchmark_prep.py b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_benchmark_prep.py index 310f7abc..28f891e9 100644 --- a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_benchmark_prep.py +++ b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_benchmark_prep.py @@ -1,125 +1,177 @@ #!/usr/bin/env python3 -import json -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Any, List, Optional, Tuple - -from utils import HOME -from utils import REPO_DIR -from utils import REFERENCE_BENCHMARK_FILE -from utils import logger - +"""Benchmark preparation oracle for _agent_eval bundles. -@dataclass(frozen=True) -class DatasetRef: - filepath: str - sizeinbytes: int +Validates: + - Dataset manifest JSON is readable and well-formed. + - Each referenced dataset file is within the repo root (no traversal). + - Each referenced dataset file exists and matches the expected size in bytes. +""" +from __future__ import annotations -class OracleBenchmarkPrep: - - def __init__(self) -> None: - self.home = Path(os.path.expanduser(str(HOME))) - self.repo_path = Path(os.path.expanduser(str(REPO_DIR))) - self.ref_json = Path(os.path.expanduser(str(REFERENCE_BENCHMARK_FILE))) +import json +import logging +import sys +from pathlib import Path +from typing import Mapping, Sequence + +from evaluator.utils import EntryConfig +from evaluator.oracle_benchmark_prep_primitives import ( + BenchmarkRequirement, + FailRequirement, + OracleBenchmarkPrepBase, + Requirement, +) + + +def _required_path(paths: Mapping[str, Path], key: str, *, label: str) -> Path: + """Returns a required path from a mapping with a clear error.""" + try: + return paths[key] + except KeyError as e: + raise ValueError(f"Missing {label}[{key!r}] in EntryConfig") from e + + +def _resolve_nonstrict(path: Path) -> Path: + """Resolves a path without requiring it to exist.""" + return path.resolve(strict = False) + + +def _is_within(root: Path, candidate: Path) -> bool: + """Returns True iff candidate is within root after resolution.""" + root_resolved = _resolve_nonstrict(root) + cand_resolved = _resolve_nonstrict(candidate) + return cand_resolved == root_resolved or root_resolved in cand_resolved.parents + + +class OracleBenchmarkPrep(OracleBenchmarkPrepBase): + """Validates dataset prerequisites for _agent_eval bundles.""" + + def __init__( + self, + *, + config: EntryConfig, + logger: logging.Logger, + manifest_key: str = "datasets", + ) -> None: + super().__init__(logger = logger) + self._config = config + self._manifest_key = manifest_key + + def requirements(self) -> Sequence[Requirement]: + repo_root = _required_path( + self._config.repository_paths, + self._config.name, + label = "repository_paths", + ) + manifest_path = _required_path( + self._config.ground_truth_paths, + self._manifest_key, + label = "ground_truth_paths", + ) + + reqs: list[Requirement] = [ + BenchmarkRequirement( + name = "repo_root_exists", + filepath = repo_root, + ), + BenchmarkRequirement( + name = "dataset_manifest_exists", + filepath = manifest_path, + ), + ] + + if not manifest_path.exists(): + return reqs - def load_json(self, path: Path) -> Tuple[Optional[Any], str]: - """ - Load JSON from disk and return (obj, err). - """ - if not path.exists(): - return None, f"ref json missing: {path}" try: - with path.open("r", encoding="utf-8") as f: - return json.load(f), "" - except Exception as e: - return None, f"ref json unreadable: {e}" - - def iter_ref_entries(self, obj: Any) -> List[dict]: - """ - Extract benchmark entries from a reference JSON. - """ - if isinstance(obj, list): - return [x for x in obj if isinstance(x, dict)] - if isinstance(obj, dict): - for v in obj.values(): - if isinstance(v, list) and v and all(isinstance(x, dict) for x in v): - return v - return [] - - def parse_entry(self, d: dict) -> Tuple[Optional[DatasetRef], str]: - """ - Parse a single JSON entry into DatasetRef. - """ - if "filepath" not in d: - return None, "missing filepath" - if "sizeinbytes" not in d: - return None, "missing sizeinbytes" - - fp = d.get("filepath", "") - sz = d.get("sizeinbytes", None) - - if not isinstance(fp, str) or not fp: - return None, "invalid filepath" - if not isinstance(sz, int) or sz < 0: - return None, "invalid sizeinbytes" - - return DatasetRef(filepath=fp, sizeinbytes=sz), "" - - def check_entry(self, ref: DatasetRef) -> Optional[str]: - """ - Validate that dataset files exist and matche the expected sizes (in bytes). - """ - rel = Path(ref.filepath) - - if rel.is_absolute(): - return f"{ref.filepath}: absolute paths not allowed" - - p = self.repo_path / rel - if not p.exists(): - return f"{ref.filepath}: missing" - if not p.is_file(): - return f"{ref.filepath}: not a file" + obj = json.loads(manifest_path.read_text(encoding = "utf-8")) + except (OSError, json.JSONDecodeError) as exc: + reqs.append( + FailRequirement( + name = "dataset_manifest_readable", + message = f"manifest unreadable: {exc}", + ) + ) + return reqs + + if not isinstance(obj, list): + reqs.append( + FailRequirement( + name = "dataset_manifest_format", + message = "manifest JSON must be a list of objects", + ) + ) + return reqs + + # Print a stable marker so signature matching is robust + # and portable across different platforms + size_script = ( + "import os, sys\n" + "p = sys.argv[1]\n" + "print(f'OK size = {os.path.getsize(p)}')\n" + ) + + for i, entry in enumerate(obj): + entry_name = f"entry[{i}]" + + if not isinstance(entry, dict): + reqs.append( + FailRequirement( + name = entry_name, + message = "entry must be an object", + ) + ) + continue - try: - actual = p.stat().st_size - except OSError as e: - return f"{ref.filepath}: stat failed ({e})" - - if actual != ref.sizeinbytes: - return f"{ref.filepath}: size mismatch (expected {ref.sizeinbytes}, got {actual})" - - return None - - def datasets_check(self) -> Tuple[bool, str]: - """ - Check all referenced dataset files are present and match expected sizes. - """ - obj, err = self.load_json(self.ref_json) - if err: - return False, err - - entries = self.iter_ref_entries(obj) - if not entries: - return False, "no entries found in ref json" - - problems: List[str] = [] - for d in entries: - ref, perr = self.parse_entry(d) - if perr or ref is None: - problems.append(perr or "invalid entry") + filepath = entry.get("filepath") + size = entry.get("sizeinbytes") + + if not isinstance(filepath, str) or not filepath.strip(): + reqs.append( + FailRequirement( + name = entry_name, + message = "missing/invalid filepath", + ) + ) + continue + if not isinstance(size, int) or size < 0: + reqs.append( + FailRequirement( + name = entry_name, + message = f"{filepath!r}: missing/invalid sizeinbytes", + ) + ) continue - msg = self.check_entry(ref) - if msg: - problems.append(msg) + rel = Path(filepath) + if rel.is_absolute(): + reqs.append( + FailRequirement( + name = f"dataset:{filepath}", + message = "absolute paths not allowed", + ) + ) + continue - if problems: - return False, "; ".join(problems) - return True, "" + full_path = repo_root / rel + if not _is_within(repo_root, full_path): + reqs.append( + FailRequirement( + name = f"dataset:{filepath}", + message = "path escapes repo root (.. traversal not allowed)", + ) + ) + continue - def run(self) -> bool: - ok, why = self.datasets_check() - logger.info(f"Datasets: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - return ok \ No newline at end of file + reqs.append( + BenchmarkRequirement( + name = f"dataset:{filepath}", + filepath = full_path, + cmd = (sys.executable, "-c", size_script, str(full_path)), + signature = f"OK size = {size}", + timeout_seconds = 30.0, + ) + ) + + return reqs diff --git a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_env_setup.py b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_env_setup.py index 028d5f20..191b104c 100644 --- a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_env_setup.py +++ b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_env_setup.py @@ -1,139 +1,92 @@ -#!/usr/bin/env python3 -import subprocess -import re -from dataclasses import dataclass -from typing import Iterable, Optional, Tuple -from pathlib import Path - -from utils import logger - - -Version = Tuple[int, int, int] - - -@dataclass(frozen=True) -class ToolRequirement: - name: str - cmd: list[str] - min_version: Optional[Version] = None - optional: bool = False +"""Environment setup oracle for the Eurosys'25 EGWALKER bundle. +Validates: + - Required tools and minimum versions where applicable. + - Repository directory exists. + - Ground-truth reference files exist. +""" -MIN_RUST_VERSION: Version = (1, 78, 0) +from __future__ import annotations - -TOOL_REQUIREMENTS: list[ToolRequirement] = [ - ToolRequirement( - name="rustc", - cmd=["rustc", "--version"], - min_version=MIN_RUST_VERSION, - ), - ToolRequirement( - name="cargo", - cmd=["cargo", "--version"], - ), - ToolRequirement( - name="node", - cmd=["node", "--version"], - ), - ToolRequirement( - name="make", - cmd=["make", "--version"], - optional=True, - ), -] - - -class OracleEnvSetup: - - def run_shell_command( - self, - cmd: Iterable[str], - cwd: Optional[Path] = None, - ) -> Tuple[int, str, str]: - """ - Run a command and return (rc, stdout, stderr) tuple. - """ - try: - cp = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - cwd=str(cwd) if cwd is not None else None, +from pathlib import Path +from typing import Mapping, Sequence + +from evaluator.utils import EntryConfig, logger +from evaluator.oracle_env_setup_primitives import ( + DependencyVersionRequirement, + FilesystemPathRequirement, + OracleEnvSetupBase, + PathType, + Requirement, + VersionCompare, +) + +_REPO_KEY = "egwalker" + + +def _required_path(paths: Mapping[str, Path], key: str, *, label: str) -> Path: + """Returns a required path from a mapping with a clear error.""" + try: + return paths[key] + except KeyError as e: + raise ValueError(f"Missing {label}[{key!r}] in EntryConfig") from e + + +class OracleEnvSetup(OracleEnvSetupBase): + """Validates environment prerequisites for EGWALKER.""" + + def __init__(self, *, config: EntryConfig, logger: logger) -> None: + super().__init__(logger) + self._config = config + + def requirements(self) -> Sequence[Requirement]: + repo_root = _required_path( + self._config.repository_paths, self._config.name, label="repository_paths") + + reqs: list[Requirement] = [ + # Tooling. + DependencyVersionRequirement( + name="rustc", + command=("rustc", "--version"), + required_version=(1, 78, 0), + compare=VersionCompare.GEQ, + ), + DependencyVersionRequirement( + name="cargo", + command=("cargo", "--version"), + required_version=(1, 0, 0), + compare=VersionCompare.GEQ, + ), + DependencyVersionRequirement( + name="node", + command=("node", "--version"), + required_version=(0, 0, 0), + compare=VersionCompare.GEQ, + ), + DependencyVersionRequirement( + name="make", + command=("make", "--version"), + required_version=(0, 0, 0), + compare=VersionCompare.GEQ, + optional=True, + ), + + # Repo directory. + FilesystemPathRequirement( + name="repo_root_exists", + path=repo_root, + path_type=PathType.DIRECTORY, + ), + ] + + # Reference files (required). + for key, ref_path in sorted(self._config.ground_truth_paths.items()): + reqs.append( + FilesystemPathRequirement( + name=f"reference_{key}_exists", + path=ref_path, + path_type=PathType.FILE, + ) ) - return cp.returncode, cp.stdout or "", cp.stderr or "" - except FileNotFoundError: - return 127, "", "" - - def parse_version(self, s: str) -> Optional[Version]: - """ - Extract a version number from a string. - """ - m = re.search(r"(?:^|\s)v?(\d+)\.(\d+)(?:\.(\d+))?", s) - if not m: - return None - major = int(m.group(1)) - minor = int(m.group(2)) - patch = int(m.group(3)) if m.group(3) is not None else 0 - return (major, minor, patch) - - def version_lt(self, a: Version, b: Version) -> bool: - return a < b - - def check_tool(self, req: ToolRequirement) -> Tuple[Optional[str], Optional[str]]: - """ - Check a single dependency requirement, including version. - """ - rc, out, err = self.run_shell_command(req.cmd) - combined = (out + "\n" + err).strip() - - if rc == 127: - if req.optional: - return None, f"{req.name} missing (optional)" - return f"{req.name} not found", None - - if rc != 0: - if req.optional: - return None, f"{req.name} check failed (rc={rc}) (optional)" - return f"{req.name} check failed (rc={rc})", None - - if req.min_version is not None: - v = self.parse_version(combined) - if v is None: - return f"{req.name} version parse failed", None - if self.version_lt(v, req.min_version): - return f"{req.name} too old (need >= {req.min_version[0]}.{req.min_version[1]}.{req.min_version[2]})", None - - return None, None - - def build_check(self): - """ - Validate required dependnecies and environment setup. - """ - problems: list[str] = [] - warnings: list[str] = [] - - for req in TOOL_REQUIREMENTS: - problem, warning = self.check_tool(req) - if problem: - problems.append(problem) - if warning: - warnings.append(warning) - - if problems: - return False, "; ".join(problems) - - if warnings: - return True, "WARN: " + "; ".join(warnings) - - return True, "" - def run(self): - ok, why = self.build_check() - label = "Environment" - if ok and why: - logger.info(f"{label}: PASS - {why}") - return ok - logger.info(f"{label}: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - return ok \ No newline at end of file + return reqs \ No newline at end of file diff --git a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_experiment_runs.py b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_experiment_runs.py index b4e2d70f..473d1aeb 100644 --- a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_experiment_runs.py +++ b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/oracle_experiment_runs.py @@ -1,135 +1,199 @@ #!/usr/bin/env python3 +"""Experiment runs oracle for the EuroSys'25 EGWALKER artifact. + +This oracle compares experiment-produced timings against reference timings. +""" + +from __future__ import annotations + import json -import os +from collections.abc import Iterable, Mapping, Sequence +from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -from utils import REPO_DIR -from utils import REFERENCE_RESULTS_FILE -from utils import SIMILARITY_RATIO -from utils import logger - - -class OracleExperimentRuns: - def __init__(self) -> None: - self.repo_dir = Path(os.path.expanduser(str(REPO_DIR))) - self.timings_file = self.repo_dir / "results" / "timings.json" - self.reference_file = Path(os.path.expanduser(str(REFERENCE_RESULTS_FILE))) - self.max_mismatches_to_report = (1 - SIMILARITY_RATIO) - - def load_json(self, path: Path) -> Tuple[Optional[Any], str]: - """ - Load JSON from disk and return (obj, err). - """ - if not path.exists(): - return None, f"missing json: {path}" +import logging + +from evaluator.oracle_experiment_runs_primitives import ( + ExperimentRunsRequirement, + LabeledSequenceSimilarityThresholdRequirement, + OracleExperimentRunsBase, +) +from evaluator.utils import EntryConfig + + +def _required_path(paths: Mapping[str, Path], key: str, *, label: str) -> Path: + """Returns a required path from a mapping with a clear error message.""" + try: + return paths[key] + except KeyError as exc: + raise ValueError(f"Missing {label}[{key!r}] in EntryConfig") from exc + + +def _loads_json_from_lines(lines: Sequence[str], *, label: str) -> object: + """Parses JSON content from already-read file lines.""" + text = "\n".join(lines).strip() + if not text: + raise ValueError(f"{label}: empty JSON content") + try: + return json.loads(text) + except json.JSONDecodeError as exc: + raise ValueError(f"{label}: invalid JSON: {exc}") from exc + + +def _load_json_file(path: Path, *, label: str) -> object: + """Loads JSON from a file path.""" + try: + text = path.read_text(encoding="utf-8") + except OSError as exc: + raise ValueError(f"{label}: failed to read {path}: {exc}") from exc + try: + return json.loads(text) + except json.JSONDecodeError as exc: + raise ValueError(f"{label}: invalid JSON: {exc}") from exc + + +def _as_float(v: object, *, label: str) -> float: + """Converts numeric values to float; raises on non-numeric.""" + if isinstance(v, (int, float)): + return float(v) + raise ValueError(f"{label}: non-numeric value {v!r}") + + +def _iter_metric_tag_rows(obj: object, *, label: str) -> Iterable[tuple[str, Mapping[str, object]]]: + """Yields (row_key, stats_dict) where row_key is '.'.""" + if not isinstance(obj, dict): + raise ValueError(f"{label}: timings JSON must be an object at top-level") + + for metric_name, metric in obj.items(): + if not isinstance(metric, dict): + raise ValueError(f"{label}: {metric_name!r} must map to an object") + for tag, stats in metric.items(): + if not isinstance(stats, dict): + raise ValueError(f"{label}: {metric_name}.{tag} must map to an object") + row_key = f"{metric_name}.{tag}" + yield row_key, stats + + +def _discover_reference_fields(reference_obj: object, *, label: str) -> tuple[str, ...]: + """Returns unique stats fields in order of first appearance in the reference.""" + seen: set[str] = set() + ordered: list[str] = [] + for _row_key, stats in _iter_metric_tag_rows(reference_obj, label=label): + for field in stats.keys(): + if not isinstance(field, str): + raise ValueError(f"{label}: non-string field name {field!r}") + if field not in seen: + seen.add(field) + ordered.append(field) + return tuple(ordered) + + +def _pairs_for_field_from_obj( + obj: object, + *, + field: str, + label: str, +) -> list[tuple[str, float]]: + """Builds (row_key, value) pairs for a given stats field.""" + out: list[tuple[str, float]] = [] + for row_key, stats in _iter_metric_tag_rows(obj, label=label): + if field not in stats: + # Skip: the primitives will treat this as "missing label" if reference + # expected it for this field (i.e., if reference includes row_key here). + continue + out.append((row_key, _as_float(stats[field], label=f"{label}: {row_key}.{field}"))) + return out + + +def _pairs_flatten_all_fields(obj: object, *, label: str) -> list[tuple[str, float]]: + """Fallback: flattens all fields into '..' labels.""" + out: list[tuple[str, float]] = [] + for row_key, stats in _iter_metric_tag_rows(obj, label=label): + for field, raw in stats.items(): + if not isinstance(field, str): + raise ValueError(f"{label}: non-string field name {field!r}") + full = f"{row_key}.{field}" + out.append((full, _as_float(raw, label=f"{label}: {full}"))) + return out + + +def _parse_results_pairs_for_field(lines: Sequence[str], *, field: str) -> list[tuple[str, float]]: + obj = _loads_json_from_lines(lines, label="timings results") + return _pairs_for_field_from_obj(obj, field=field, label="timings results") + + +def _parse_reference_pairs_for_field(path: Path, *, field: str) -> list[tuple[str, float]]: + obj = _load_json_file(path, label="timings reference") + return _pairs_for_field_from_obj(obj, field=field, label="timings reference") + + +def _parse_results_pairs_flat(lines: Sequence[str]) -> list[tuple[str, float]]: + obj = _loads_json_from_lines(lines, label="timings results") + return _pairs_flatten_all_fields(obj, label="timings results") + + +def _parse_reference_pairs_flat(path: Path) -> list[tuple[str, float]]: + obj = _load_json_file(path, label="timings reference") + return _pairs_flatten_all_fields(obj, label="timings reference") + + +class OracleExperimentRuns(OracleExperimentRunsBase): + """Validates experiment run timings for EGWALKER.""" + + _NAME = "ExperimentRuns" + + def __init__(self, *, config: EntryConfig, logger: logging.Logger) -> None: + super().__init__(logger=logger) + self._config = config + + def requirements(self) -> Sequence[ExperimentRunsRequirement]: + if not self._config.results_paths: + raise ValueError("EntryConfig.results_paths must be non-empty") + if not self._config.ground_truth_paths: + raise ValueError("EntryConfig.ground_truth_paths must be non-empty") + + results_path = _required_path( + self._config.results_paths, "timings", label="results_paths" + ) + reference_path = _required_path( + self._config.ground_truth_paths, "timings", label="ground_truth_paths" + ) + + threshold = self._config.similarity_ratio + + # Discover which "types" (fields) to check from the reference. + # If discovery fails (missing/invalid JSON), fall back to a single requirement + # that will report the real failure via the primitives. try: - with path.open("r", encoding="utf-8") as f: - return json.load(f), "" - except Exception as e: - return None, f"unreadable json: {path} ({e})" - - def as_float(self, v: Any) -> Optional[float]: - if isinstance(v, (int, float)): - return float(v) - return None - - def ratios_within_tolerance(self, actual: float, ref: float) -> Tuple[bool, float]: - """ - Check whether two measurements are within tolerance. - """ - if abs(ref) < 1e-12: - if abs(actual) < 1e-12: - return True, 0.0 - return False, float("inf") - - rel_diff = abs(actual - ref) / abs(ref) - return rel_diff <= (1.0 - float(SIMILARITY_RATIO)), rel_diff - - def compare_timings( - self, - actual: Dict[str, Any], - reference: Dict[str, Any], - ) -> Tuple[bool, str]: - """ - Compare current timings with the original, reference timings. - """ - if not isinstance(actual, dict) or not isinstance(reference, dict): - return False, "timings json invalid format (expected object at top-level)" - - missing: List[str] = [] - mismatches: List[str] = [] - total = 0 - ok_count = 0 - - for metric_name, ref_metric in reference.items(): - if not isinstance(ref_metric, dict): - missing.append(f"{metric_name}: invalid reference section (expected object)") - continue - - act_metric = actual.get(metric_name) - if not isinstance(act_metric, dict): - missing.append(f"{metric_name}: missing metric") - continue - - for tag, ref_stats in ref_metric.items(): - if not isinstance(ref_stats, dict): - missing.append(f"{metric_name}.{tag}: invalid reference tag (expected object)") - continue - - act_stats = act_metric.get(tag) - if not isinstance(act_stats, dict): - missing.append(f"{metric_name}.{tag}: missing tag") - continue - - for field, ref_val_raw in ref_stats.items(): - total += 1 - - if field not in act_stats: - missing.append(f"{metric_name}.{tag}.{field}: missing field") - continue - - ref_val = self.as_float(ref_val_raw) - act_val = self.as_float(act_stats.get(field)) - - if ref_val is None: - missing.append(f"{metric_name}.{tag}.{field}: non-numeric reference value") - continue - if act_val is None: - missing.append(f"{metric_name}.{tag}.{field}: non-numeric actual value") - continue - - ok, sim = self.ratios_within_tolerance(act_val, ref_val) - if ok: - ok_count += 1 - else: - mismatches.append( - f"{metric_name}.{tag}.{field}: {act_val} vs {ref_val} (similarity {sim:.3f} < {SIMILARITY_RATIO})" - ) - - if missing or mismatches: - parts: List[str] = [] - summary = f"{ok_count}/{total} fields meet similarity ratio" if total else "0 fields compared" - if missing: - parts.append("missing/invalid: " + "; ".join(missing)) - if mismatches: - parts.append("measurement difference: " + "; ".join(mismatches)) - return False, summary + " - " + " | ".join(parts) - - summary = f"{ok_count}/{total} fields meet similarity ratio" if total else "no reference fields to compare" - return True, summary - - def run(self) -> bool: - actual_obj, err = self.load_json(self.timings_file) - if err: - logger.info(f"Timings: FAIL - {err}") - return False - - ref_obj, err = self.load_json(self.reference_file) - if err: - logger.info(f"Timings: FAIL - {err}") - return False - - ok, why = self.compare_timings(actual_obj, ref_obj) - logger.info(f"Timings: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - return ok + ref_obj = _load_json_file(reference_path, label="timings reference") + fields = _discover_reference_fields(ref_obj, label="timings reference") + except ValueError: + fields = () + + if not fields: + # Fallback or "no fields": compare all qualified fields as one sequence. + return ( + LabeledSequenceSimilarityThresholdRequirement( + name="timings", + label="Timings", + results_path=results_path, + reference_path=reference_path, + threshold=threshold, + parse_results_fn=_parse_results_pairs_flat, + parse_reference_fn=_parse_reference_pairs_flat, + ), + ) + + reqs: list[ExperimentRunsRequirement] = [] + for field in fields: + reqs.append( + LabeledSequenceSimilarityThresholdRequirement( + name=f"timings_{field}", + label=f"Timings {field}", + results_path=results_path, + reference_path=reference_path, + threshold=threshold, + parse_results_fn=partial(_parse_results_pairs_for_field, field=field), + parse_reference_fn=partial(_parse_reference_pairs_for_field, field=field), + ) + ) + return tuple(reqs) \ No newline at end of file diff --git a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/utils.py b/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/utils.py deleted file mode 100644 index 284ec895..00000000 --- a/benchmarks/arteval_bench/data/benchmark/eurosys25_egwalker/_agent_eval/utils.py +++ /dev/null @@ -1,29 +0,0 @@ -# --- CONSTANTS --- # -from pathlib import Path - -HOME = Path.home() / "eurosys25_egwalker" -REPO_DIR = f"{HOME}/egwalker" - -REFERENCE_BENCHMARK_FILE = f"{HOME}/_agent_eval/refs/datasets.ref.json" -REFERENCE_RESULTS_FILE = f"{HOME}/_agent_eval/refs/timings.ref.json" -SIMILARITY_RATIO = 0.75 - - -# --- CUSTOM LOGGER --- # -import logging -import os -from datetime import datetime - -os.makedirs('logs', exist_ok=True) - -LOG_FORMAT = '%(asctime)s | %(levelname)s | %(name)s | %(message)s' -DATE_FORMAT = '%Y-%m-%d %H:%M:%S' - -logger = logging.getLogger("OSDI24-ANVIL-AGENT-EVALUATOR") -logger.setLevel(logging.DEBUG) - -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.INFO) -console_handler.setFormatter(logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)) - -logger.addHandler(console_handler) \ No newline at end of file From a65d74011e6d72a02084aac07cf89d61534e19d6 Mon Sep 17 00:00:00 2001 From: Bogdan-Alexandru Stoica Date: Sat, 31 Jan 2026 01:22:32 -0600 Subject: [PATCH 4/6] refactor: adapt anvil's oracles to use the standardized interface --- .../osdi24_anvil/_agent_eval/main.py | 89 +++- .../_agent_eval/oracle_artifact_build.py | 164 ++++--- .../_agent_eval/oracle_benchmark_prep.py | 2 +- .../_agent_eval/oracle_env_setup.py | 336 +++++-------- .../_agent_eval/oracle_experiment_runs.py | 460 +++++++++--------- .../osdi24_anvil/_agent_eval/utils.py | 30 -- 6 files changed, 521 insertions(+), 560 deletions(-) delete mode 100644 benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/utils.py diff --git a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/main.py b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/main.py index 2f434ee5..b9910621 100644 --- a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/main.py +++ b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/main.py @@ -1,32 +1,87 @@ #!/usr/bin/env python3 +"""Runs environment setup checks for ANVIL.""" + +from __future__ import annotations + +import os import sys +from pathlib import Path from typing import Dict -from oracle_artifact_build import OracleArtifactBuild +_AGENT_EVAL_DIR = Path(__file__).resolve().parent +_AGENT_SRC_DIR = _AGENT_EVAL_DIR.parents[3] / "src" +sys.path.append(str(_AGENT_SRC_DIR)) + from oracle_env_setup import OracleEnvSetup +from oracle_artifact_build import OracleArtifactBuild from oracle_benchmark_prep import OracleBenchmarkPrep from oracle_experiment_runs import OracleExperimentRuns +from evaluator.utils import ( + EntryConfig, + LoggerConfig, + get_logger, + record_result, +) -from utils import logger +# Reuse the same constants the legacy oracle used. +from utils import RESULTS_PATH, SIMILARITY_RATIO # pylint: disable=wrong-import-position + + +ANVIL_CONFIG = EntryConfig( + name="osdi24-anvil", + home_dir=Path.home() / "osdi24_anvil", + repository_paths={ + "osdi24-anvil": Path.home() / "osdi24_anvil" / "anvil", + "osdi24-acto-dependency": Path.home() / "osdi24_anvil" / "acto", + }, + results_paths={ + "table3": Path(RESULTS_PATH), + }, + ground_truth_paths={ + "table3": ( + Path.home() + / "osdi24_anvil" + / "_agent_eval" + / "refs" + / "anvil-table-3.ref.json" + ), + }, + similarity_ratio=SIMILARITY_RATIO, +) -def main(): - results: Dict[str, int] = {} +def main(argv: list[str]) -> int: + results: Dict[str, int] = {} score = 0 - for cls in (OracleEnvSetup, OracleArtifactBuild, OracleBenchmarkPrep, OracleExperimentRuns): - checker = cls() - ok = checker.run() - name = cls.__name__ - logger.info(f"{name}: {'PASS' if ok else 'FAIL'}") - if ok: - results[name] = 1 - score += 1 - else: - results[name] = 0 - - logger.info(f"Agent scores: {results}") + + verbose = "--verbose" in argv + + logger_name = os.environ.get("EVAL_LOGGER_NAME", "ANVIL-EVAL") + logger = get_logger(LoggerConfig(root_name=logger_name)) + + env_checker = OracleEnvSetup(config=ANVIL_CONFIG, logger=logger) + score += record_result( + results, type(env_checker).__name__, env_checker.run(verbose=verbose) + ) + + build_checker = OracleArtifactBuild(config=ANVIL_CONFIG, logger=logger) + score += record_result( + results, type(build_checker).__name__, build_checker.run(verbose=verbose) + ) + + prep_checker = OracleBenchmarkPrep(config=ANVIL_CONFIG, logger=logger) + score += record_result( + results, type(prep_checker).__name__, prep_checker.run(verbose=verbose) + ) + + runs_checker = OracleExperimentRuns(config=ANVIL_CONFIG, logger=logger) + score += record_result( + results, type(runs_checker).__name__, runs_checker.run(verbose=verbose) + ) + + logger.info("Agent scores: %s", results) return score if __name__ == "__main__": - main() \ No newline at end of file + raise SystemExit(main(sys.argv[1:])) diff --git a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_artifact_build.py b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_artifact_build.py index d45eb662..3554c528 100644 --- a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_artifact_build.py +++ b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_artifact_build.py @@ -1,86 +1,90 @@ -import os -import subprocess -from dataclasses import dataclass -from typing import Iterable, Optional, Tuple +#!/usr/bin/env python3 +"""Artifact build oracle for the OSDI '24 ANVIL artifact. + +Validates: + - The ACTO dependency repository can build its required library target. +""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +import logging from pathlib import Path -from utils import REPO_DIRS -from utils import logger +from evaluator.oracle_artifact_build_primitives import ( + BuildCommandRequirement, + BuildRequirement, + OracleArtifactBuildBase, +) +from evaluator.utils import EntryConfig -@dataclass(frozen=True) +@dataclass(frozen = True, slots = True, kw_only = True) class BuildTarget: + """Declarative description of one build command to run.""" + name: str - repo_key: str - cmd: list[str] - - -BUILD_TARGETS: list[BuildTarget] = [ - BuildTarget( - name="acto", - repo_key="acto", - cmd=["make", "lib"], - ), -] - - -class OracleArtifactBuild: - - def __init__(self) -> None: - self.repo_dirs = REPO_DIRS - - def run_shell_command( - self, - cmd: Iterable[str], - cwd: Optional[Path] = None, - ) -> Tuple[int, str, str]: - """ - Run a command and return (rc, stdout, stderr) tuple. - """ - try: - cp = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - cwd=str(cwd) if cwd is not None else None, - ) - return cp.returncode, cp.stdout or "", cp.stderr or "" - except FileNotFoundError: - return 127, "", "" - - def build_target(self, target: BuildTarget) -> Optional[str]: - """ - Build a single target using its configured repository and command. - """ - repo_dir = self.repo_dirs.get(target.repo_key, "") - if not repo_dir: - return f"{target.name} repo directory undefined" - - repo_path = Path(os.path.expanduser(repo_dir)) - if not repo_path.exists(): - return f"{target.name} repo directory missing" - - rc, out, err = self.run_shell_command(target.cmd, cwd=repo_path) - if rc != 0: - return f"{target.name} build failed (rc={rc})" - - return None - - def build_check(self): - """ - Run builds for all configured targets and collect failures. - """ - problems: list[str] = [] - for target in BUILD_TARGETS: - msg = self.build_target(target) - if msg: - problems.append(msg) - if problems: - return False, "; ".join(problems) - return True, "" - - def run(self): - ok, why = self.build_check() - logger.info(f"Build: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - return ok + cwd: Path + command: Sequence[str] + cwd_relative: Path | None = None + optional: bool = False + timeout_seconds: float = 60.0 + env_overrides: Mapping[str, str] = field(default_factory = dict) + + def __post_init__(self) -> None: + if not self.name: + raise ValueError("BuildTarget.name must be non-empty") + if not self.command: + raise ValueError(f"{self.name}: command must be non-empty") + if self.timeout_seconds <= 0: + raise ValueError(f"{self.name}: timeout_seconds must be > 0") + + object.__setattr__(self, "command", tuple(self.command)) + + +class OracleArtifactBuild(OracleArtifactBuildBase): + """Artifact build oracle for ANVIL.""" + + def __init__( + self, + *, + config: EntryConfig, + logger: logging.Logger, + targets: Sequence[BuildTarget] | None = None, + ) -> None: + super().__init__(logger = logger) + self._config = config + + if targets is None: + targets = self._default_targets() + self._targets = tuple(targets) + + names = [t.name for t in self._targets] + if len(names) != len(set(names)): + raise ValueError(f"Duplicate build target names: {names!r}") + + def _default_targets(self) -> tuple[BuildTarget, ...]: + acto_repo = self._config.repository_paths["osdi24-acto-dependency"] + return ( + BuildTarget( + name = "acto: make lib", + cwd = acto_repo, + command = ("make", "lib"), + timeout_seconds = 60.0, + ), + ) + + def requirements(self) -> Sequence[BuildRequirement]: + return tuple( + BuildCommandRequirement( + name = t.name, + optional = t.optional, + cwd = t.cwd, + command = t.command, + cwd_relative = t.cwd_relative, + timeout_seconds = t.timeout_seconds, + env_overrides = t.env_overrides, + ) + for t in self._targets + ) \ No newline at end of file diff --git a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_benchmark_prep.py b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_benchmark_prep.py index dcf80c4b..0e274242 100644 --- a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_benchmark_prep.py +++ b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_benchmark_prep.py @@ -18,7 +18,7 @@ def run_shell_command(self, cmd): Run a command and return (rc, stdout, stderr) tuple. """ try: - cp = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + cp = subprocess.run(cmd, stdout = subprocess.PIPE, stderr = subprocess.PIPE, text = True) return cp.returncode, (cp.stdout or "").strip(), (cp.stderr or "").strip() except FileNotFoundError as e: return 127, "", str(e) diff --git a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_env_setup.py b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_env_setup.py index 5d41fd11..8bef40db 100644 --- a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_env_setup.py +++ b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_env_setup.py @@ -1,204 +1,138 @@ -import os -import re +#!/usr/bin/env python3 +"""Environment setup oracle for the ANVIL bundle. + +This implementation uses evaluator.oracle_env_setup_primitives for consistent +reporting and verbose failure logging. +""" + +from __future__ import annotations + +import dataclasses +import logging import shutil -import subprocess -from dataclasses import dataclass -from typing import Iterable, Optional, Tuple +from collections.abc import Sequence from pathlib import Path -from utils import HOME, REPO_DIRS -from utils import logger - -VersionTuple = Tuple[int, ...] - - -@dataclass(frozen=True) -class Dependency: - name: str - binary: str - cmd: Optional[list] = None - parse_regex: Optional[str] = None - require: Optional[VersionTuple] = None - compare: Optional[str] = None - - -DEPENDENCIES: list[Dependency] = [ - - # Basic tooling - Dependency( - name="git", binary="git" - ), - - # Docker, latest version is okay - Dependency( - name="docker", binary="docker", - ), - - # Python v3.10+ - Dependency( - name="python3", binary="python3", - cmd=["python3", "--version"], parse_regex=r"Python\s+([0-9.]+)", - require=(3, 10), compare="gte", - ), - - # pip3 for Python 3.10+ - Dependency( - name="pip3", binary="pip3", - ), - - # Go toolchain (golang package), latest STL version - Dependency( - name="go", binary="go", - ), - - # Kind v0.20.0 - Dependency( - name="kind", binary="kind", - cmd=["kind", "version"], parse_regex=r"v([0-9.]+)", - require=(0, 20, 0), compare="gte", - ), - - # Kubectl v1.22.9 - Dependency( - name="kubectl", binary="kubectl", - cmd=["kubectl", "version", "--client", "--short"], - parse_regex=r"Client Version:\s+v?([0-9.]+)", - require=(1, 22, 9), compare="gte", - ), -] - - -class OracleEnvSetup: - - def __init__(self) -> None: - self.expected_root_dirs = REPO_DIRS.values() - self.go_root = HOME / "go" - self.go_bin = self.go_root / "bin" - self.venv_dir = HOME / ".venv" - - def run_shell_command(self, cmd: Iterable[str]) -> Tuple[int, str, str]: - """ - Run a command and return (rc, stdout, stderr) tuple. - """ - try: - cp = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - return cp.returncode, cp.stdout or "", cp.stderr or "" - except FileNotFoundError: - return 127, "", "" - - def parse_version_tuple(self, text: str) -> VersionTuple: - """ - Extract the first version-like token from arbitrary text. - For example, for Java: '1.8.0_422' -> (1, 8, 0) - """ - m = re.search(r"(\d+(?:\.\d+){0,3})", text) - return tuple(int(x) for x in m.group(1).split(".")) if m else () - - def extract_version(self, text: str, pattern: str) -> Tuple[VersionTuple, str]: - """ - Apply regex pattern on a version string. - """ - m = re.search(pattern, text, re.I) - if not m: - return (), "unknown" - ver_str = m.group(1) - return self.parse_version_tuple(ver_str), ver_str - - def cmp_versions(self, found: VersionTuple, required: VersionTuple, mode: str) -> bool: - """ - Compare versions either to be greater or equal to the reference. - """ - if not found: - return False - f, r = list(found), list(required) - while len(f) < len(r): - f.append(0) - while len(r) < len(f): - r.append(0) - return (f == r) if mode == "eq" else (f >= r) - - def paths_check(self): - """ - Check that Python virtual environment is succesfully created - and that Go-related paths are set properly. - """ - problems: list[str] = [] - - # Check repositories exist - for dir in self.expected_root_dirs: - if not Path(dir).exists(): - problems.append(f"{dir} directory not found repository not cloned properly") - - # Check Python virtual environment is created - if not Path(self.venv_dir).exists(): - problems.append(".venv virtual environment missing (run 'python3 -m venv .venv')") - - # Check Go directories exit - if not Path(self.go_root).exists(): - problems.append("$HOME/go directory missing (install golang and configure GOPATH)") - if not Path(self.go_bin).exists(): - problems.append("$HOME/go/bin directory missing (ensure Go tools are installed)") - - # Check PATH contains Go path - path_env = os.environ.get("PATH", "") - go_root_str = str(self.go_root) - go_bin_str = str(self.go_bin) - if go_root_str not in path_env or go_bin_str not in path_env: - problems.append("PATH missing $HOME/go or $HOME/go/bin " - "(export PATH=$HOME/go:$HOME/go/bin:$PATH)") - - if problems: - return False, "; ".join(problems) - return True, "" - - def check_dependency(self, dep: Dependency) -> Optional[str]: - """ - Core method that checks whether a certain dependency of a version - equal or greather than a reference version is installed. - """ - if shutil.which(dep.binary) is None: - return f"{dep.name} missing" - - # If no version information is required, presence is enough - if dep.cmd is None and dep.parse_regex is None and dep.require is None: - return None - - rc, out, err = self.run_shell_command(dep.cmd or []) - text = (out + "\n" + err).strip() - - if dep.parse_regex and dep.require and dep.compare: - ver_tuple, ver_str = self.extract_version(text, dep.parse_regex) - if not ver_tuple: - return f"{dep.name} version unreadable" - ok = self.cmp_versions(ver_tuple, dep.require, dep.compare) - cmp_word = "==" if dep.compare == "eq" else ">=" - want = ".".join(map(str, dep.require)) - return None if ok else f"{dep.name} {cmp_word} {want} not met (got {ver_str})" - - return f"{dep.name} check misconfigured" - - def prereqs_check(self): - problems: list[str] = [] - for dep in DEPENDENCIES: - msg = self.check_dependency(dep) - if msg: - problems.append(msg) - if problems: - return False, "; ".join(problems) - return True, "" - - def run(self): - results = [] - - ok, why = self.prereqs_check() - logger.info(f"Prerequisites: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - results.append(ok) - - ok, why = self.paths_check() - logger.info(f"Paths: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - results.append(ok) - - if all(results): - return True - - return False \ No newline at end of file +from evaluator.utils import CheckResult, EntryConfig +from evaluator.oracle_env_setup_primitives import ( + DependencyVersionRequirement, + EnvironmentVariableRequirement, + EnvQuantifier, + FilesystemPathRequirement, + OracleEnvSetupBase, + PathType, + Requirement, + VersionCompare, +) + + +@dataclasses.dataclass(frozen = True, slots = True, kw_only = True) +class ExecutableOnPathRequirement(Requirement): + """Checks that an executable is present on PATH (no version constraint).""" + + executable: str + + def __post_init__(self) -> None: + if not self.executable: + raise ValueError(f"{self.name}: executable must be non-empty") + + def check(self) -> CheckResult: + if shutil.which(self.executable) is None: + return CheckResult.failure(f"not found on PATH: {self.executable!r}") + return CheckResult.success() + + +class OracleEnvSetup(OracleEnvSetupBase): + """Validates environment prerequisites for the ANVIL bundle.""" + + def __init__(self, *, config: EntryConfig, logger: logging.Logger) -> None: + super().__init__(logger = logger) + self._config = config + + def requirements(self) -> Sequence[Requirement]: + home_dir = self._config.home_dir + venv_dir = home_dir / ".venv" + go_root = Path.home() / "go" + go_bin = go_root / "bin" + + reqs: list[Requirement] = [ + # Check dependencies + DependencyVersionRequirement( + name = "docker", + command = ("docker", "--version"), + required_version = (24, 0, 0), + compare = VersionCompare.GEQ, + ), + DependencyVersionRequirement( + name = "go", + command = ("go", "version"), + required_version = (1, 22, 0), + compare = VersionCompare.GEQ, + version_regex = r"go(\d+\.\d+(?:\.\d+)?)", + ), + DependencyVersionRequirement( + name = "python3", + command = ("python3", "--version"), + required_version = (3, 10, 0), + compare = VersionCompare.GEQ, + version_regex = r"Python\s+([0-9.]+)", + ), + DependencyVersionRequirement( + name = "pip3", + command = ("pip3", "--version"), + required_version = (24, 0, 0), + compare = VersionCompare.GEQ, + ), + DependencyVersionRequirement( + name = "kind", + command = ("kind", "version"), + required_version = (0, 20, 0), + compare = VersionCompare.GEQ, + version_regex = r"v([0-9.]+)", + ), + DependencyVersionRequirement( + name = "kubectl", + command = ("kubectl", "version", "--client", "--short"), + required_version = (1, 22, 9), + compare = VersionCompare.GEQ, + version_regex = r"Client Version:\s+v?([0-9.]+)", + ), + + # Check directory structure + FilesystemPathRequirement( + name = "venv_exists", + path = venv_dir, + path_type = PathType.DIRECTORY, + ), + FilesystemPathRequirement( + name = "go_root_exists", + path = go_root, + path_type = PathType.DIRECTORY, + ), + + # Check PATH contents + EnvironmentVariableRequirement( + name = "PATH_contains_go_root", + env_var = "PATH", + expected = str(go_root), + quantifier = EnvQuantifier.CONTAINS, + ), + EnvironmentVariableRequirement( + name = "PATH_contains_go_bin", + env_var = "PATH", + expected = str(go_bin), + quantifier = EnvQuantifier.CONTAINS, + ), + ] + + # Check that the repo root directory is present + for key, repo_root in sorted(self._config.repository_paths.items()): + reqs.append( + FilesystemPathRequirement( + name = f"repo_exists:{key}", + path = repo_root, + path_type = PathType.DIRECTORY, + ) + ) + + return reqs diff --git a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_experiment_runs.py b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_experiment_runs.py index 3405b7e5..a9f5f1c6 100644 --- a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_experiment_runs.py +++ b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/oracle_experiment_runs.py @@ -1,12 +1,29 @@ +#!/usr/bin/env python3 +"""Experiment runs oracle for the OSDI'24 ANVIL artifact. + +Validates results (tsble 3) against reference measurements by comparing +per-controller calues: + - mean ratio: verified_anvil_mean / reference_unverified_mean + - max ratio: verified_anvil_max / reference_unverified_max +""" + +from __future__ import annotations + import json +from collections.abc import Mapping, Sequence from dataclasses import dataclass from pathlib import Path -from typing import Tuple +import logging -from utils import RESULTS_PATH, REFERENCE_PATH, SIMILARITY_RATIO, logger +from evaluator.oracle_experiment_runs_primitives import ( + ExperimentRunsRequirement, + LabeledSequenceSimilarityThresholdRequirement, + OracleExperimentRunsBase, +) +from evaluator.utils import EntryConfig -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class TableRow: controller: str verified_anvil_mean: float @@ -15,255 +32,236 @@ class TableRow: reference_unverified_max: float -class OracleExperimentRuns: - - def __init__(self) -> None: - self.results_path = Path(RESULTS_PATH) - self.reference_path = Path(REFERENCE_PATH) - self.rows: list[TableRow] = [] - self.rows_by_controller: dict[str, TableRow] = {} - self._raw_lines: list[str] = [] - - def load(self) -> Tuple[bool, str]: - """ - Load the raw table file into memory. - """ - if not self.reference_path.exists(): - return False, f"{self.reference_path} (reference measurement) not found" - - if not self.results_path.exists(): - return False, f"{self.results_path} not found" - - text = self.results_path.read_text(encoding="utf-8") - lines = [line.rstrip("\n") for line in text.splitlines() if line.strip()] - if not lines: - return False, f"{self.results_path} is empty" - - self._raw_lines = lines - return True, "" - - def is_separator_line(self, line: str) -> bool: - """ - Return True if this looks like the Markdown header separator line. - """ - stripped = line.strip() - if not stripped.startswith("|") or not stripped.endswith("|"): - return False - - inner = stripped.replace("|", "").replace(" ", "") - return bool(inner) and all(ch in "-:" for ch in inner) - - def parse_float(self, text: str) -> Tuple[bool, float]: - """ - Parse a numeric string into a float. - """ - try: - return True, float(text.replace(",", "")) - except ValueError: - return False, 0.0 - - def parse_table(self) -> Tuple[bool, str]: - """ - Parse table saved in markdown format into rows and a dictionary keyed by controller. - """ - EXPECTED_HEADERS: list[str] = [ - "Controller", - "Verified (Anvil) Mean", - "Verified (Anvil) Max", - "Reference (unverified) Mean", - "Reference (unverified) Max", - ] - - def split_row(line: str) -> list[str]: - """ - Split a markdown table row into individual cells. - """ - return [cell.strip() for cell in line.strip().strip("|").split("|")] - - header_line: str | None = None - data_lines: list[str] = [] - - for line in self._raw_lines: - if "|" not in line: - # Not a table row, skip. - continue - - if header_line is None: - header_line = line - continue - - if self.is_separator_line(line): - # Skip the ---|--- header separator. - continue - - # Remaining lines are data rows. - data_lines.append(line) +_EXPECTED_HEADERS: tuple[str, ...] = ( + "Controller", + "Verified (Anvil) Mean", + "Verified (Anvil) Max", + "Reference (unverified) Mean", + "Reference (unverified) Max", +) - if header_line is None: - return False, "No table header found" - headers = split_row(header_line) - if headers != EXPECTED_HEADERS: - return False, f"Unexpected table headers: {headers!r}" +def _required_path(paths: Mapping[str, Path], key: str, *, label: str) -> Path: + """Returns a required path from a mapping with a clear error message.""" + try: + return paths[key] + except KeyError as exc: + raise ValueError(f"Missing {label}[{key!r}] in EntryConfig") from exc - self.rows = [] - self.rows_by_controller = {} - for line in data_lines: - cells = split_row(line) - if len(cells) != len(EXPECTED_HEADERS): - return False, f"Row has {len(cells)} cells, expected {len(EXPECTED_HEADERS)}: {line!r}" +def _is_separator_line(line: str) -> bool: + """Returns True if this looks like the Markdown header separator line.""" + stripped = line.strip() + if not stripped.startswith("|") or not stripped.endswith("|"): + return False + inner = stripped.replace("|", "").replace(" ", "") + return bool(inner) and all(ch in "-:" for ch in inner) - ok, verified_anvil_mean = self.parse_float(cells[1]) - if not ok: - return False, f"Unparseable float in column 'Verified (Anvil) Mean': {cells[1]!r}" - ok, verified_anvil_max = self.parse_float(cells[2]) - if not ok: - return False, f"Unparseable float in column 'Verified (Anvil) Max': {cells[2]!r}" +def _split_markdown_row(line: str) -> list[str]: + """Splits a markdown table row into cells.""" + return [cell.strip() for cell in line.strip().strip("|").split("|")] - ok, reference_unverified_mean = self.parse_float(cells[3]) - if not ok: - return False, f"Unparseable float in column 'Reference (unverified) Mean': {cells[3]!r}" - ok, reference_unverified_max = self.parse_float(cells[4]) - if not ok: - return False, f"Unparseable float in column 'Reference (unverified) Max': {cells[4]!r}" +def _parse_float_token(text: str, *, label: str) -> float: + """Parses a float allowing commas.""" + try: + return float(text.replace(",", "")) + except ValueError as exc: + raise ValueError(f"{label}: unparseable float: {text!r}") from exc - row = TableRow( - controller=cells[0], - verified_anvil_mean=verified_anvil_mean, - verified_anvil_max=verified_anvil_max, - reference_unverified_mean=reference_unverified_mean, - reference_unverified_max=reference_unverified_max, - ) - self.rows.append(row) - self.rows_by_controller[row.controller] = row - return True, "" +def _compute_ratios(row: TableRow) -> tuple[float, float]: + """Computes (mean_ratio, max_ratio) as verified/reference per row.""" + if row.reference_unverified_mean == 0.0: + mean_ratio = float("inf") + else: + mean_ratio = row.verified_anvil_mean / row.reference_unverified_mean - def load_json_rows(self, path: Path) -> Tuple[bool, list[TableRow], str]: - """ - Load TableRow entries from a JSON file. - """ - if not path.exists(): - return False, [], f"{path} not found" + if row.reference_unverified_max == 0.0: + max_ratio = float("inf") + else: + max_ratio = row.verified_anvil_max / row.reference_unverified_max - try: - raw = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError as e: - return False, [], f"{path} invalid JSON: {e}" - - if not isinstance(raw, list): - return False, [], f"{path} must contain a list of objects" - - rows: list[TableRow] = [] - for idx, obj in enumerate(raw): - if not isinstance(obj, dict): - return False, [], f"{path} entry #{idx} is not an object" - try: - row = TableRow( - controller=str(obj["controller"]), - verified_anvil_mean=float(obj["verified_anvil_mean"]), - verified_anvil_max=float(obj["verified_anvil_max"]), - reference_unverified_mean=float(obj["reference_unverified_mean"]), - reference_unverified_max=float(obj["reference_unverified_max"]), - ) - except (KeyError, TypeError, ValueError) as e: - return False, [], f"{path} malformed entry #{idx}: {e}" - rows.append(row) - - return True, rows, "" - - def compute_ratios(self, row: TableRow) -> Tuple[float, float]: - """ - Compute mean/max ratios (verified / reference) and compare with - similar ratios from reference measurements. - """ - if row.reference_unverified_mean == 0.0: - mean_ratio = float("inf") - else: - mean_ratio = row.verified_anvil_mean / row.reference_unverified_mean - - if row.reference_unverified_max == 0.0: - max_ratio = float("inf") - else: - max_ratio = row.verified_anvil_max / row.reference_unverified_max - - return mean_ratio, max_ratio - - def ratios_within_tolerance(self, found: float, ref: float) -> bool: - """ - Check whether two ratio values are within tolerance. - """ - if ref == 0.0: - return False - return abs(found - ref) <= (1.0 - SIMILARITY_RATIO) * max(abs(found), abs(ref)) - - def compare_against_reference(self) -> Tuple[bool, str]: - """ - Compare current measurements (parsed from the markdown table) against - reference measurements (loaded from JSON) using mean/max ratios. - """ - if not self.rows_by_controller: - return False, "No parsed rows available for comparison" - - ok, reference_rows, why = self.load_json_rows(self.reference_path) - if not ok: - return False, why - - ref_by_controller = {r.controller: r for r in reference_rows} - problems: list[str] = [] - - if len(self.rows_by_controller) != len(ref_by_controller): - why = ( - f"Missing or mismatched results: got {len(self.rows_by_controller)}" - + f", expected {len(ref_by_controller)}" - ) - return False, why + return mean_ratio, max_ratio - for controller, row in self.rows_by_controller.items(): - ref = ref_by_controller.get(controller) - if ref is None: - problems.append(f"Missing reference row for controller {controller}") - continue - mean_cur, max_cur = self.compute_ratios(row) - mean_ref, max_ref = self.compute_ratios(ref) +def _parse_results_table_rows(lines: Sequence[str]) -> list[TableRow]: + """Parses the markdown table from results into rows.""" + header_line: str | None = None + data_lines: list[str] = [] - if not self.ratios_within_tolerance(mean_cur, mean_ref): - problems.append( - f"{controller} mean ratio differs too much " - f"(got {mean_cur:.4f}, ref {mean_ref:.4f})" - ) + for line in lines: + if "|" not in line: + # Not a table row. + continue - if not self.ratios_within_tolerance(max_cur, max_ref): - problems.append( - f"{controller} max ratio differs too much " - f"(got {max_cur:.4f}, ref {max_ref:.4f})" - ) + if header_line is None: + header_line = line + continue + + if _is_separator_line(line): + continue + + data_lines.append(line) + + if header_line is None: + raise ValueError("No table header found") + + headers = _split_markdown_row(header_line) + if tuple(headers) != _EXPECTED_HEADERS: + raise ValueError(f"Unexpected table headers: {headers!r}") - if problems: - return False, "; ".join(problems) + rows: list[TableRow] = [] + for line in data_lines: + cells = _split_markdown_row(line) + if len(cells) != len(_EXPECTED_HEADERS): + raise ValueError( + f"Row has {len(cells)} cells, expected {len(_EXPECTED_HEADERS)}: {line!r}" + ) + + controller = cells[0] + verified_anvil_mean = _parse_float_token( + cells[1], label="Verified (Anvil) Mean" + ) + verified_anvil_max = _parse_float_token( + cells[2], label="Verified (Anvil) Max" + ) + reference_unverified_mean = _parse_float_token( + cells[3], label="Reference (unverified) Mean" + ) + reference_unverified_max = _parse_float_token( + cells[4], label="Reference (unverified) Max" + ) + + rows.append( + TableRow( + controller=controller, + verified_anvil_mean=verified_anvil_mean, + verified_anvil_max=verified_anvil_max, + reference_unverified_mean=reference_unverified_mean, + reference_unverified_max=reference_unverified_max, + ) + ) - return True, "" + return rows - def run(self): - results: list[bool] = [] - ok, why = self.load() - logger.info(f"Table present: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - results.append(ok) +def _load_reference_rows(path: Path) -> list[TableRow]: + """Loads reference TableRow objects from JSON (list of row objects).""" + if not path.exists(): + raise ValueError(f"{path} not found") - ok, why = self.parse_table() - logger.info(f"Table format: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - results.append(ok) + try: + raw = json.loads(path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ValueError(f"{path} invalid JSON: {exc}") from exc - ok, why = self.compare_against_reference() - logger.info(f"Compare against reference: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - results.append(ok) + if not isinstance(raw, list): + raise ValueError(f"{path} must contain a list of objects") - if all(results): - return True + rows: list[TableRow] = [] + for idx, obj in enumerate(raw): + if not isinstance(obj, dict): + raise ValueError(f"{path} entry #{idx} is not an object") - return False \ No newline at end of file + try: + rows.append( + TableRow( + controller=str(obj["controller"]), + verified_anvil_mean=float(obj["verified_anvil_mean"]), + verified_anvil_max=float(obj["verified_anvil_max"]), + reference_unverified_mean=float(obj["reference_unverified_mean"]), + reference_unverified_max=float(obj["reference_unverified_max"]), + ) + ) + except (KeyError, TypeError, ValueError) as exc: + raise ValueError(f"{path} malformed entry #{idx}: {exc}") from exc + + return rows + + +def _results_mean_ratio_pairs(lines: Sequence[str]) -> list[tuple[str, float]]: + """Returns (controller, mean_ratio) from results table.""" + rows = _parse_results_table_rows(lines) + out: list[tuple[str, float]] = [] + for r in rows: + mean_ratio, _ = _compute_ratios(r) + out.append((r.controller, mean_ratio)) + return out + + +def _results_max_ratio_pairs(lines: Sequence[str]) -> list[tuple[str, float]]: + """Returns (controller, max_ratio) from results table.""" + rows = _parse_results_table_rows(lines) + out: list[tuple[str, float]] = [] + for r in rows: + _, max_ratio = _compute_ratios(r) + out.append((r.controller, max_ratio)) + return out + + +def _reference_mean_ratio_pairs(path: Path) -> list[tuple[str, float]]: + """Returns (controller, mean_ratio) from reference JSON rows.""" + rows = _load_reference_rows(path) + out: list[tuple[str, float]] = [] + for r in rows: + mean_ratio, _ = _compute_ratios(r) + out.append((r.controller, mean_ratio)) + return out + + +def _reference_max_ratio_pairs(path: Path) -> list[tuple[str, float]]: + """Returns (controller, max_ratio) from reference JSON rows.""" + rows = _load_reference_rows(path) + out: list[tuple[str, float]] = [] + for r in rows: + _, max_ratio = _compute_ratios(r) + out.append((r.controller, max_ratio)) + return out + + +class OracleExperimentRuns(OracleExperimentRunsBase): + """Validates ANVIL experiment run outputs (TABLE-3).""" + + _NAME = "ExperimentRuns" + + def __init__(self, *, config: EntryConfig, logger: logging.Logger) -> None: + super().__init__(logger=logger) + self._config = config + + def requirements(self) -> Sequence[ExperimentRunsRequirement]: + if not self._config.results_paths: + raise ValueError("EntryConfig.results_paths must be non-empty") + if not self._config.ground_truth_paths: + raise ValueError("EntryConfig.ground_truth_paths must be non-empty") + + results_path = _required_path( + self._config.results_paths, "table3", label="results_paths" + ) + reference_path = _required_path( + self._config.ground_truth_paths, "table3", label="ground_truth_paths" + ) + + threshold = self._config.similarity_ratio + + return ( + LabeledSequenceSimilarityThresholdRequirement( + name="table3_mean_ratio", + label="TABLE-3 mean_ratio", + results_path=results_path, + reference_path=reference_path, + threshold=threshold, + parse_results_fn=_results_mean_ratio_pairs, + parse_reference_fn=_reference_mean_ratio_pairs, + ), + LabeledSequenceSimilarityThresholdRequirement( + name="table3_max_ratio", + label="TABLE-3 max_ratio", + results_path=results_path, + reference_path=reference_path, + threshold=threshold, + parse_results_fn=_results_max_ratio_pairs, + parse_reference_fn=_reference_max_ratio_pairs, + ), + ) diff --git a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/utils.py b/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/utils.py deleted file mode 100644 index cf6f6e10..00000000 --- a/benchmarks/arteval_bench/data/benchmark/osdi24_anvil/_agent_eval/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -# --- CONSTANTS --- # -from pathlib import Path - -HOME = Path.home() / "osdi24_anvil" -REPO_DIRS = {"acto": f"{HOME}/acto", "anvil": f"{HOME}/anvil"} - -REFERENCE_PATH = f"{HOME}/_agent_eval/refs/anvil-table-3.ref.json" -RESULTS_PATH = f"{REPO_DIRS["acto"]}/anvil-table-3.txt" - -SIMILARITY_RATIO = 0.75 - - -# --- CUSTOM LOGGER --- # -import logging -import os -from datetime import datetime - -os.makedirs('logs', exist_ok=True) - -LOG_FORMAT = '%(asctime)s | %(levelname)s | %(name)s | %(message)s' -DATE_FORMAT = '%Y-%m-%d %H:%M:%S' - -logger = logging.getLogger("OSDI24-ANVIL-AGENT-EVALUATOR") -logger.setLevel(logging.DEBUG) - -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.INFO) -console_handler.setFormatter(logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)) - -logger.addHandler(console_handler) \ No newline at end of file From 03c839287195b4d1c1bc456c4837209b768695ed Mon Sep 17 00:00:00 2001 From: Bogdan-Alexandru Stoica Date: Sat, 31 Jan 2026 01:22:53 -0600 Subject: [PATCH 5/6] refactor: adapt wasabi's oracles to use the standardized interface --- .../sosp24_wasabi/_agent_eval/main.py | 86 +++++++ .../_agent_eval/oracle_artifact_build.py | 190 ++++++++++++++++ .../_agent_eval/oracle_benchmark_prep.py | 214 ++++++++++++++++++ .../_agent_eval/oracle_env_setup.py | 185 +++++++++++++++ .../_agent_eval/oracle_experiment_runs.py | 121 ++++++++++ 5 files changed, 796 insertions(+) create mode 100644 benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/main.py create mode 100644 benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_artifact_build.py create mode 100644 benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_benchmark_prep.py create mode 100644 benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_env_setup.py create mode 100644 benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_experiment_runs.py diff --git a/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/main.py b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/main.py new file mode 100644 index 00000000..95967ba7 --- /dev/null +++ b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/main.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +"""Runs environment setup checks for WASABI.""" + +from __future__ import annotations +from pathlib import Path +from typing import Dict +import os +import sys + + +_AGENT_EVAL_DIR = Path(__file__).resolve().parent +_AGENT_SRC_DIR = _AGENT_EVAL_DIR.parents[3] / "src" +sys.path.append(str(_AGENT_SRC_DIR)) + + +from evaluator.utils import ( + EntryConfig, + LoggerConfig, + get_logger, + record_result, +) +from oracle_artifact_build import OracleArtifactBuild +from oracle_env_setup import OracleEnvSetup +from oracle_benchmark_prep import OracleBenchmarkPrep +from oracle_experiment_runs import OracleExperimentRuns + + +# NOTE: WASABI bundle layout mirrors the legacy constants, but we build it directly +# from EntryConfig rather than importing legacy globals. +_WASABI_HOME = Path.home() / "sosp24_wasabi" +_WASABI_REPO = _WASABI_HOME / "wasabi" +_WASABI_BENCH = _WASABI_HOME / "benchmarks" + + +WASABI_CONFIG = EntryConfig( + name = "sosp24-wasabi", + home_dir = _WASABI_HOME, + repository_paths = { + "sosp24-wasabi": _WASABI_REPO, + "benchmarks": _WASABI_BENCH, + }, + results_paths = { + "results_root": _WASABI_REPO / "results", + }, + ground_truth_paths = { + "bugs_ground_truth": _WASABI_REPO / "bugs_ground_truth.txt", + }, + similarity_ratio = 0.75, +) + + +def main(argv: list[str]) -> int: + verbose = "--verbose" in argv + + results: Dict[str, int] = {} + score = 0 + + logger_name = os.environ.get("EVAL_LOGGER_NAME", "WASABI-AGENT-EVALUATOR") + logger = get_logger(LoggerConfig(root_name = logger_name)) + + env_checker = OracleEnvSetup(config = WASABI_CONFIG, logger = logger) + score += record_result( + logger, results, type(env_checker).__name__, env_checker.run(verbose = verbose) + ) + + build_checker = OracleArtifactBuild(config = WASABI_CONFIG, logger = logger) + score += record_result( + logger, results, type(build_checker).__name__, build_checker.run(verbose = verbose) + ) + + prep_checker = OracleBenchmarkPrep(config = WASABI_CONFIG, logger = logger) + score += record_result( + logger, results, type(prep_checker).__name__, prep_checker.run(verbose = verbose) + ) + + runs_checker = OracleExperimentRuns(config = WASABI_CONFIG, logger = logger) + score += record_result( + logger, results, type(runs_checker).__name__, runs_checker.run(verbose = verbose) + ) + + logger.info("Agent scores: %s", results) + return score + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_artifact_build.py b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_artifact_build.py new file mode 100644 index 00000000..6bf39f2f --- /dev/null +++ b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_artifact_build.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +import xml.etree.ElementTree as ET +import fnmatch + +from utils import HOME +from utils import REPO_DIR +from utils import logger + +from evaluator.oracle_artifact_build_primitives import OracleArtifactBuildBase +from evaluator import utils + + +@dataclasses.dataclass(frozen=True, slots=True) +class _BuildInputsRequirement(utils.BaseRequirement): + oracle: "OracleArtifactBuild" + + def check(self, ctx: object) -> utils.CheckResult: + del ctx + + if not REPO_DIR.exists(): + logger.info("Build: FAIL - base project directory not found") + return utils.CheckResult.failure("base project directory not found") + + poms = self.oracle.find_poms(REPO_DIR) + if not poms: + logger.info("Build: FAIL - no pom.xml files found under wasabi-testing") + return utils.CheckResult.failure("no pom.xml files found under wasabi-testing") + + root_pom = REPO_DIR / "pom.xml" + top_defaults = {} + if root_pom.exists(): + root_mod = self.oracle.parse_pom(root_pom) + if not root_mod.get("error"): + if root_mod.get("groupId"): + top_defaults["groupId"] = root_mod["groupId"] + if root_mod.get("version"): + top_defaults["version"] = root_mod["version"] + + modules = [] + errors = [] + for pom in poms: + m = self.oracle.parse_pom(pom, top_defaults=top_defaults) + if m.get("error"): + errors.append((pom, m["error"])) + continue + if not all([m.get("artifactId"), m.get("groupId"), m.get("version")]): + errors.append((pom, "missing groupId/artifactId/version after inheritance")) + else: + modules.append(m) + + if errors: + logger.info("Build: FAIL - POM parsing errors present") + for pom, err in errors[:5]: + logger.info(f" - {pom}: {err}") + if len(errors) > 5: + logger.info(f" ... {len(errors)-5} more") + return utils.CheckResult.failure("POM parsing errors present") + + self.oracle._modules = modules + return utils.CheckResult.success() + + +@dataclasses.dataclass(frozen=True, slots=True) +class _CodeBuildRequirement(utils.BaseRequirement): + oracle: "OracleArtifactBuild" + + def check(self, ctx: object) -> utils.CheckResult: + del ctx + + modules = getattr(self.oracle, "_modules", None) + if not modules: + return utils.CheckResult.success() + + missing_targets = [] + missing_installs = [] + + for m in modules: + if not self.oracle.has_target_jar(m): + missing_targets.append(str(m["dir"])) + if not self.oracle.has_installed_artifact(m): + missing_installs.append(f"{m['groupId']}:{m['artifactId']}:{m['version']}") + + if missing_targets or missing_installs: + logger.info("Code build: FAIL") + if missing_targets: + logger.info(" Missing built JARs in target/:") + for d in missing_targets[:10]: + logger.info(f" - {d}") + if len(missing_targets) > 10: + logger.info(f" ... {len(missing_targets)-10} more") + if missing_installs: + logger.info(" Missing artifacts in local ~/.m2 repository:") + for gav in missing_installs[:10]: + logger.info(f" - {gav}") + if len(missing_installs) > 10: + logger.info(f" ... {len(missing_installs)-10} more") + + return utils.CheckResult.failure("missing built jars and/or installed artifacts") + + logger.info("Code build: PASS") + return utils.CheckResult.success() + + +class OracleArtifactBuild(OracleArtifactBuildBase): + def __init__(self, *, logger=logger): + super().__init__(logger=logger) + self.maven_packages_dir = HOME / ".m2" / "repository" + self._modules = None + + def requirements(self): + return ( + _BuildInputsRequirement(name="Build", oracle=self), + _CodeBuildRequirement(name="Code build", oracle=self), + ) + + def xget(self, elem, tag): + """ + Helper function to handle POM tags with or without default namespace + """ + if elem is None: + return None + v = elem.find(tag) + if v is not None and v.text: + return v.text.strip() + for child in elem: + t = child.tag.split('}', 1)[-1] + if t == tag: + return (child.text or "").strip() + return None + + def parse_pom(self, pom_path, top_defaults=None): + """ + Collects POM files into dictionary + """ + try: + tree = ET.parse(pom_path) + root = tree.getroot() + except Exception as e: + return {"dir": pom_path.parent, "pom": pom_path, "error": f"XML parse error: {e}"} + + artifactId = self.xget(root, "artifactId") + groupId = self.xget(root, "groupId") + version = self.xget(root, "version") + packaging = self.xget(root, "packaging") or "jar" + + parent = root.find("parent") + if parent is not None: + p_groupId = self.xget(parent, "groupId") + p_version = self.xget(parent, "version") + if not groupId and p_groupId: + groupId = p_groupId + if not version and p_version: + version = p_version + + if top_defaults: + groupId = groupId or top_defaults.get("groupId") + version = version or top_defaults.get("version") + + return { + "dir": pom_path.parent, + "pom": pom_path, + "groupId": groupId, + "artifactId": artifactId, + "version": version, + "packaging": packaging + } + + def find_poms(self, base): + return sorted(base.rglob("pom.xml")) + + def repo_path(self, groupId, artifactId, version): + parts = groupId.split(".") + return self.maven_packages_dir.joinpath(*parts, artifactId, version) + + def has_target_jar(self, module): + if module["packaging"] == "pom": + return True # no jar expected + target = module["dir"] / "target" + if not target.is_dir(): + return False + pattern = f"{module['artifactId']}-{module['version']}*.jar" + return any(fnmatch.fnmatch(p.name, pattern) for p in target.glob("*.jar")) + + def has_installed_artifact(self, module): + rp = self.repo_path(module["groupId"], module["artifactId"], module["version"]) + if module["packaging"] == "pom": + return (rp / f"{module['artifactId']}-{module['version']}.pom").is_file() + return any(p.suffix == ".jar" and fnmatch.fnmatch( + p.name, f"{module['artifactId']}-{module['version']}*.jar") + for p in rp.glob("*.jar")) \ No newline at end of file diff --git a/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_benchmark_prep.py b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_benchmark_prep.py new file mode 100644 index 00000000..96f19eef --- /dev/null +++ b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_benchmark_prep.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +import sys +import shlex +import subprocess +from dataclasses import dataclass +from pathlib import Path + +from utils import BENCH_DIR +from utils import logger + +from evaluator.utils import EntryConfig +from evaluator.oracle_benchmark_prep_primitives import ( + OracleBenchmarkPrepBase, + Requirement, +) +from evaluator import utils + + + +REPOS = { + "hadoop": ("https://github.com/apache/hadoop.git", "60867de"), + "hbase": ("https://github.com/apache/hbase.git", "89ca7f4"), + "hive": ("https://github.com/apache/hive.git", "e08a600"), +} + +ASPECTJ_MARKERS = [ + "ajc$preClinit", + "ajc$initFailureCause", + "ajc$tjp", + "ajc$before$", + "ajc$after$", + "ajc$around$", + "ajc$interField$", + "ajc$interMethod$", + "org.aspectj.runtime.reflect.Factory", + "org.aspectj.runtime.internal.AroundClosure", + "org.aspectj.lang.JoinPoint", + "org.aspectj.lang.JoinPoint$StaticPart", + "org.aspectj.lang.ProceedingJoinPoint", + "org.aspectj.lang.Signature", + "org.aspectj.lang.NoAspectBoundException", +] + + +def _required_path(paths, key: str, *, label: str) -> Path: + """Returns a required path from a mapping with a clear error.""" + try: + return paths[key] + except KeyError as e: + raise ValueError(f"Missing {label}[{key!r}] in EntryConfig") from e + + +@dataclass(frozen=True, slots=True) +class _RepoCommitRequirement(utils.BaseRequirement): + oracle: "OracleBenchmarkPrep" + app: str + app_root: Path + expected_commit_prefix: str + + def check(self, ctx) -> utils.CheckResult: + ok, msg = self.oracle.check_repo_commit(self.app, self.app_root, self.expected_commit_prefix) + ctx.logger.info(msg) + return utils.CheckResult.success() if ok else utils.CheckResult.failure(msg) + + +@dataclass(frozen=True, slots=True) +class _WeavingRequirement(utils.BaseRequirement): + oracle: "OracleBenchmarkPrep" + app: str + app_root: Path + + def check(self, ctx) -> utils.CheckResult: + ok, msg = self.oracle.check_app_weaving(self.app, self.app_root) + ctx.logger.info(msg) + return utils.CheckResult.success() if ok else utils.CheckResult.failure(msg) + + +class OracleBenchmarkPrep(OracleBenchmarkPrepBase): + + def __init__(self, *, config: EntryConfig, logger: logger.__class__): + super().__init__(logger = logger) + self._config = config + + self.max_class_dirs = 200 + self.max_classess_per_dir = 2000 + + def requirements(self) -> tuple[Requirement, ...]: + bench_root = _required_path(self._config.repository_paths, "benchmarks", label="repository_paths") + + reqs: list[Requirement] = [] + for app in REPOS: + app_root = bench_root / app + + expected_commit = REPOS[app][1] + reqs.append( + _RepoCommitRequirement( + name = f"{app}: clone", + oracle = self, + app = app, + app_root = app_root, + expected_commit_prefix = expected_commit, + ) + ) + + reqs.append( + _WeavingRequirement( + name = f"{app}: weaving", + oracle = self, + app = app, + app_root = app_root, + ) + ) + + return tuple(reqs) + + def run_shell_command(self, cmd): + """ + Run a bash command given as argument. + """ + try: + cp = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + return cp.returncode, (cp.stdout or "").strip(), (cp.stderr or "").strip() + except FileNotFoundError as e: + return 127, "", str(e) + + def find_class_dirs(self, app_root: Path): + """ + Find directories that contain .class files. + """ + qroot = shlex.quote(str(app_root)) + cmd = [ + "bash", + "-lc", + ( + f"shopt -s nullglob; " + f"find {qroot} -type f -name '*.class' " + f"-not -path '*/.git/*' -not -path '*/.m2/*' -not -path '*/.gradle/*' " + f"-printf '%h\n' | sort -u" + ), + ] + rc, out, err = self.run_shell_command(cmd) + if rc != 0: + return [], f"find failed: {err or out}" + dirs = [Path(p) for p in out.splitlines() if p] + return dirs, "" + + def iter_class_files(self, classes_dir: Path, limit: int): + """ + Iterate over .class files from a class directory, processing up to + a configurable number of files. + """ + q = shlex.quote(str(classes_dir)) + cmd = ["bash", "-lc", f"shopt -s nullglob; find {q} -type f -name '*.class' | sort"] + rc, out, err = self.run_shell_command(cmd) + if rc != 0 or not out: + return [] + files = [Path(p) for p in out.splitlines() if p] + if limit and len(files) > limit: + step = max(len(files) // limit, 1) + files = files[::step][:limit] + return files + + def check_repo_commit(self, app: str, app_root: Path, expected_commit_prefix: str): + """ + Verify the repo at app_root is a git repo and HEAD matches an expected commit ID prefix. + """ + if not app_root.is_dir(): + return False, f"{app}: FAIL (clone) - directory not found: {app_root}" + + rc, out, err = self.run_shell_command(["git", "-C", str(app_root), "rev-parse", "HEAD"]) + if rc != 0: + return False, f"{app}: FAIL (clone) - not a git repo or unreadable HEAD: {err or out}" + + head = (out or "").strip() + if head.startswith(expected_commit_prefix): + return True, f"{app}: PASS (clone) - commit {head[:12]} matches {expected_commit_prefix}" + else: + return False, f"{app}: FAIL (clone) - HEAD {head[:12]} != expected {expected_commit_prefix}*" + + + def classfile_has_aspect_markers(self, class_path: Path): + """ + Search through a decoded .class for AspectJ markers. + """ + pattern = "|".join(ASPECTJ_MARKERS) + cmd = ["bash", "-lc", f"strings {shlex.quote(str(class_path))} | grep -a -E '{pattern}' -m 1"] + rc, out, err = self.run_shell_command(cmd) + if rc == 0 and out: + matched = next((m for m in ASPECTJ_MARKERS if m in out), out) + return True, matched + return False, "" + + def check_app_weaving(self, app: str, app_root: Path): + """ + Scan compiled .class files for AspectJ markers. + """ + if not app_root.is_dir(): + return False, f"{app}: FAIL (waving) - directory not found: {app_root}" + + class_dirs, err = self.find_class_dirs(app_root) + if err: + return False, f"{app}: FAIL (waving) - {err}" + if not class_dirs: + return False, f"{app}: FAIL (waving) - no compiled .class files found under {app_root}" + + dirs = class_dirs[:self.max_class_dirs] if (self.max_class_dirs and len(class_dirs) > self.max_class_dirs) else class_dirs + + for cdir in dirs: + for cf in self.iter_class_files(cdir, self.max_classess_per_dir): + ok, marker = self.classfile_has_aspect_markers(cf) + if ok: + return True, f"{app}: PASS (weaving) - marker '{marker}' in {cf}" + + return False, f"{app}: FAIL (weaving) - scanned .class files but found no AspectJ markers" diff --git a/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_env_setup.py b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_env_setup.py new file mode 100644 index 00000000..4c6016e2 --- /dev/null +++ b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_env_setup.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +import os +import re +import shutil +import subprocess +from dataclasses import dataclass +from typing import Iterable, Optional, Tuple +from pathlib import Path + +from utils import REPO_DIR +from utils import logger as _default_logger + +from evaluator.oracle_env_setup_primitives import OracleEnvSetupBase, Requirement +from evaluator import utils + + +VersionTuple = Tuple[int, ...] + + +@dataclass(frozen=True) +class Dependency: + name: str + binary: str + cmd: Optional[list] = None + parse_regex: Optional[str] = None + require: Optional[VersionTuple] = None + compare: Optional[str] = None + + +DEPENDENCIES: list[Dependency] = [ + + Dependency( + name="git", binary="git" + ), + + Dependency( + name="maven", binary="mvn", + cmd=["mvn", "-v"], parse_regex=r"Apache Maven\s+([0-9.]+)", + require=(3, 6, 3), compare="gte", + ), + Dependency( + name="gradle", binary="gradle", + cmd=["gradle", "-v"], parse_regex=r"Gradle\s+([0-9.]+)", + require=(4, 4, 1), compare="gte", + ), + Dependency( + name="ant", binary="ant", + cmd=["ant", "-version"], parse_regex=r"version\s+([0-9.]+)", + require=(1, 10), compare="gte", + ), + Dependency( + name="python3", binary="python3", + cmd=["python3", "--version"], parse_regex=r"Python\s+([0-9.]+)", + require=(3, 10), compare="gte", + ), + Dependency( + name="java", binary="java", + cmd=["java", "-version"], parse_regex=r'version\s+"([^"]+)"', + require=(1, 8), compare="eq", + ), +] + + +@dataclass(frozen=True, slots=True) +class _PrereqsRequirement(utils.BaseRequirement): + oracle: "OracleEnvSetup" + + def check(self, ctx: object) -> utils.CheckResult: + del ctx + ok, why = self.oracle.prereqs_check() + if ok: + return utils.CheckResult.success() + return utils.CheckResult.failure(why or "Prerequisites failed") + + +@dataclass(frozen=True, slots=True) +class _PathsRequirement(utils.BaseRequirement): + oracle: "OracleEnvSetup" + + def check(self, ctx: object) -> utils.CheckResult: + del ctx + ok, why = self.oracle.paths_check() + if ok: + return utils.CheckResult.success() + return utils.CheckResult.failure(why or "Paths failed") + + +class OracleEnvSetup(OracleEnvSetupBase): + + def __init__(self, *, logger=_default_logger) -> None: + super().__init__(logger=logger) + + self.expected_root_dir = REPO_DIR + self.expected_java_home = "/usr/lib/jvm/java-8-openjdk-amd64/jre" + + def requirements(self) -> Tuple[Requirement, ...]: + return ( + _PrereqsRequirement(name="Prerequisites", oracle=self), + _PathsRequirement(name="Paths", oracle=self), + ) + + def run_shell_command(self, cmd: Iterable[str]) -> Tuple[int, str, str]: + """ + Run a command and return (rc, stdout, stderr) tuple. + """ + try: + cp = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + return cp.returncode, cp.stdout or "", cp.stderr or "" + except FileNotFoundError: + return 127, "", "" + + def parse_version_tuple(self, text: str) -> VersionTuple: + """ + Extract the first version-like token from arbitrary text. + For example, for Java: '1.8.0_422' -> (1, 8, 0) + """ + m = re.search(r"(\d+(?:\.\d+){0,3})", text) + return tuple(int(x) for x in m.group(1).split(".")) if m else () + + def extract_version(self, text: str, pattern: str) -> Tuple[VersionTuple, str]: + """ + Apply regex pattern on a version string. + """ + m = re.search(pattern, text, re.I) + if not m: + return (), "unknown" + ver_str = m.group(1) + return self.parse_version_tuple(ver_str), ver_str + + def cmp_versions(self, found: VersionTuple, required: VersionTuple, mode: str) -> bool: + """ + Compare versions either to match exactly ('eq') + or the installed version is greather than the reference one ('gte'). + """ + if not found: + return False + f, r = list(found), list(required) + while len(f) < len(r): f.append(0) + while len(r) < len(f): r.append(0) + return (f == r) if mode == "eq" else (f >= r) + + def paths_check(self): + wasabi_root = os.environ.get("WASABI_ROOT_DIR", "") + if not (wasabi_root == self.expected_root_dir and Path(wasabi_root).exists()): + return False, "WASABI_ROOT_DIR incorrect" + java_home = os.environ.get("JAVA_HOME", "") + if not (java_home == self.expected_java_home and Path(java_home).exists()): + return False, "JAVA_HOME incorrect" + return True, "" + + def check_dependency(self, dep: Dependency) -> Optional[str]: + """ + Core method that checks whether a certain dependency of a version + equal or greather than that specified in the README is installed. + """ + if shutil.which(dep.binary) is None: + return f"{dep.name} missing" + + + if dep.cmd is None and dep.parse_regex is None and dep.require is None: + return None + + rc, out, err = self.run_shell_command(dep.cmd or []) + text = (out + "\n" + err).strip() + + if dep.parse_regex and dep.require and dep.compare: + ver_tuple, ver_str = self.extract_version(text, dep.parse_regex) + if not ver_tuple: + return f"{dep.name} version unreadable" + ok = self.cmp_versions(ver_tuple, dep.require, dep.compare) + cmp_word = "==" if dep.compare == "eq" else ">=" + want = ".".join(map(str, dep.require)) + return None if ok else f"{dep.name} {cmp_word} {want} not met (got {ver_str})" + + return f"{dep.name} check misconfigured" + + def prereqs_check(self): + problems: list[str] = [] + for dep in DEPENDENCIES: + msg = self.check_dependency(dep) + if msg: + problems.append(msg) + if problems: + return False, "; ".join(problems) + return True, "" \ No newline at end of file diff --git a/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_experiment_runs.py b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_experiment_runs.py new file mode 100644 index 00000000..e37e0d42 --- /dev/null +++ b/benchmarks/arteval_bench/data/benchmark/sosp24_wasabi/_agent_eval/oracle_experiment_runs.py @@ -0,0 +1,121 @@ +from collections import defaultdict +import os + +from utils import RESULTS_ROOT_DIR +from utils import GROUND_TRUTH_FILE +from utils import SIMILARITY_RATIO + +from utils import logger + +class OracleExperimentRuns: + def __init__(self): + pass + + def get_benchmark_name(self, loc): + """ + Classifies the location based on its prefix. + """ + if loc.startswith("org.apache.hadoop.hdfs") and "SecondaryNameNode.doWork" not in loc: + return "hdfs" + elif loc.startswith("org.apache.hadoop.yarn"): + return "yarn" + elif loc.startswith("org.apache.hadoop.mapreduce") or loc.startswith("org.apache.hadoop.mapred"): + return "mapreduce" + elif loc.startswith("org.apache.hadoop.hbase"): + return "hbase" + elif loc.startswith("org.apache.hadoop.hive"): + return "hive" + elif loc.startswith("org.apache.cassandra"): + return "cassandra" + elif loc.startswith("org.apache.hadoop") or "SecondaryNameNode.doWork" in loc: # initialy found in hadoop-common, added here to match Table 3 + return "hadoop" + elif loc.startswith("org.elasticsearch"): + return "elasticsearch" + else: + return "unknown" + + def aggregate_bugs(self, root_dir): + """ + Searches for bug report files and aggregates bugs based on their type and + which application have been found in. + """ + bugs = defaultdict(lambda: defaultdict(set)) + unique = dict() + + for dirpath, _, files in os.walk(root_dir): + for file in files: + if file.endswith(".csv"): + file_path = os.path.join(dirpath, file) + + with open(file_path, 'r') as f: + for line in f: + if "how-bug" in line or "when-missing-" in line: + tokens = line.strip().split(",") + + bug_type = tokens[1] + bug_loc = tokens[2] + + key = bug_type + bug_loc + if key in unique: + continue + unique[key] = "x" + + benchmark = self.get_benchmark_name(bug_loc) + bugs[bug_type][benchmark].add(bug_loc) + + return bugs + + def get_ground_truth_bugs(self, file_path: str): + """ + Reads the ground truth values from a file into a dictionary. + """ + ground_truth = defaultdict(lambda: defaultdict(set)) + + try: + with open(file_path, 'r') as f: + for line in f: + tokens = line.strip().split(",") + benchmark = tokens[0] + bug_type = tokens[1] + retry_location = tokens[2] + ground_truth[bug_type][benchmark].add(retry_location) + except Exception: + logger.info(f"Cannot open {file_path} or file not present.") + + return ground_truth + + def count_bugs(self, bugs, ground_truth): + """ + Compares the total number of bugs found against the ground truth. + """ + total_ground_truth = 0 + total_found = 0 + + for bug_type, benchmarks in ground_truth.items(): + for benchmark, ground_truth_locations in benchmarks.items(): + total_ground_truth += len(ground_truth_locations) + bug_locations = bugs.get(bug_type, {}).get(benchmark, set()) + matching_locations = ground_truth_locations & bug_locations + total_found += len(matching_locations) + + if total_ground_truth == 0: + logger.info("No ground truth bugs available.") + return False + + coverage = total_found / total_ground_truth + logger.info(f"Found {total_found} out of {total_ground_truth} ground truth bugs ({coverage:.2%}).") + + passed = coverage >= SIMILARITY_RATIO + logger.info("Results reproduced: PASS" if passed else "Results reproduced: FAIL") + return passed + + + def run(self): + bugs = self.aggregate_bugs(str(RESULTS_ROOT_DIR)) + ground_truth = self.get_ground_truth_bugs(str(GROUND_TRUTH_FILE)) + passed = self.count_bugs(bugs, ground_truth) + + if passed: + return True + + return False \ No newline at end of file From f3b2b1ce588fb537c72dae8cff06dfa86d69cca6 Mon Sep 17 00:00:00 2001 From: Bogdan-Alexandru Stoica Date: Sat, 31 Jan 2026 01:23:09 -0600 Subject: [PATCH 6/6] refactor: adapt acto's oracles to use the standardized interface --- .../benchmark/sosp23_acto/_agent_eval/main.py | 88 +- .../_agent_eval/oracle_artifact_build.py | 178 ++- .../_agent_eval/oracle_env_setup.py | 379 +++--- .../_agent_eval/oracle_experiment_runs.py | 1203 ++++++++--------- .../sosp23_acto/_agent_eval/utils.py | 40 - 5 files changed, 890 insertions(+), 998 deletions(-) delete mode 100644 benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/utils.py diff --git a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/main.py b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/main.py index 2f434ee5..7f1676d8 100644 --- a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/main.py +++ b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/main.py @@ -1,32 +1,78 @@ #!/usr/bin/env python3 -import sys +"""Runs environment setup checks for ACTO.""" + +from __future__ import annotations +from pathlib import Path from typing import Dict +import os +import sys -from oracle_artifact_build import OracleArtifactBuild -from oracle_env_setup import OracleEnvSetup -from oracle_benchmark_prep import OracleBenchmarkPrep -from oracle_experiment_runs import OracleExperimentRuns -from utils import logger +_AGENT_EVAL_DIR = Path(__file__).resolve().parent +_AGENT_SRC_DIR = _AGENT_EVAL_DIR.parents[3] / "src" +sys.path.append(str(_AGENT_SRC_DIR)) -def main(): - results: Dict[str, int] = {} +from evaluator.utils import ( # pylint: disable = wrong-import-position + EntryConfig, + LoggerConfig, + get_logger, + record_result, +) +from oracle_artifact_build import OracleArtifactBuild # pylint: disable = wrong-import-position +from oracle_env_setup import OracleEnvSetup # pylint: disable = wrong-import-position + + +ACTO_CONFIG = EntryConfig( + name = "sosp23-acto", + home_dir = Path.home() / "sosp23_acto", + repository_paths = {"sosp23-acto": (Path.home() / "sosp23_acto" / "acto")}, + results_paths = { + "table5": (Path.home() / "sosp23_acto" / "acto" / "table5.txt"), + "table6": (Path.home() / "sosp23_acto" / "acto" / "table6.txt"), + "table7": (Path.home() / "sosp23_acto" / "acto" / "table7.txt"), + "table8": (Path.home() / "sosp23_acto" / "acto" / "table8.txt"), + }, + ground_truth_paths = { + "table5": ( + Path.home() / "sosp23_acto" / "_agent_eval" / "refs" / "table5.ref.json" + ), + "table6": ( + Path.home() / "sosp23_acto" / "_agent_eval" / "refs" / "table6.ref.json" + ), + "table7": ( + Path.home() / "sosp23_acto" / "_agent_eval" / "refs" / "table7.ref.json" + ), + "table8": ( + Path.home() / "sosp23_acto" / "_agent_eval" / "refs" / "table8.ref.json" + ), + }, + similarity_ratio = 0.75, +) + + +def main(argv: list[str]) -> int: + verbose = "--verbose" in argv + + results: Dict[str, int] = {} score = 0 - for cls in (OracleEnvSetup, OracleArtifactBuild, OracleBenchmarkPrep, OracleExperimentRuns): - checker = cls() - ok = checker.run() - name = cls.__name__ - logger.info(f"{name}: {'PASS' if ok else 'FAIL'}") - if ok: - results[name] = 1 - score += 1 - else: - results[name] = 0 - - logger.info(f"Agent scores: {results}") + + logger_name = os.environ.get("EVAL_LOGGER_NAME", "ACTO-EVAL") + logger = get_logger(LoggerConfig(root_name = logger_name)) + + env_checker = OracleEnvSetup(config = ACTO_CONFIG, logger = logger) + score += record_result( + logger, results, type(env_checker).__name__, env_checker.run(verbose = verbose) + ) + + build_checker = OracleArtifactBuild(config = ACTO_CONFIG, logger = logger) + score += record_result( + logger, results, type(build_checker).__name__, build_checker.run(verbose = verbose) + ) + + logger.info("Agent scores: %s", results) return score if __name__ == "__main__": - main() \ No newline at end of file + raise SystemExit(main(sys.argv[1:])) \ No newline at end of file diff --git a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_artifact_build.py b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_artifact_build.py index 3cd57bd8..34ec8eff 100644 --- a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_artifact_build.py +++ b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_artifact_build.py @@ -1,82 +1,128 @@ -import os -import subprocess -from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple -from pathlib import Path +"""Artifact build oracle. + +This module defines a concrete artifact-build oracle that declares build commands +using the primitives in oracle_artifact_build_primitives.py. + +The oracle is intentionally simple: it declares an ordered list of build command +requirements and relies on the base class to execute them, log results, and +produce a structured report for main.py. + +An EntryConfig instance is expected to be provided by main.py and must include +repository_paths entries for any referenced repo_key values. +""" -from utils import REPO_DIR -from utils import logger +from __future__ import annotations +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from evaluator.oracle_artifact_build_primitives import ( + BuildCommandRequirement, + BuildRequirement, + OracleArtifactBuildBase, +) +from evaluator.utils import EntryConfig +from pathlib import Path +import logging -@dataclass(frozen=True) +@dataclass(frozen = True, slots = True, kw_only = True) class BuildTarget: + """Declarative description of one build command to run. + + Attributes: + name: Display name used in logs and reports. + repo_key: Key into EntryConfig.repository_paths. + command: argv-style command to execute. + cwd_relative: Optional subdirectory within the repo to execute from. + optional: If True, failures are reported as warnings instead of errors. + timeout_seconds: Per-command timeout. + env_overrides: Environment variables to override for the command. + """ + name: str repo_key: str - cmd: List[str] + command: Sequence[str] + cwd_relative: Path | None = None + optional: bool = False + timeout_seconds: float = 60.0 + env_overrides: Mapping[str, str] = field(default_factory = dict) + + def __post_init__(self) -> None: + if not self.name: + raise ValueError("BuildTarget.name must be non-empty") + if not self.repo_key: + raise ValueError(f"{self.name}: repo_key must be non-empty") + if not self.command: + raise ValueError(f"{self.name}: command must be non-empty") + if self.timeout_seconds <= 0: + raise ValueError(f"{self.name}: timeout_seconds must be > 0") + if self.cwd_relative is not None and not isinstance(self.cwd_relative, Path): + object.__setattr__(self, "cwd_relative", Path(self.cwd_relative)) -BUILD_TARGETS: List[BuildTarget] = [ + object.__setattr__(self, "command", tuple(self.command)) + + +DEFAULT_BUILD_TARGETS: tuple[BuildTarget, ...] = ( BuildTarget( - name="acto", - repo_key="acto", - cmd=["make", "lib"], + name = "acto: make lib", + repo_key = "acto", + command = ("make", "lib"), + timeout_seconds = 60.0, ), -] +) -class OracleArtifactBuild: +class OracleArtifactBuild(OracleArtifactBuildBase): + """The artifact build oracle.""" - def __init__(self) -> None: - self.repo_dir = REPO_DIR + _DEFAULT_TARGET_SPECS: tuple[tuple[str, tuple[str, ...], float], ...] = ( + ("acto: make lib", ("make", "lib"), 60.0), + ) - def run_shell_command( + def __init__( self, - cmd: Iterable[str], - cwd: Optional[Path] = None, - ) -> Tuple[int, str, str]: - """ - Run a command and return (rc, stdout, stderr) tuple. - """ - try: - cp = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - cwd=str(cwd) if cwd is not None else None, + *, + config: EntryConfig, + logger: logging.Logger, + targets: Sequence[BuildTarget] | None = None, + ) -> None: + super().__init__(logger = logger) + self._config = config + + if targets is None: + targets = self._make_default_targets(config) + + self._targets = tuple(targets) + names = [t.name for t in self._targets] + if len(names) != len(set(names)): + raise ValueError(f"Duplicate build target names: {names!r}") + + def _make_default_targets(self, config: EntryConfig) -> tuple[BuildTarget, ...]: + """Creates the default BuildTarget list for this config.""" + repo_key = config.name + return tuple( + BuildTarget( + name = name, + repo_key = repo_key, + command = command, + timeout_seconds = timeout_seconds, + ) + for (name, command, timeout_seconds) in self._DEFAULT_TARGET_SPECS + ) + + def requirements(self) -> Sequence[BuildRequirement]: + """Returns an ordered list of build requirements to validate.""" + requirements: list[BuildRequirement] = [] + for target in self._targets: + requirements.append( + BuildCommandRequirement( + name = target.name, + optional = target.optional, + cwd = self._config.repository_paths[self._config.name], + command = target.command, + cwd_relative = target.cwd_relative, + timeout_seconds = target.timeout_seconds, + env_overrides = target.env_overrides, + ) ) - return cp.returncode, cp.stdout or "", cp.stderr or "" - except FileNotFoundError: - return 127, "", "" - - def build_target(self, target: BuildTarget) -> Optional[str]: - """ - Build a single target using its configured repository and command. - """ - repo_path = Path(os.path.expanduser(self.repo_dir)) - if not repo_path.exists(): - return f"{target.name} repo directory missing" - - rc, out, err = self.run_shell_command(target.cmd, cwd=repo_path) - if rc != 0: - return f"{target.name} build failed (rc={rc})" - - return None - - def build_check(self): - """ - Run builds for all configured targets and collect failures. - """ - problems: List[str] = [] - for target in BUILD_TARGETS: - msg = self.build_target(target) - if msg: - problems.append(msg) - if problems: - return False, "; ".join(problems) - return True, "" - - def run(self): - ok, why = self.build_check() - logger.info(f"Build: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - return ok + return requirements diff --git a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_env_setup.py b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_env_setup.py index 94b56b83..65899e35 100644 --- a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_env_setup.py +++ b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_env_setup.py @@ -1,208 +1,173 @@ -import os -import re -import shutil -import subprocess -from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple +#!/usr/bin/env python3 +"""Environment setup oracle for the SOSP'23 ACTO artifact. + +Validates: + - Required dependencies and minimum versions where applicable. + - Repository directory exists. + - Ground-truth reference files exist (required). + - Result files exist (optional; typically generated later). +""" + +from __future__ import annotations +from collections.abc import Mapping, Sequence +from evaluator.oracle_env_setup_primitives import ( + DependencyVersionRequirement, + EnvironmentVariableRequirement, + EnvQuantifier, + FilesystemPathRequirement, + OracleEnvSetupBase, + PathType, + Requirement, + VersionCompare, +) +from evaluator.utils import EntryConfig from pathlib import Path - -from utils import HOME, REPO_DIR -from utils import logger - -VersionTuple = Tuple[int, ...] - - -@dataclass(frozen=True) -class Dependency: - name: str - binary: str - cmd: Optional[list] = None - parse_regex: Optional[str] = None - require: Optional[VersionTuple] = None - compare: Optional[str] = None - - -DEPENDENCIES: List[Dependency] = [ - - # Basic tooling - Dependency( - name="git", binary="git" - ), - - # Docker, latest version is okay - Dependency( - name="docker", binary="docker", - ), - - # Python v3.8+ - Dependency( - name="python3", binary="python3", - cmd=["python3", "--version"], parse_regex=r"Python\s+([0-9.]+)", - require=(3, 8), compare="gte", - ), - - # pip3 for Python 3.8+ - Dependency( - name="pip3", binary="pip3", - ), - - # Go toolchain (golang package), latest STL version - Dependency( - name="go", binary="go", - ), - - # Kind v0.20.0 - Dependency( - name="kind", binary="kind", - cmd=["kind", "version"], parse_regex=r"v([0-9.]+)", - require=(0, 20, 0), compare="gte", - ), - - # Kubectl v1.22.9 - Dependency( - name="kubectl", binary="kubectl", - cmd=["kubectl", "version", "--client", "--short"], - parse_regex=r"Client Version:\s+v?([0-9.]+)", - require=(1, 22, 9), compare="gte", - ), -] - - -class OracleEnvSetup: - - def __init__(self) -> None: - # Root of the cloned repositories - self.expected_root_dir = REPO_DIR - - # Go paths that should be present in PATH - self.go_root = HOME / "go" - self.go_bin = self.go_root / "bin" - - # Python virtual environment inside the repo - self.venv_dir = HOME / ".venv" - - def run_shell_command(self, cmd: Iterable[str]) -> Tuple[int, str, str]: - """ - Run a command and return (rc, stdout, stderr) tuple. - """ - try: - cp = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - return cp.returncode, cp.stdout or "", cp.stderr or "" - except FileNotFoundError: - return 127, "", "" - - def parse_version_tuple(self, text: str) -> VersionTuple: - """ - Extract the first version-like token from arbitrary text. - For example, for Java: '1.8.0_422' -> (1, 8, 0) - """ - m = re.search(r"(\d+(?:\.\d+){0,3})", text) - return tuple(int(x) for x in m.group(1).split(".")) if m else () - - def extract_version(self, text: str, pattern: str) -> Tuple[VersionTuple, str]: - """ - Apply regex pattern on a version string. - """ - m = re.search(pattern, text, re.I) - if not m: - return (), "unknown" - ver_str = m.group(1) - return self.parse_version_tuple(ver_str), ver_str - - def cmp_versions(self, found: VersionTuple, required: VersionTuple, mode: str) -> bool: - """ - Compare versions either to be greater or equal to the reference. - """ - if not found: - return False - f, r = list(found), list(required) - while len(f) < len(r): - f.append(0) - while len(r) < len(f): - r.append(0) - return (f == r) if mode == "eq" else (f >= r) - - def paths_check(self): - """ - Check that Python virtual environment is succesfully created - and that Go-related paths are set properly. - """ - problems: List[str] = [] - - # Check repositories exist - if not Path(self.expected_root_dir).exists(): - problems.append(f"{dir} directory not found repository not cloned properly") - - # Check Python virtual environment is created - if not Path(self.venv_dir).exists(): - problems.append(".venv virtual environment missing (run 'python3 -m venv .venv')") - - # Check Go directories exit - if not Path(self.go_root).exists(): - problems.append("$HOME/go directory missing (install golang and configure GOPATH)") - if not Path(self.go_bin).exists(): - problems.append("$HOME/go/bin directory missing (ensure Go tools are installed)") - - # Check PATH contains Go path - path_env = os.environ.get("PATH", "") - go_root_str = str(self.go_root) - go_bin_str = str(self.go_bin) - if go_root_str not in path_env or go_bin_str not in path_env: - problems.append("PATH missing $HOME/go or $HOME/go/bin " - "(export PATH=$HOME/go:$HOME/go/bin:$PATH)") - - if problems: - return False, "; ".join(problems) - return True, "" - - def check_dependency(self, dep: Dependency) -> Optional[str]: - """ - Core method that checks whether a certain dependency of a version - equal or greather than a reference version is installed. - """ - if shutil.which(dep.binary) is None: - return f"{dep.name} missing" - - # If no version information is required, presence is enough - if dep.cmd is None and dep.parse_regex is None and dep.require is None: - return None - - rc, out, err = self.run_shell_command(dep.cmd or []) - text = (out + "\n" + err).strip() - - if dep.parse_regex and dep.require and dep.compare: - ver_tuple, ver_str = self.extract_version(text, dep.parse_regex) - if not ver_tuple: - return f"{dep.name} version unreadable" - ok = self.cmp_versions(ver_tuple, dep.require, dep.compare) - cmp_word = "==" if dep.compare == "eq" else ">=" - want = ".".join(map(str, dep.require)) - return None if ok else f"{dep.name} {cmp_word} {want} not met (got {ver_str})" - - return f"{dep.name} check misconfigured" - - def prereqs_check(self): - problems: List[str] = [] - for dep in DEPENDENCIES: - msg = self.check_dependency(dep) - if msg: - problems.append(msg) - if problems: - return False, "; ".join(problems) - return True, "" - - def run(self): - results = [] - - ok, why = self.prereqs_check() - logger.info(f"Prerequisites: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - results.append(ok) - - ok, why = self.paths_check() - logger.info(f"Paths: {'PASS' if ok else 'FAIL' + (' - ' + why if why else '')}") - results.append(ok) - - if all(results): - return True - - return False \ No newline at end of file +import logging + + +def _required_path(paths: Mapping[str, Path], key: str, *, label: str) -> Path: + """Returns a required path from a mapping with a clear error message. + + Args: + paths: Mapping containing paths. + key: Required key. + label: Label used in error messages. + + Returns: + The path from the mapping. + + Raises: + ValueError: If the key is missing. + """ + try: + return paths[key] + except KeyError as exc: + raise ValueError(f"Missing {label}[{key!r}] in EntryConfig") from exc + + +class OracleEnvSetup(OracleEnvSetupBase): + """Validates environment prerequisites for the ACTO _agent_eval bundle.""" + + def __init__(self, *, config: EntryConfig, logger: logging.Logger) -> None: + super().__init__(logger) + self._config = config + + def requirements(self) -> Sequence[Requirement]: + """Returns an ordered list of requirements to validate.""" + repo_root = _required_path( + self._config.repository_paths, + self._config.name, + label = "repository_paths", + ) + + if not self._config.ground_truth_paths: + raise ValueError("EntryConfig.ground_truth_paths must be non-empty") + + home_dir = self._config.home_dir + venv_dir = home_dir / ".venv" + go_root = home_dir / "go" + go_bin = go_root / "bin" + + reqs: list[Requirement] = [ + # Docker 23.0.0+ + DependencyVersionRequirement( + name = "docker", + command = ("docker", "--version"), + required_version = (23, 0, 0), + compare = VersionCompare.GEQ, + ), + # pip 23.0.1+ + DependencyVersionRequirement( + name = "pip3", + command = ("pip3", "--version"), + required_version = (23, 0, 1), + compare = VersionCompare.GEQ, + ), + # Python 3.8+ + DependencyVersionRequirement( + name = "python3", + command = ("python3", "--version"), + required_version = (3, 8, 0), + compare = VersionCompare.GEQ, + version_regex = r"Python\s+([0-9.]+)", + ), + # Go 1.20+ + DependencyVersionRequirement( + name = "go", + command = ("go", "version"), + required_version = (1, 20, 0), + compare = VersionCompare.GEQ, + version_regex = r"go(\d+\.\d+(?:\.\d+)?)", + ), + # kind 0.20.0+ + DependencyVersionRequirement( + name = "kind", + command = ("kind", "version"), + required_version = (0, 20, 0), + compare = VersionCompare.GEQ, + version_regex = r"v([0-9.]+)", + ), + # kubectl 1.22.9+ + DependencyVersionRequirement( + name = "kubectl", + command = ("kubectl", "version", "--client", "--short"), + required_version = (1, 22, 9), + compare = VersionCompare.GEQ, + version_regex = r"Client Version:\s+v?([0-9.]+)", + ), + # Directory checks + FilesystemPathRequirement( + name = "repo_root_exists", + path = repo_root, + path_type = PathType.DIRECTORY, + ), + FilesystemPathRequirement( + name = "venv_exists", + path = venv_dir, + path_type = PathType.DIRECTORY, + ), + FilesystemPathRequirement( + name = "go_root_exists", + path = go_root, + path_type = PathType.DIRECTORY, + ), + FilesystemPathRequirement( + name = "go_bin_exists", + path = go_bin, + path_type = PathType.DIRECTORY, + ), + # PATH checks for Go + EnvironmentVariableRequirement( + name = "PATH_contains_go_root", + env_var = "PATH", + expected = str(go_root), + quantifier = EnvQuantifier.CONTAINS, + ), + EnvironmentVariableRequirement( + name = "PATH_contains_go_bin", + env_var = "PATH", + expected = str(go_bin), + quantifier = EnvQuantifier.CONTAINS, + ), + ] + + for key, path in sorted(self._config.ground_truth_paths.items()): + reqs.append( + FilesystemPathRequirement( + name = f"ground_truth[{key}]", + path = path, + path_type = PathType.FILE, + ) + ) + + for key, path in sorted(self._config.results_paths.items()): + reqs.append( + FilesystemPathRequirement( + name = f"results[{key}]", + optional = True, + path = path, + path_type = PathType.FILE, + ) + ) + + return tuple(reqs) diff --git a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_experiment_runs.py b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_experiment_runs.py index 1cc8355c..c6932de3 100644 --- a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_experiment_runs.py +++ b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/oracle_experiment_runs.py @@ -1,740 +1,615 @@ -import json -from dataclasses import dataclass -from pathlib import Path -from typing import Tuple - -from utils import RESULTS_PATH_TABLES, REFERENCE_PATH_TABLES, SIMILARITY_RATIO, logger - - -@dataclass(frozen=True) -class Table5Row: - operator: str - undesired_state: int - system_error: int - operator_error: int - recovery_failure: int - total: int - - -@dataclass(frozen=True) -class Table6Row: - symptom: str - bugs: int - - -@dataclass(frozen=True) -class Table7Row: - test_oracle: str - bugs: int - - -@dataclass(frozen=True) -class Table8Row: - operator: str - operations: int - - -class OracleExperimentRuns: - - def __init__(self) -> None: - # File paths for each table - self.table5_results_path = Path(RESULTS_PATH_TABLES["table5"]) - self.table5_reference_path = Path(REFERENCE_PATH_TABLES["table5"]) - self.table6_results_path = Path(RESULTS_PATH_TABLES["table6"]) - self.table6_reference_path = Path(REFERENCE_PATH_TABLES["table6"]) - self.table7_results_path = Path(RESULTS_PATH_TABLES["table7"]) - self.table7_reference_path = Path(REFERENCE_PATH_TABLES["table7"]) - self.table8_results_path = Path(RESULTS_PATH_TABLES["table8"]) - self.table8_reference_path = Path(REFERENCE_PATH_TABLES["table8"]) - - # Parsed rows for tables - self.table5_rows: list[Table5Row] = [] - self.table6_rows: list[Table6Row] = [] - self.table7_rows: list[Table7Row] = [] - self.table8_rows: list[Table8Row] = [] - - # Totals - self.table5_exp_total: int | None = None - self.table5_ref_total: int | None = None - self.table6_exp_total: int | None = None - self.table6_ref_total: int | None = None - self.table7_exp_total: int | None = None - self.table7_ref_total: int | None = None - self.table8_exp_total: int | None = None - self.table8_ref_total: int | None = None - - # Raw non-empty lines from result files - self._table5_raw_lines: list[str] = [] - self._table6_raw_lines: list[str] = [] - self._table7_raw_lines: list[str] = [] - self._table8_raw_lines: list[str] = [] - - def is_separator_line(self, line: str) -> bool: - """ - Return True if this is a header separator line (Markdown spaces and dashes). - """ - stripped = line.strip() - if not stripped: - return False - return all(ch in "- " for ch in stripped) - - def parse_int(self, text: str) -> Tuple[bool, int]: - """ - Parse a numeric string into an int. - """ - try: - return True, int(text.replace(",", "")) - except ValueError: - return False, 0 - - # ----------------------- - # TABLE-5 helpers - # ----------------------- - - def load_table5(self) -> Tuple[bool, str]: - """ - Load raw TABLE-5 file into memory. - """ - if not self.table5_reference_path.exists(): - return False, f"{self.table5_reference_path} (TABLE-5 reference) not found" - - if not self.table5_results_path.exists(): - return False, f"{self.table5_results_path} (TABLE-5 results) not found" - - text = self.table5_results_path.read_text(encoding="utf-8") - lines = [line.rstrip("\n") for line in text.splitlines() if line.strip()] - if not lines: - return False, f"{self.table5_results_path} is empty" - - self._table5_raw_lines = lines - return True, "" - - def parse_table5(self) -> Tuple[bool, str]: - """ - Parse TABLE-5 and extract the bottom-right 'Total' cell. - """ - EXPECTED_HEADERS: list[str] = [ - "Operator", - "Undesired State", - "System Error", - "Operator Error", - "Recovery Failure", - "Total", - ] - - header_line: str | None = None - data_lines: list[str] = [] - saw_separator = False - - for line in self._table5_raw_lines: - if header_line is None: - header_line = line - continue - - if not saw_separator and self.is_separator_line(line): - saw_separator = True - continue - - if saw_separator: - data_lines.append(line) - - if header_line is None: - return False, "TABLE-5: no table header found" - - if any(h not in header_line for h in EXPECTED_HEADERS): - return False, f"TABLE-5: unexpected headers: {header_line!r}" - - self.table5_rows = [] - self.table5_exp_total = None - - for line in data_lines: - parts = line.split() - if len(parts) != 6: - return False, f"TABLE-5: row has {len(parts)} fields, expected 6: {line!r}" - - operator = parts[0] - - ok, undesired = self.parse_int(parts[1]) - if not ok: - return False, f"TABLE-5: unparseable int in 'Undesired State': {parts[1]!r}" - - ok, system_err = self.parse_int(parts[2]) - if not ok: - return False, f"TABLE-5: unparseable int in 'System Error': {parts[2]!r}" - - ok, operator_err = self.parse_int(parts[3]) - if not ok: - return False, f"TABLE-5: unparseable int in 'Operator Error': {parts[3]!r}" - - ok, recovery_fail = self.parse_int(parts[4]) - if not ok: - return False, f"TABLE-5: unparseable int in 'Recovery Failure': {parts[4]!r}" - - ok, total = self.parse_int(parts[5]) - if not ok: - return False, f"TABLE-5: unparseable int in 'Total': {parts[5]!r}" - - row = Table5Row( - operator=operator, - undesired_state=undesired, - system_error=system_err, - operator_error=operator_err, - recovery_failure=recovery_fail, - total=total, - ) - self.table5_rows.append(row) +"""Experiment runs oracle for TABLE-5..TABLE-8 (elementwise sequence checks). - if operator == "Total": - self.table5_exp_total = total +This oracle validates that the tables produced by running the full set of +experiments match reference JSON values elementwise (aligned by label) +within a configurable similarity threshold. +""" - if self.table5_exp_total is None: - return False, "TABLE-5: missing 'Total' row" +from __future__ import annotations - return True, "" +import abc +import dataclasses +import json +import math +from collections.abc import Callable, Sequence +from pathlib import Path - def load_table5_reference(self) -> Tuple[bool, str]: - """ - Load TABLE-5 reference JSON. - """ - path = self.table5_reference_path +from utils import RESULTS_PATH_TABLES, REFERENCE_PATH_TABLES, SIMILARITY_RATIO, logger - if not path.exists(): - return False, f"{path} not found" +# Update this import path to wherever you placed the base primitives module. +from evaluator.oracles.experiment_runs_oracle_primitives import ( # pylint: disable=g-import-not-at-top + ElementwiseMetrics, + ExperimentRunRequirement, + ExperimentRunsContext, + OracleExperimentRunsBase, +) - try: - raw = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError as e: - return False, f"{path} invalid JSON: {e}" - if not isinstance(raw, dict): - return False, f"{path} must contain a JSON object" +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- - totals = raw.get("totals") - if not isinstance(totals, dict): - return False, f"{path} missing 'totals' object" - if "total_all" not in totals: - return False, f"{path} missing 'total_all' field in 'totals'" +_LabeledIntPairs = list[tuple[str, int]] - try: - self.table5_ref_total = int(totals["total_all"]) - except (TypeError, ValueError): - return False, f"{path} field 'totals.total_all' must be an integer" - return True, "" +def _normalize_label(label: str) -> str: + """Canonicalizes labels for alignment (keeps case; collapses whitespace).""" + return " ".join(label.split()).strip() - # ----------------------- - # TABLE-6 helpers - # ----------------------- - def load_table6(self) -> Tuple[bool, str]: - """ - Load raw TABLE-6 file into memory. - """ - if not self.table6_reference_path.exists(): - return False, f"{self.table6_reference_path} (TABLE-6 reference) not found" +def _is_separator_line(line: str) -> bool: + """Returns True if this is a header separator line (Markdown spaces/dashes).""" + stripped = line.strip() + if not stripped: + return False + return all(ch in "- " for ch in stripped) - if not self.table6_results_path.exists(): - return False, f"{self.table6_results_path} (TABLE-6 results) not found" - text = self.table6_results_path.read_text(encoding="utf-8") - lines = [line.rstrip("\n") for line in text.splitlines() if line.strip()] - if not lines: - return False, f"{self.table6_results_path} is empty" +def _read_nonempty_lines(path: Path) -> list[str]: + """Reads file and returns non-empty lines with trailing newlines stripped.""" + text = path.read_text(encoding = "utf-8") + return [line.rstrip("\n") for line in text.splitlines() if line.strip()] - self._table6_raw_lines = lines - return True, "" - def parse_table6(self) -> Tuple[bool, str]: - """ - Parse TABLE-6 and compute the sum of all '# Bugs' values. - """ - EXPECTED_HEADERS: list[str] = [ - "Consequence", - "# Bugs", - ] +def _parse_int_token(text: str, *, label: str) -> int: + """Parses an integer allowing commas.""" + try: + return int(text.replace(",", "")) + except ValueError as exc: + raise ValueError(f"{label}: unparseable int: {text!r}") from exc - header_line: str | None = None - data_lines: list[str] = [] - saw_separator = False - for line in self._table6_raw_lines: - if header_line is None: - header_line = line - continue +def _extract_table_body(lines: Sequence[str], *, label: str) -> tuple[str, list[str]]: + """Splits raw lines into (header_line, data_lines) after separator.""" + if not lines: + raise ValueError(f"{label}: table is empty") - if not saw_separator and self.is_separator_line(line): - saw_separator = True - continue + header_line = lines[0] + saw_separator = False + data_lines: list[str] = [] - if saw_separator: - data_lines.append(line) + for line in lines[1:]: + if not saw_separator and _is_separator_line(line): + saw_separator = True + continue + if saw_separator: + data_lines.append(line) - if header_line is None: - return False, "TABLE-6: no table header found" + if not saw_separator: + raise ValueError(f"{label}: missing header separator line") - if any(h not in header_line for h in EXPECTED_HEADERS): - return False, f"TABLE-6: unexpected headers: {header_line!r}" + return header_line, data_lines - self.table6_rows = [] - self.table6_exp_total = 0 - for line in data_lines: - parts = line.split() - if len(parts) < 2: - return False, f"TABLE-6: row has {len(parts)} fields, expected at least 2: {line!r}" +def _load_json_object(path: Path, *, label: str) -> dict: + """Loads a JSON file expected to contain an object.""" + try: + raw = json.loads(path.read_text(encoding = "utf-8")) + except json.JSONDecodeError as exc: + raise ValueError(f"{label}: invalid JSON: {exc}") from exc - label = " ".join(parts[:-1]) - bugs_str = parts[-1] + if not isinstance(raw, dict): + raise ValueError(f"{label}: must contain a JSON object") + return raw - ok, bugs = self.parse_int(bugs_str) - if not ok: - return False, f"TABLE-6: unparseable int in '# Bugs': {bugs_str!r}" - row = Table6Row( - symptom=label, - bugs=bugs, - ) - self.table6_rows.append(row) - self.table6_exp_total += bugs +def _get_first_str_field( + obj: dict, + *, + keys: Sequence[str], + label: str, +) -> str: + """Returns the first present string field among keys.""" + for k in keys: + if k in obj: + v = obj[k] + if isinstance(v, str) and v.strip(): + return v + raise ValueError(f"{label}: field {k!r} must be a non-empty string") + raise ValueError(f"{label}: missing any of fields {list(keys)!r}") - return True, "" - def load_table6_reference(self) -> Tuple[bool, str]: - """ - Load TABLE-6 reference JSON. - """ - path = self.table6_reference_path - - if not path.exists(): - return False, f"{path} not found" - - try: - raw = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError as e: - return False, f"{path} invalid JSON: {e}" - - if not isinstance(raw, dict): - return False, f"{path} must contain a JSON object" - - symptoms = raw.get("symptoms") - if not isinstance(symptoms, list): - return False, f"{path} missing 'symptoms' list" - - total = 0 - for idx, obj in enumerate(symptoms): - if not isinstance(obj, dict): - return False, f"{path} entry #{idx} in 'symptoms' is not an object" - if "bugs" not in obj: - return False, f"{path} entry #{idx} in 'symptoms' missing 'bugs' tag" +def _get_first_int_field( + obj: dict, + *, + keys: Sequence[str], + label: str, +) -> int: + """Returns the first present int-ish field among keys.""" + for k in keys: + if k in obj: + v = obj[k] try: - total += int(obj["bugs"]) - except (TypeError, ValueError): - return False, f"{path} entry #{idx} in 'symptoms' has non-integer 'bugs' tag" - - self.table6_ref_total = total - return True, "" - - # ----------------------- - # TABLE-7 helpers - # ----------------------- - - def load_table7(self) -> Tuple[bool, str]: - """ - Load raw TABLE-7 file into memory. - """ - if not self.table7_reference_path.exists(): - return False, f"{self.table7_reference_path} (TABLE-7 reference) not found" - - if not self.table7_results_path.exists(): - return False, f"{self.table7_results_path} (TABLE-7 results) not found" + return int(v) + except (TypeError, ValueError) as exc: + raise ValueError(f"{label}: field {k!r} must be an integer") from exc + raise ValueError(f"{label}: missing any of fields {list(keys)!r}") + + +def _pairs_to_unique_map( + pairs: _LabeledIntPairs, + *, + label: str, +) -> tuple[dict[str, int], list[str]]: + """Builds a label->value map and preserves the (normalized) order from pairs.""" + mapping: dict[str, int] = {} + order: list[str] = [] + for raw_label, value in pairs: + key = _normalize_label(raw_label) + if key in mapping: + raise ValueError(f"{label}: duplicate label: {raw_label!r}") + mapping[key] = value + order.append(key) + return mapping, order + + +def _summarize_labels(items: Sequence[str], *, max_items: int = 10) -> str: + shown = list(items[:max_items]) + suffix = f", ... ({len(items) - len(shown)} more)" if len(items) > len(shown) else "" + return ", ".join(shown) + suffix + + +def _align_by_reference( + observed: _LabeledIntPairs, + reference: _LabeledIntPairs, + *, + label: str, +) -> tuple[list[str], list[float], list[float]]: + """Aligns observed/reference by label using reference order.""" + obs_map, _ = _pairs_to_unique_map(observed, label = f"{label}: observed") + ref_map, ref_order = _pairs_to_unique_map(reference, label = f"{label}: reference") + + obs_keys = set(obs_map.keys()) + ref_keys = set(ref_map.keys()) + + missing = sorted(ref_keys - obs_keys) + extra = sorted(obs_keys - ref_keys) + + if missing: + raise ValueError( + f"{label}: missing labels in observed: {_summarize_labels(missing)}" + ) + if extra: + raise ValueError( + f"{label}: extra labels in observed: {_summarize_labels(extra)}" + ) - text = self.table7_results_path.read_text(encoding="utf-8") - lines = [line.rstrip("\n") for line in text.splitlines() if line.strip()] - if not lines: - return False, f"{self.table7_results_path} is empty" - - self._table7_raw_lines = lines - return True, "" - - def parse_table7(self) -> Tuple[bool, str]: - """ - Parse TABLE-7 and compute the sum of integer bug counts (ignoring percentages). - """ - EXPECTED_HEADERS: list[str] = [ - "Test Oracle", - "# Bugs (Percentage)", - ] - - header_line: str | None = None - data_lines: list[str] = [] - saw_separator = False - - for line in self._table7_raw_lines: - if header_line is None: - header_line = line - continue - - if not saw_separator and self.is_separator_line(line): - saw_separator = True - continue - - if saw_separator: - data_lines.append(line) - - if header_line is None: - return False, "TABLE-7: no table header found" - - if any(h not in header_line for h in EXPECTED_HEADERS): - return False, f"TABLE-7: unexpected headers: {header_line!r}" - - self.table7_rows = [] - self.table7_exp_total = 0 - - for line in data_lines: - parts = line.split() - if len(parts) < 2: - return False, f"TABLE-7: row has {len(parts)} fields, expected at least 2: {line!r}" - - # Ignore last token that is in "(xx.xx%)" format, the integer is second last. - last = parts[-1] - if last.startswith("(") and last.endswith("%)"): - if len(parts) < 3: - return False, f"TABLE-7: malformed row with percentage but no integer: {line!r}" - bugs_str = parts[-2] - label = " ".join(parts[:-2]) - else: - bugs_str = parts[-1] - label = " ".join(parts[:-1]) - - ok, bugs = self.parse_int(bugs_str) - if not ok: - return False, f"TABLE-7: unparseable int in '# Bugs': {bugs_str!r}" - - row = Table7Row( - test_oracle=label, - bugs=bugs, + observed_values: list[float] = [] + reference_values: list[float] = [] + for key in ref_order: + observed_values.append(float(obs_map[key])) + reference_values.append(float(ref_map[key])) + + return ref_order, observed_values, reference_values + + +def _collect_threshold_mismatches( + labels: Sequence[str], + observed_values: Sequence[float], + reference_values: Sequence[float], + *, + threshold: float, + max_items: int, +) -> tuple[int, list[str]]: + """Returns (num_failed, mismatch_lines) with score diagnostics.""" + scores = ElementwiseMetrics.similarity_scores(observed_values, reference_values) + failures = 0 + lines: list[str] = [] + for i, s in enumerate(scores): + if s.cmp < threshold: + failures += 1 + if len(lines) < max_items: + lines.append( + f"[{labels[i]}] score={s.cmp:.6f} observed={s.observed!r} reference={s.reference!r}" + ) + return failures, lines + + +# --------------------------------------------------------------------------- +# Experiment results parsing (Markdown-style) +# --------------------------------------------------------------------------- + +_TABLE5_FIELDS: tuple[str, ...] = ( + "undesired_state", + "system_error", + "operator_error", + "recovery_failure", +) + + +def _parse_table5_results_pairs(lines: Sequence[str], *, field: str) -> _LabeledIntPairs: + """Parses TABLE-5 results and returns (operator, value) for one field.""" + if field not in _TABLE5_FIELDS: + raise ValueError(f"TABLE-5: unsupported field: {field!r}") + + expected_headers = [ + "Operator", + "Undesired State", + "System Error", + "Operator Error", + "Recovery Failure", + "Total", + ] + header_line, data_lines = _extract_table_body(lines, label = "TABLE-5") + if any(h not in header_line for h in expected_headers): + raise ValueError(f"TABLE-5: unexpected headers: {header_line!r}") + + out: _LabeledIntPairs = [] + for line in data_lines: + parts = line.split() + if len(parts) != 6: + raise ValueError( + f"TABLE-5: row has {len(parts)} fields, expected 6: {line!r}" ) - self.table7_rows.append(row) - self.table7_exp_total += bugs - - return True, "" - def load_table7_reference(self) -> Tuple[bool, str]: - """ - Load TABLE-7 reference JSON. - """ - path = self.table7_reference_path - - if not path.exists(): - return False, f"{path} not found" + operator = parts[0] + if operator == "Total": + # Ignore aggregate totals row. + continue + + undesired = _parse_int_token(parts[1], label = "TABLE-5 undesired_state") + system_err = _parse_int_token(parts[2], label = "TABLE-5 system_error") + operator_err = _parse_int_token(parts[3], label = "TABLE-5 operator_error") + recovery_fail = _parse_int_token(parts[4], label = "TABLE-5 recovery_failure") + # parts[5] is the per-operator "Total" column; ignore by design. + + if field == "undesired_state": + out.append((operator, undesired)) + elif field == "system_error": + out.append((operator, system_err)) + elif field == "operator_error": + out.append((operator, operator_err)) + else: + out.append((operator, recovery_fail)) + + return out + + +def _parse_table6_results_pairs(lines: Sequence[str]) -> _LabeledIntPairs: + """Parses TABLE-6 results and returns (label, bugs) pairs.""" + expected_headers = ["Consequence", "# Bugs"] + header_line, data_lines = _extract_table_body(lines, label = "TABLE-6") + if any(h not in header_line for h in expected_headers): + raise ValueError(f"TABLE-6: unexpected headers: {header_line!r}") + + out: _LabeledIntPairs = [] + for line in data_lines: + parts = line.split() + if len(parts) < 2: + raise ValueError( + f"TABLE-6: row has {len(parts)} fields, expected at least 2: {line!r}" + ) + label = " ".join(parts[:-1]) + bugs = _parse_int_token(parts[-1], label = "TABLE-6 bugs") + out.append((label, bugs)) + return out + + +def _parse_table7_results_pairs(lines: Sequence[str]) -> _LabeledIntPairs: + """Parses TABLE-7 results and returns (label, bugs) pairs (ignores % token).""" + expected_headers = ["Test Oracle", "# Bugs (Percentage)"] + header_line, data_lines = _extract_table_body(lines, label = "TABLE-7") + if any(h not in header_line for h in expected_headers): + raise ValueError(f"TABLE-7: unexpected headers: {header_line!r}") + + out: _LabeledIntPairs = [] + for line in data_lines: + parts = line.split() + if len(parts) < 2: + raise ValueError( + f"TABLE-7: row has {len(parts)} fields, expected at least 2: {line!r}" + ) - try: - raw = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError as e: - return False, f"{path} invalid JSON: {e}" - - if not isinstance(raw, dict): - return False, f"{path} must contain a JSON object" - - test_oracles = raw.get("test_oracles") - if not isinstance(test_oracles, list): - return False, f"{path} missing 'test_oracles' list" - - total = 0 - for idx, obj in enumerate(test_oracles): - if not isinstance(obj, dict): - return False, f"{path} entry #{idx} in 'test_oracles' is not an object" - if "bugs" not in obj: - return False, f"{path} entry #{idx} in 'test_oracles' missing 'bugs' tag" - try: - total += int(obj["bugs"]) - except (TypeError, ValueError): - return False, f"{path} entry #{idx} in 'test_oracles' has non-integer 'bugs' tag" + last = parts[-1] + if last.startswith("(") and last.endswith("%)"): + if len(parts) < 3: + raise ValueError( + f"TABLE-7: malformed row with percentage but no integer: {line!r}" + ) + bugs_str = parts[-2] + label = " ".join(parts[:-2]) + else: + bugs_str = parts[-1] + label = " ".join(parts[:-1]) - self.table7_ref_total = total - return True, "" + bugs = _parse_int_token(bugs_str, label = "TABLE-7 bugs") + out.append((label, bugs)) - # ----------------------- - # TABLE-8 helpers - # ----------------------- + return out - def load_table8(self) -> Tuple[bool, str]: - """ - Load raw TABLE-8 file into memory. - """ - if not self.table8_reference_path.exists(): - return False, f"{self.table8_reference_path} (TABLE-8 reference) not found" - if not self.table8_results_path.exists(): - return False, f"{self.table8_results_path} (TABLE-8 results) not found" +def _parse_table8_results_pairs(lines: Sequence[str]) -> _LabeledIntPairs: + """Parses TABLE-8 results and returns (operator, operations) pairs.""" + expected_headers = ["Operator", "# Operations"] + header_line, data_lines = _extract_table_body(lines, label = "TABLE-8") + if any(h not in header_line for h in expected_headers): + raise ValueError(f"TABLE-8: unexpected headers: {header_line!r}") - text = self.table8_results_path.read_text(encoding="utf-8") - lines = [line.rstrip("\n") for line in text.splitlines() if line.strip()] - if not lines: - return False, f"{self.table8_results_path} is empty" + out: _LabeledIntPairs = [] + for line in data_lines: + parts = line.split() + if len(parts) != 2: + raise ValueError( + f"TABLE-8: row has {len(parts)} fields, expected 2: {line!r}" + ) + operator = parts[0] + ops = _parse_int_token(parts[1], label = "TABLE-8 operations") + out.append((operator, ops)) - self._table8_raw_lines = lines - return True, "" + return out - def parse_table8(self) -> Tuple[bool, str]: - """ - Parse TABLE-8 and compute the sum of '# Operations'. - """ - EXPECTED_HEADERS: list[str] = [ - "Operator", - "# Operations", - ] - header_line: str | None = None - data_lines: list[str] = [] - saw_separator = False +# --------------------------------------------------------------------------- +# Reference results parsing (JSON) +# --------------------------------------------------------------------------- - for line in self._table8_raw_lines: - if header_line is None: - header_line = line - continue +def _parse_table5_reference_pairs(path: Path, *, field: str) -> _LabeledIntPairs: + """Parses TABLE-5 reference JSON and returns (operator, value) for one field.""" + if field not in _TABLE5_FIELDS: + raise ValueError(f"TABLE-5 reference: unsupported field: {field!r}") - if not saw_separator and self.is_separator_line(line): - saw_separator = True - continue + raw = _load_json_object(path, label = "TABLE-5 reference") + ops = raw.get("operators") + if not isinstance(ops, list): + raise ValueError("TABLE-5 reference: missing 'operators' list") - if saw_separator: - data_lines.append(line) + out: _LabeledIntPairs = [] + for idx, obj in enumerate(ops): + if not isinstance(obj, dict): + raise ValueError(f"TABLE-5 reference: entry #{idx} is not an object") - if header_line is None: - return False, "TABLE-8: no table header found" + operator = _get_first_str_field( + obj, + keys = ("operator", "label"), + label = f"TABLE-5 reference: entry #{idx}", + ) + value = _get_first_int_field( + obj, + keys = (field, "value"), + label = f"TABLE-5 reference: entry #{idx}", + ) + # Ignore obj["total"] by design. + out.append((operator, value)) - if any(h not in header_line for h in EXPECTED_HEADERS): - return False, f"TABLE-8: unexpected headers: {header_line!r}" + return out - self.table8_rows = [] - self.table8_exp_total = 0 - for line in data_lines: - parts = line.split() - if len(parts) != 2: - return False, f"TABLE-8: row has {len(parts)} fields, expected 2: {line!r}" +def _parse_table6_reference_pairs(path: Path) -> _LabeledIntPairs: + raw = _load_json_object(path, label = "TABLE-6 reference") + items = raw.get("symptoms") + if not isinstance(items, list): + raise ValueError("TABLE-6 reference: missing 'symptoms' list") - operator = parts[0] - ops_str = parts[1] + out: _LabeledIntPairs = [] + for idx, obj in enumerate(items): + if not isinstance(obj, dict): + raise ValueError(f"TABLE-6 reference: entry #{idx} is not an object") - ok, ops = self.parse_int(ops_str) - if not ok: - return False, f"TABLE-8: unparseable int in '# Operations': {ops_str!r}" + label = _get_first_str_field( + obj, + keys = ("symptom", "consequence", "label"), + label = f"TABLE-6 reference: entry #{idx}", + ) + value = _get_first_int_field( + obj, + keys = ("bugs", "value"), + label = f"TABLE-6 reference: entry #{idx}", + ) + out.append((label, value)) - row = Table8Row( - operator=operator, - operations=ops, - ) - self.table8_rows.append(row) - self.table8_exp_total += ops + return out - return True, "" - def load_table8_reference(self) -> Tuple[bool, str]: - """ - Load TABLE-8 reference JSON. - """ - path = self.table8_reference_path +def _parse_table7_reference_pairs(path: Path) -> _LabeledIntPairs: + raw = _load_json_object(path, label = "TABLE-7 reference") + items = raw.get("test_oracles") + if not isinstance(items, list): + raise ValueError("TABLE-7 reference: missing 'test_oracles' list") - if not path.exists(): - return False, f"{path} not found" + out: _LabeledIntPairs = [] + for idx, obj in enumerate(items): + if not isinstance(obj, dict): + raise ValueError(f"TABLE-7 reference: entry #{idx} is not an object") - try: - raw = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError as e: - return False, f"{path} invalid JSON: {e}" - - if not isinstance(raw, dict): - return False, f"{path} must contain a JSON object" - - operators = raw.get("operators") - if not isinstance(operators, list): - return False, f"{path} missing 'operators' list" - - total = 0 - for idx, obj in enumerate(operators): - if not isinstance(obj, dict): - return False, f"{path} entry #{idx} in 'operators' is not an object" - if "operations" not in obj: - return False, f"{path} entry #{idx} in 'operators' missing 'operations' tag" - try: - total += int(obj["operations"]) - except (TypeError, ValueError): - return False, f"{path} entry #{idx} in 'operators' has non-integer 'operations'" - - self.table8_ref_total = total - return True, "" - - # ----------------------- - # Shared comparison logic - # ----------------------- - - def totals_within_tolerance(self, found: int, ref: int) -> bool: - """ - Check whether two values are within a given tolerance. - """ - if ref == 0: - return False - return abs(found - ref) <= (1.0 - SIMILARITY_RATIO) * max(abs(found), abs(ref)) - - def compare_table5_against_reference(self) -> Tuple[bool, str]: - """ - Compare TABLE-5 total with the reference total. - """ - if self.table5_exp_total is None: - return False, "TABLE-5: bottom-right total not parsed" - - if self.table5_ref_total is None: - return False, "TABLE-5: reference total_all not loaded" - - found = self.table5_exp_total - ref = self.table5_ref_total - - if not self.totals_within_tolerance(found, ref): - return False, ( - "TABLE-5: bottom-right total differs too much " - f"(got {found}, ref {ref})" - ) + label = _get_first_str_field( + obj, + keys = ("test_oracle", "oracle", "label"), + label = f"TABLE-7 reference: entry #{idx}", + ) + value = _get_first_int_field( + obj, + keys = ("bugs", "value"), + label = f"TABLE-7 reference: entry #{idx}", + ) + out.append((label, value)) - return True, "" + return out - def compare_table6_against_reference(self) -> Tuple[bool, str]: - """ - Compare TABLE-6 total with the reference total. - """ - if self.table6_exp_total is None: - return False, "TABLE6: sum of bugs not computed" - if self.table6_ref_total is None: - return False, "TABLE6: reference total_all not loaded" +def _parse_table8_reference_pairs(path: Path) -> _LabeledIntPairs: + raw = _load_json_object(path, label = "TABLE-8 reference") + items = raw.get("operators") + if not isinstance(items, list): + raise ValueError("TABLE-8 reference: missing 'operators' list") - found = self.table6_exp_total - ref = self.table6_ref_total + out: _LabeledIntPairs = [] + for idx, obj in enumerate(items): + if not isinstance(obj, dict): + raise ValueError(f"TABLE-8 reference: entry #{idx} is not an object") - if not self.totals_within_tolerance(found, ref): - return False, ( - "TABLE6: total bugs differs too much " - f"(got {found}, ref {ref})" + label = _get_first_str_field( + obj, + keys = ("operator", "label"), + label = f"TABLE-8 reference: entry #{idx}", + ) + value = _get_first_int_field( + obj, + keys = ("operations", "value"), + label = f"TABLE-8 reference: entry #{idx}", + ) + out.append((label, value)) + + return out + + +# --------------------------------------------------------------------------- +# Similarity metadata +# --------------------------------------------------------------------------- + +@dataclasses.dataclass(frozen = True, slots = True, kw_only = True) +class SimilarityRequirement(ExperimentRunRequirement): + """Compares two labeled sequences elementwise under a similarity threshold.""" + + label: str + results_path: Path + reference_path: Path + threshold: float + parse_results_fn: Callable[[Sequence[str]], _LabeledIntPairs] + parse_reference_fn: Callable[[Path], _LabeledIntPairs] + max_mismatches_to_report: int = 10 + + def __post_init__(self) -> None: + if not math.isfinite(self.threshold): + raise ValueError(f"{self.name}: threshold must be finite") + if self.threshold < 0.0 or self.threshold > 1.0: + raise ValueError(f"{self.name}: threshold must be in [0, 1]") + if self.max_mismatches_to_report <= 0: + raise ValueError(f"{self.name}: max_mismatches_to_report must be > 0") + object.__setattr__(self, "results_path", Path(self.results_path)) + object.__setattr__(self, "reference_path", Path(self.reference_path)) + + def check(self, ctx: ExperimentRunsContext) -> "CheckResult": + del ctx # Reserved for shared policies/logging. + + # Match legacy behavior: check reference existence first. + if not self.reference_path.exists(): + return CheckResult.failure( + f"{self.reference_path} ({self.label} reference) not found" ) - - return True, "" - - def compare_table7_against_reference(self) -> Tuple[bool, str]: - """ - Compare TABLE-7 total with the reference total. - """ - if self.table7_exp_total is None: - return False, "TABLE7: sum of bugs not computed" - - if self.table7_exp_total is None: - return False, "TABLE7: reference total not loaded" - - found = self.table7_exp_total - ref = self.table7_ref_total - - if not self.totals_within_tolerance(found, ref): - return False, ( - "TABLE7: total bugs differs too much " - f"(got {found}, ref {ref})" + if not self.results_path.exists(): + return CheckResult.failure( + f"{self.results_path} ({self.label} results) not found" ) - return True, "" - - def compare_table8_against_reference(self) -> Tuple[bool, str]: - """ - Compare TABLE-8 total with the reference total. - """ - if self.table8_exp_total is None: - return False, "TABLE8: sum of operations not computed" - - if self.table8_ref_total is None: - return False, "TABLE8: reference total not loaded" + try: + lines = _read_nonempty_lines(self.results_path) + except OSError as exc: + return CheckResult.failure(f"{self.label}: failed to read results: {exc}") + if not lines: + return CheckResult.failure(f"{self.label}: results file is empty") - found = self.table8_exp_total - ref = self.table8_ref_total + try: + observed_pairs = self.parse_results_fn(lines) + reference_pairs = self.parse_reference_fn(self.reference_path) + except ValueError as exc: + return CheckResult.failure(str(exc)) - if not self.totals_within_tolerance(found, ref): - return False, ( - "TABLE8: total operations differs too much " - f"(got {found}, ref {ref})" + try: + aligned_labels, observed_values, reference_values = _align_by_reference( + observed_pairs, + reference_pairs, + label = self.label, ) - - return True, "" - - # ----------------------- - # Invocations - # ----------------------- - - def _run_table(self, label: str, steps): - """ - Run all steps for a single table, stopping on first failure. - """ - for step in steps: - ok, why = step() - if not ok: - suffix = f" - {why}" if why else "" - logger.error(f"{label}: FAIL{suffix}") - return False - - logger.info(f"{label}: PASS") - return True - - def run_table5(self) -> bool: - return self._run_table( - "TABLE-5", - [ - self.load_table5, - self.parse_table5, - self.load_table5_reference, - self.compare_table5_against_reference, - ], + except ValueError as exc: + return CheckResult.failure(str(exc)) + + failures, mismatch_lines = _collect_threshold_mismatches( + aligned_labels, + observed_values, + reference_values, + threshold = self.threshold, + max_items = self.max_mismatches_to_report, ) - def run_table6(self) -> bool: - return self._run_table( - "TABLE-6", - [ - self.load_table6, - self.parse_table6, - self.load_table6_reference, - self.compare_table6_against_reference, - ], - ) + if failures == 0: + return CheckResult.success() - def run_table7(self) -> bool: - return self._run_table( - "TABLE-7", - [ - self.load_table7, - self.parse_table7, - self.load_table7_reference, - self.compare_table7_against_reference, - ], + msg = ( + f"{self.label}: {failures} entries below threshold {self.threshold:.6f}\n" + "mismatches:\n" + "\n".join(mismatch_lines) ) + if failures > len(mismatch_lines): + msg = f"{msg}\n... ({failures - len(mismatch_lines)} more)" + return CheckResult.failure(msg) + + +# --------------------------------------------------------------------------- +# Oracle's main logic +# --------------------------------------------------------------------------- + +class OracleExperimentRuns(OracleExperimentRunsBase): + """Derived oracle that validates TABLE-5..TABLE-8 outputs.""" + + _NAME = "ExperimentRuns" + + def __init__(self, *, threshold: float = SIMILARITY_RATIO) -> None: + super().__init__(logger = logger) + self._threshold = threshold + + self._table5_results_path = Path(RESULTS_PATH_TABLES["table5"]) + self._table5_reference_path = Path(REFERENCE_PATH_TABLES["table5"]) + self._table6_results_path = Path(RESULTS_PATH_TABLES["table6"]) + self._table6_reference_path = Path(REFERENCE_PATH_TABLES["table6"]) + self._table7_results_path = Path(RESULTS_PATH_TABLES["table7"]) + self._table7_reference_path = Path(REFERENCE_PATH_TABLES["table7"]) + self._table8_results_path = Path(RESULTS_PATH_TABLES["table8"]) + self._table8_reference_path = Path(REFERENCE_PATH_TABLES["table8"]) + + def requirements(self) -> Sequence[ExperimentRunRequirement]: + reqs: list[ExperimentRunRequirement] = [] + + # TABLE-5 is parsed as 4 independent labeled sequences (by operator) + for field in _TABLE5_FIELDS: + def _make_results_fn(f: str) -> Callable[[Sequence[str]], _LabeledIntPairs]: + return lambda lines: _parse_table5_results_pairs(lines, field = f) + + def _make_reference_fn(f: str) -> Callable[[Path], _LabeledIntPairs]: + return lambda path: _parse_table5_reference_pairs(path, field = f) + + reqs.append( + SimilarityRequirement( + name = f"TABLE-5/{field}", + label = f"TABLE-5 {field}", + results_path = self._table5_results_path, + reference_path = self._table5_reference_path, + threshold = self._threshold, + parse_results_fn = _make_results_fn(field), + parse_reference_fn = _make_reference_fn(field), + ) + ) - def run_table8(self) -> bool: - return self._run_table( - "TABLE-8", + # Tabels 6, 7, and 8 (with 8 being optional) + reqs.extend( [ - self.load_table8, - self.parse_table8, - self.load_table8_reference, - self.compare_table8_against_reference, - ], + SimilarityRequirement( + name = "TABLE-6", + label = "TABLE-6", + results_path = self._table6_results_path, + reference_path = self._table6_reference_path, + threshold = self._threshold, + parse_results_fn = _parse_table6_results_pairs, + parse_reference_fn = _parse_table6_reference_pairs, + ), + SimilarityRequirement( + name = "TABLE-7", + label = "TABLE-7", + results_path = self._table7_results_path, + reference_path = self._table7_reference_path, + threshold = self._threshold, + parse_results_fn = _parse_table7_results_pairs, + parse_reference_fn = _parse_table7_reference_pairs, + ), + SimilarityRequirement( + name = "TABLE-8", + label = "TABLE-8", + results_path = self._table8_results_path, + reference_path = self._table8_reference_path, + threshold = self._threshold, + parse_results_fn = _parse_table8_results_pairs, + parse_reference_fn = _parse_table8_reference_pairs, + ), + ] ) - def run(self): - """ - Run all table checks and return True only if every table passes. - Each table logs exactly one line: PASS or the first failure. - """ - results: list[bool] = [] - - results.append(self.run_table5()) - results.append(self.run_table6()) - results.append(self.run_table7()) - results.append(self.run_table8()) - - return all(results) + return reqs diff --git a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/utils.py b/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/utils.py deleted file mode 100644 index cec05877..00000000 --- a/benchmarks/arteval_bench/data/benchmark/sosp23_acto/_agent_eval/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -# --- CONSTANTS --- # -from pathlib import Path - -HOME = Path.home() / "sosp23_acto" -REPO_DIR = f"{HOME}/acto" - -RESULTS_PATH_TABLES = { - "table5": f"{REPO_DIR}/table5.txt", - "table6": f"{REPO_DIR}/table6.txt", - "table7": f"{REPO_DIR}/table7.txt", - "table8": f"{REPO_DIR}/table8.txt" - } -REFERENCE_PATH_TABLES = { - "table5": f"{HOME}/_agent_eval/refs/table5.ref.json", - "table6": f"{HOME}/_agent_eval/refs/table6.ref.json", - "table7": f"{HOME}/_agent_eval/refs/table7.ref.json", - "table8": f"{HOME}/_agent_eval/refs/table8.ref.json" - } - -SIMILARITY_RATIO = 0.75 - - -# --- CUSTOM LOGGER --- # -import logging -import os -from datetime import datetime - -os.makedirs('logs', exist_ok=True) - -LOG_FORMAT = '%(asctime)s | %(levelname)s | %(name)s | %(message)s' -DATE_FORMAT = '%Y-%m-%d %H:%M:%S' - -logger = logging.getLogger("OSDI24-ANVIL-AGENT-EVALUATOR") -logger.setLevel(logging.DEBUG) - -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.INFO) -console_handler.setFormatter(logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)) - -logger.addHandler(console_handler) \ No newline at end of file