Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions exca/steps/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,27 +249,26 @@ def __init__(
self.step_uid = step_uid

# Returns None: the driver re-reads from cache, so the result never
# round-trips through a job pickle (matters for submitit).
# round-trips through a job pickle.
def __call__(self, batch: items.StepItems) -> None:
folder = self.cache_dict.folder
if folder is not None:
folder.mkdir(parents=True, exist_ok=True)
uids = batch.uids
results = items._annotated_batch(self.step, batch, uids=uids)
result_items = self.step._run_items(batch)
try:
with self.cache_dict.write():
uid_iter = iter(uids)
for result in results:
uid = next(uid_iter)
for i, result in enumerate(result_items):
uid = batch.uids[i]
if uid not in self.cache_dict:
self.cache_dict[uid] = result
except Exception as e:
failed_uid: str | None = getattr(e, "_failed_uid", None)
if folder is not None and failed_uid is not None:
e.add_note(f" -> cached as {self.step_uid}[{failed_uid}]")
inflight: list[str] = getattr(e, "_inflight_uids", [])
if folder is not None and inflight:
e.add_note(f" -> error recorded at {self.step_uid}{inflight}")
tb = "".join(traceback.format_exception(e))
with errors.ErrorRegistry(folder) as reg:
reg.record(failed_uid, e, tb)
for uid in inflight:
reg.record(uid, e, tb)
raise


Expand Down
23 changes: 12 additions & 11 deletions exca/steps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Core step classes.

Step holds pydantic config and `_run`; identity (`step_uid`, `uid`)
is computed by `exca.steps.identity` from `(self, value)` at call time.
`_dispatch` routes computation inline or via backend; ``_run_batch``
does the actual work.
"""
"""Core step classes: Step and Chain."""

from __future__ import annotations

Expand Down Expand Up @@ -165,7 +159,11 @@ def __pydantic_init_subclass__(cls, **kwargs: tp.Any) -> None:
f"{cls.__name__} overrides _forward which was removed; "
"override _run instead"
)
has_run = cls._run is not Step._run or cls._run_batch is not Step._run_batch
has_run = (
cls._run is not Step._run
or cls._run_batch is not Step._run_batch
or cls._run_items is not Step._run_items
)
if has_run:
flags.add("has_run")
if _has_all_defaults(cls._run):
Expand Down Expand Up @@ -239,6 +237,10 @@ def _inner_mode(self) -> identity.ModeType:
return resolved._inner_mode()
return "cached" if self.infra is None else self.infra.mode

def _run_items(self, batch: items.StepItems) -> items.StepItems:
"""Process *batch* and return result as StepItems."""
return batch.apply_step(self)

def _dispatch(self, batch: items.StepItems) -> items.StepItems:
"""Push *batch* through this step, return result as StepItems."""
self._check_cache_type()
Expand Down Expand Up @@ -552,9 +554,8 @@ def _walk_steps(
current = step._dispatch(current)
return current

def _run_batch(self, values: tp.Iterable[tp.Any]) -> tp.Iterator[tp.Any]:
# values is actually StepItems here (FIX with _run_items?)
yield from self._walk_steps(values) # type: ignore[arg-type]
def _run_items(self, batch: items.StepItems) -> items.StepItems:
return self._walk_steps(batch)

def _inner_mode(self) -> identity.ModeType:
own: identity.ModeType = "cached" if self.infra is None else self.infra.mode
Expand Down
72 changes: 43 additions & 29 deletions exca/steps/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import collections
import typing as tp

import exca.cachedict
Expand Down Expand Up @@ -45,33 +46,47 @@ def __iter__(self) -> tp.Iterator[tp.Any]:
return iter(self._values)


def _annotated_batch(
step: Step,
values: tp.Iterable[tp.Any],
*,
uids: tp.Sequence[str] | None = None,
) -> tp.Iterator[tp.Any]:
"""Iterate ``step._run_batch(values)`` with yield-count validation and error annotation."""
n_out = 0
expected = len(uids) if uids is not None else None
try:
for result in step._run_batch(values):
if expected is not None and n_out >= expected:
raise RuntimeError(
f"{step!r}._run_batch yielded more than {expected} results"
)
n_out += 1
yield result
except Exception as e:
uid = uids[n_out] if uids is not None and n_out < len(uids) else None
e.add_note(f" -> while running step {step!r}" + (f"[{uid}]" if uid else ""))
if uid is not None:
e._failed_uid = uid # type: ignore[attr-defined]
raise
if expected is not None and n_out < expected:
raise RuntimeError(
f"{step!r}._run_batch yielded {n_out} results for {expected} inputs"
)
class _AnnotatedBatch:
"""Wraps ``step._run_batch`` with consumption tracking, yield validation, and error annotation.

On error, ``_inflight_uids`` on the exception contains the consumed-but-not-yielded uids.
"""

def __init__(
self, step: Step, values: tp.Iterable[tp.Any], uids: tp.Sequence[str]
) -> None:
self.step = step
self._values = values
self._uid_iter = iter(uids)
self._expected = len(uids)
self._inflight: collections.deque[str] = collections.deque()
self.n_out = 0

def _tracked(self) -> tp.Iterator[tp.Any]:
for v in self._values:
self._inflight.append(next(self._uid_iter))
yield v

def __iter__(self) -> tp.Iterator[tp.Any]:
try:
for result in self.step._run_batch(self._tracked()):
if self.n_out >= self._expected:
raise RuntimeError(
f"{self.step!r}._run_batch yielded more than {self._expected} results"
)
self._inflight.popleft()
self.n_out += 1
yield result
except Exception as e:
failed = list(self._inflight)
e.add_note(f" -> in {self.step!r}, inflight uids: {failed}")
if failed:
e._inflight_uids = failed # type: ignore[attr-defined]
raise
if self.n_out < self._expected:
raise RuntimeError(
f"{self.step!r}._run_batch yielded {self.n_out} results for {self._expected} inputs"
)


class StepItems(Items):
Expand Down Expand Up @@ -135,7 +150,6 @@ def select(self, uids: tp.Sequence[str]) -> StepItems:

def __iter__(self) -> tp.Iterator[tp.Any]:
current: tp.Iterable[tp.Any] = (self._source[uid] for uid in self.uids)
uids = self.uids
for step in self._pending:
current = _annotated_batch(step, current, uids=uids)
current = _AnnotatedBatch(step, current, self.uids)
return iter(current)
8 changes: 3 additions & 5 deletions exca/steps/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,15 @@ def boom(self: tp.Any, item_uids: list[str]) -> None:


def test_cached_as_note_survives_pickle(tmp_path: Path) -> None:
"""`_CachingCall` adds a `-> cached as ...` note before pickling into
errors.db. Notes must round-trip so the re-raise points to the cache key."""
with pytest.raises(ValueError) as exc_info:
_add(True, tmp_path).run(5.0)
notes = getattr(exc_info.value, "__notes__", [])
assert any("cached as" in n for n in notes)
assert any("error recorded at" in n for n in notes), "missing record note"
with pytest.raises(ValueError) as exc_info:
_add(False, tmp_path).run(5.0)
notes = getattr(exc_info.value, "__notes__", [])
assert any("cached as" in n for n in notes)
assert any("mode='retry'" in n for n in notes)
assert any("error recorded at" in n for n in notes), "missing record note"
assert any("mode='retry'" in n for n in notes), "missing retry note"


def test_orphan_errors_db_self_heals_on_recompute(tmp_path: Path) -> None:
Expand Down
48 changes: 47 additions & 1 deletion exca/steps/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""Tests for Step and Chain basic functionality (no caching tests here, see test_cache.py)."""

import collections
import itertools
import pickle
import traceback
import typing as tp
Expand Down Expand Up @@ -376,7 +377,7 @@ def test_chain_error_note() -> None:
with pytest.raises(ValueError) as exc_info:
chain.run(1)
formatted = _format_exc(exc_info.value)
assert "Add" in formatted and "while running step" in formatted
assert "Add" in formatted and "inflight uids" in formatted


# =============================================================================
Expand Down Expand Up @@ -566,3 +567,48 @@ def test_run_batch_yield_count(
infra: tp.Any = {"backend": "Cached", "folder": tmp_path} if with_infra else None
with pytest.raises(RuntimeError, match=match):
list(_NumYield(num=num, infra=infra).run(items.Items([10, 20, 30])))


class _GroupedMult(Step):
"""Consumes inputs in groups, multiplies by 10."""

_CALLS: tp.ClassVar[list[list[int]]] = []
group_size: int = 2
fail_value: int = -1

def _run_batch(self, values: tp.Iterable[tp.Any]) -> tp.Iterator[tp.Any]:
it = iter(values)
while group := list(itertools.islice(it, self.group_size)):
if self.fail_value in group:
raise ValueError("boom")
self._CALLS.append(group)
for v in group:
yield v * 10


@pytest.mark.parametrize("with_infra", [False, True])
def test_batch_error_inflight_uids(tmp_path: Path, with_infra: bool) -> None:
infra: tp.Any = {"backend": "Cached", "folder": tmp_path} if with_infra else None
step = _GroupedMult(group_size=2, fail_value=3, infra=infra)
with pytest.raises(ValueError, match="boom") as exc_info:
list(step.run(items.Items([1, 2, 3, 4, 5, 6])))
inflight = getattr(exc_info.value, "_inflight_uids", [])
assert len(inflight) == 2, f"expected 2 inflight uids, got {inflight}"


def test_chained_group_sizes_call_order() -> None:
_GroupedMult._CALLS.clear()
chain = Chain(steps=[_GroupedMult(group_size=s) for s in (2, 3, 1)])
_ = list(chain.run(items.Items([1, 2, 3, 4, 5])))
calls = list(_GroupedMult._CALLS)
_GroupedMult._CALLS.clear()
# fmt: off
assert calls == [
[1, 2], [3, 4], # step1 runs both pairs (step2 needs 3)
[10, 20, 30], # step2 fills its group of 3
[100], [200], [300], # step3 drains immediately
[5], # step1 partial: 1 left
[40, 50], # step2 partial: only 2 left
[400], [500], # step3 drains
]
# fmt: on
Loading