Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ Import as `temporal_model.api`. Depends on `temporal-model-core`.
`version` is `{api, model}` — the code release (== the Docker image tag,
`null` on non-release builds) and the packaged model release.
`POST /predict?verbose=true` adds a `details` block (decision, preprocessing,
per-tube tracks). See `docs/specs/2026-06-02-api-service-design.md` for the
full contract.
per-tube tracks). `POST /predict?compute_trigger=true` runs the
first-crossing search (extra classifier work, off by default) and adds a
top-level `trigger_frame_index` (`null` if nothing crossed) — with
`verbose=true` it also fills `details.decision.trigger_tube_id` and
per-tube `details.tubes[].first_crossing_frame`. See
`docs/specs/2026-06-02-api-service-design.md` for the full contract.

## Run

Expand Down
12 changes: 10 additions & 2 deletions api/src/temporal_model/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def health(request: Request) -> HealthResponse:
dependencies=[Depends(require_token)],
)
async def predict(
body: PredictRequest, request: Request, verbose: bool = False
body: PredictRequest,
request: Request,
verbose: bool = False,
compute_trigger: bool = False,
) -> PredictResponse:
bucket = body.bucket or settings.s3_bucket
if not bucket:
Expand All @@ -154,7 +157,11 @@ async def predict(
)

out = await runner.predict(
paths, roi=body.roi_xyxyn, timer=timer, profile=profile
paths,
roi=body.roi_xyxyn,
timer=timer,
profile=profile,
compute_trigger=compute_trigger,
)

profiling = None
Expand All @@ -173,6 +180,7 @@ async def predict(
model_version=runner.version,
calibrated=runner.calibrated,
verbose=verbose,
compute_trigger=compute_trigger,
threshold_overridden=runner.threshold_overridden,
packaged_threshold=runner.packaged_threshold,
profiling=profiling,
Expand Down
10 changes: 8 additions & 2 deletions api/src/temporal_model/api/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ async def predict(
roi: tuple[float, float, float, float] | None = None,
timer: StageTimer | None = None,
profile: dict[str, Any] | None = None,
compute_trigger: bool = False,
) -> Any:
"""Resolve detections (cache + detect misses) then run the model.

Expand All @@ -135,7 +136,7 @@ async def predict(
"""
async with self._lock:
return await run_in_threadpool(
self._predict_sync, frame_paths, roi, timer, profile
self._predict_sync, frame_paths, roi, timer, profile, compute_trigger
)

def _predict_sync(
Expand All @@ -144,6 +145,7 @@ def _predict_sync(
roi: tuple[float, float, float, float] | None = None,
timer: StageTimer | None = None,
profile: dict[str, Any] | None = None,
compute_trigger: bool = False,
) -> Any:
started = time.perf_counter()
frames = self._model.load_sequence(frame_paths)
Expand All @@ -160,7 +162,11 @@ def _predict_sync(
self._cache.put(fd.frame_id, fd)
resolved[fd.frame_id] = fd
out = self._model.predict(
frames, frame_detections=resolved, roi=roi, timer=timer
frames,
frame_detections=resolved,
roi=roi,
timer=timer,
compute_trigger=compute_trigger,
)
if profile is not None:
profile["n_frames"] = len(frames)
Expand Down
29 changes: 25 additions & 4 deletions api/src/temporal_model/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Public request/response DTOs and the mapper from the core model output.

The default response is the lean verdict; ``?verbose=true`` adds a ``details``
block. ``details`` is only set when verbose, so the route serializes with
``exclude_unset=True`` to omit it otherwise (while keeping explicit ``null``s).
block and ``?compute_trigger=true`` adds the time-to-detection fields. Both
are only set when requested, so the route serializes with
``exclude_unset=True`` to omit them otherwise (while keeping explicit
``null``s).
"""

import re
Expand Down Expand Up @@ -87,6 +89,7 @@ class Tube(BaseModel):
end_frame: int
logit: float
probability: float | None
first_crossing_frame: int | None = None
entries: list[FrameEntry]


Expand All @@ -95,6 +98,7 @@ class Decision(BaseModel):
threshold: float
threshold_overridden: bool = False
packaged_threshold: float | None = None
trigger_tube_id: int | None = None


class Preprocessing(BaseModel):
Expand Down Expand Up @@ -127,6 +131,7 @@ class Version(BaseModel):
class PredictResponse(BaseModel):
is_smoke: bool
probability: float | None
trigger_frame_index: int | None = None
version: Version
details: Details | None = None

Expand All @@ -150,12 +155,23 @@ def _to_details(
threshold_overridden: bool,
packaged_threshold: float | None,
profiling: dict[str, Any] | None = None,
compute_trigger: bool = False,
) -> Details:
tubes_block = details["tubes"]
pre = details["preprocessing"]
decision = dict(details["decision"])
kept = tubes_block["kept"]
if not compute_trigger:
# Core emits these keys even on the fast path (always null there);
# dropping them keeps the DTO fields unset so exclude_unset omits
# them and the no-flag response is unchanged.
decision.pop("trigger_tube_id", None)
kept = [
{k: v for k, v in t.items() if k != "first_crossing_frame"} for t in kept
]
return Details(
decision=Decision(
**details["decision"],
**decision,
threshold_overridden=threshold_overridden,
packaged_threshold=packaged_threshold,
),
Expand All @@ -168,7 +184,7 @@ def _to_details(
# always emits the key; a silent 0 here would mask a core rename.
num_tubes_outside_roi=tubes_block["num_outside_roi"],
),
tubes=[Tube(**t) for t in tubes_block["kept"]],
tubes=[Tube(**t) for t in kept],
profiling=profiling,
)

Expand All @@ -180,6 +196,7 @@ def to_response(
model_version: str | None,
calibrated: bool,
verbose: bool,
compute_trigger: bool = False,
threshold_overridden: bool = False,
packaged_threshold: float | None = None,
profiling: dict[str, Any] | None = None,
Expand All @@ -190,11 +207,15 @@ def to_response(
"probability": _decision_probability(out.details, calibrated),
"version": Version(api=api_version, model=model_version),
}
if compute_trigger:
# Explicit null is meaningful here: searched, no crossing found.
kwargs["trigger_frame_index"] = out.trigger_frame_index
if verbose:
kwargs["details"] = _to_details(
out.details,
threshold_overridden=threshold_overridden,
packaged_threshold=packaged_threshold,
profiling=profiling,
compute_trigger=compute_trigger,
)
return PredictResponse(**kwargs)
34 changes: 33 additions & 1 deletion api/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ def __init__(self, output=None, error=None):
self._output = output
self._error = error
self.roi = None
self.compute_trigger = None

async def predict(self, paths, *, roi=None, timer=None, profile=None):
async def predict(
self, paths, *, roi=None, timer=None, profile=None, compute_trigger=False
):
self.roi = roi
self.compute_trigger = compute_trigger
if self._error:
raise self._error
if timer is not None:
Expand Down Expand Up @@ -212,6 +216,34 @@ def test_predict_verbose_surfaces_override(client):
assert decision["packaged_threshold"] == 0.5


def test_predict_compute_trigger_returns_trigger_frame_index(client):
r = client.post("/predict?compute_trigger=true", json={"frames": KEYS})
assert r.status_code == 200
assert r.json() == {
"is_smoke": True,
"probability": 0.98,
"trigger_frame_index": 3,
"version": {"api": None, "model": "1.2.0"},
}
assert client.app.state.runner.compute_trigger is True


def test_predict_default_runs_fast_path(client):
r = client.post("/predict", json={"frames": KEYS})
assert r.status_code == 200
assert "trigger_frame_index" not in r.json()
assert client.app.state.runner.compute_trigger is False


def test_predict_compute_trigger_verbose_adds_trigger_details(client):
r = client.post("/predict?compute_trigger=true&verbose=true", json={"frames": KEYS})
body = r.json()
assert r.status_code == 200
assert body["trigger_frame_index"] == 3
assert body["details"]["decision"]["trigger_tube_id"] == 7
assert body["details"]["tubes"][0]["first_crossing_frame"] == 3


def test_predict_reports_api_version(client, monkeypatch):
# settings.api_version is read per request, so a monkeypatched value
# must show up as version.api.
Expand Down
36 changes: 34 additions & 2 deletions api/tests/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self):
self.detect_calls: list[list[str]] = []
self.predict_calls: list[set[str]] = []
self.roi_calls: list[tuple | None] = []
self.trigger_calls: list[bool] = []

def load_sequence(self, paths):
return [
Expand All @@ -137,9 +138,18 @@ def detect(self, frames):
for i, f in enumerate(frames)
]

def predict(self, frames, *, frame_detections=None, roi=None, timer=None):
def predict(
self,
frames,
*,
frame_detections=None,
roi=None,
timer=None,
compute_trigger=False,
):
self.predict_calls.append(set(frame_detections or {}))
self.roi_calls.append(roi)
self.trigger_calls.append(compute_trigger)
return SimpleNamespace(frame_ids=[f.frame_id for f in frames])


Expand Down Expand Up @@ -167,6 +177,20 @@ def test_predict_roi_defaults_to_none():
assert model.roi_calls[-1] is None


def test_predict_threads_compute_trigger_to_model():
model = _OrchestrationModel()
runner = ModelRunner(model, name="m", version="1", calibrated=True)
asyncio.run(runner.predict(["c/x_00.jpg"], compute_trigger=True))
assert model.trigger_calls[-1] is True


def test_predict_compute_trigger_defaults_to_false():
model = _OrchestrationModel()
runner = ModelRunner(model, name="m", version="1", calibrated=True)
asyncio.run(runner.predict(["c/x_00.jpg"]))
assert model.trigger_calls[-1] is False


def test_predict_caches_and_reuses_detections():
model = _OrchestrationModel()
runner = ModelRunner(
Expand Down Expand Up @@ -208,7 +232,15 @@ def load_sequence(self, paths):
def detect(self, misses):
return [SimpleNamespace(frame_id=f.frame_id) for f in misses]

def predict(self, frames, *, frame_detections=None, roi=None, timer=None):
def predict(
self,
frames,
*,
frame_detections=None,
roi=None,
timer=None,
compute_trigger=False,
):
self.predict_timer = timer
if timer is not None:
with timer.stage("classifier"):
Expand Down
75 changes: 75 additions & 0 deletions api/tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,78 @@ def test_verbose_details_num_tubes_outside_roi_is_strict():
to_response(
out, api_version=None, model_version="1", calibrated=True, verbose=True
)


def test_compute_trigger_sets_top_level_trigger_frame_index():
out = SimpleNamespace(
is_positive=True, trigger_frame_index=3, details=_details([_tube(7, 0.98)])
)
resp = to_response(
out,
api_version=None,
model_version="1.2.0",
calibrated=True,
verbose=False,
compute_trigger=True,
)
dumped = resp.model_dump(exclude_unset=True)
assert dumped["trigger_frame_index"] == 3
assert "details" not in dumped


def test_compute_trigger_no_crossing_is_explicit_null():
# Searched but nothing crossed: the key is present with an explicit null.
out = SimpleNamespace(
is_positive=False, trigger_frame_index=None, details=_details([])
)
resp = to_response(
out,
api_version=None,
model_version="1.2.0",
calibrated=True,
verbose=False,
compute_trigger=True,
)
dumped = resp.model_dump(exclude_unset=True)
assert "trigger_frame_index" in dumped
assert dumped["trigger_frame_index"] is None


def test_default_omits_trigger_frame_index():
# Even when the core output carries a trigger, the flag gates exposure.
out = SimpleNamespace(
is_positive=True, trigger_frame_index=3, details=_details([_tube(7, 0.98)])
)
resp = to_response(
out, api_version=None, model_version="1.2.0", calibrated=True, verbose=False
)
assert "trigger_frame_index" not in resp.model_dump(exclude_unset=True)


def test_compute_trigger_verbose_adds_trigger_details():
out = SimpleNamespace(
is_positive=True, trigger_frame_index=3, details=_details([_tube(7, 0.98)])
)
resp = to_response(
out,
api_version=None,
model_version="1.2.0",
calibrated=True,
verbose=True,
compute_trigger=True,
)
details = resp.model_dump(exclude_unset=True)["details"]
assert details["decision"]["trigger_tube_id"] == 7
assert details["tubes"][0]["first_crossing_frame"] == 3


def test_verbose_without_compute_trigger_omits_trigger_details():
out = SimpleNamespace(
is_positive=True, trigger_frame_index=3, details=_details([_tube(7, 0.98)])
)
resp = to_response(
out, api_version=None, model_version="1.2.0", calibrated=True, verbose=True
)
details = resp.model_dump(exclude_unset=True)["details"]
assert "trigger_tube_id" not in details["decision"]
assert "first_crossing_frame" not in details["tubes"][0]
Loading
Loading