Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ Import as `temporal_model.api`. Depends on `temporal-model-core`.

- `GET /health` — readiness + loaded model name/version + API code version.
- `POST /predict` — body `{ "frames": [...], "source": "s3" | "local",
"bucket": "<name>", "roi_xyxyn": [x_min, y_min, x_max, y_max] }`
"bucket": "<name>", "roi_xyxyn": [x_min, y_min, x_max, y_max],
"detections": [[{"xyxyn": [...], "confidence": 0.6}], []] }`
(ordered frames; `source` optional, falls back to `FRAME_SOURCE` — with
`s3`, frames are S3 keys and `bucket` optionally overrides `S3_BUCKET`;
with `local`, frames are relative paths under `FRAMES_ROOT` and `bucket`
is invalid; `roi_xyxyn` optional normalized region of interest — tubes
with no real detection intersecting it are dropped before scoring);
with no real detection intersecting it are dropped before scoring;
`detections` optional caller-supplied boxes, one list per frame
index-aligned with `frames`, `[]` = that frame's detector saw nothing —
skips the bundled YOLO and its cache entirely, tubes are built from the
supplied boxes);
returns `{ is_smoke, probability, version }` (`probability` = max kept-tube
calibrated probability, `null` if uncalibrated).
`version` is `{api, model}` — the code release (== the Docker image tag,
Expand All @@ -25,8 +30,10 @@ Import as `temporal_model.api`. Depends on `temporal-model-core`.
top-level `trigger_frame_index` (`null` if nothing crossed) — with
`verbose=true` it also fills `details.decision.trigger_tube_id` and
per-tube `details.tubes[].first_crossing_frame`. See
`docs/specs/2026-06-02-api-service-design.md` and
`docs/specs/2026-06-11-api-local-frames-design.md` for the full contract.
`docs/specs/2026-06-02-api-service-design.md`,
`docs/specs/2026-06-11-api-local-frames-design.md` and
`docs/specs/2026-06-11-api-supplied-detections-design.md` for the full
contract.

## Run

Expand Down
4 changes: 4 additions & 0 deletions api/src/temporal_model/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ async def predict(
out = await runner.predict(
paths,
roi=body.roi_xyxyn,
detections=body.detections,
timer=timer,
profile=profile,
compute_trigger=compute_trigger,
Expand All @@ -210,6 +211,9 @@ async def predict(
compute_trigger=compute_trigger,
threshold_overridden=runner.threshold_overridden,
packaged_threshold=runner.packaged_threshold,
detections_source=(
"request" if body.detections is not None else "detector"
),
profiling=profiling,
)
except ApiError:
Expand Down
66 changes: 64 additions & 2 deletions api/src/temporal_model/api/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from starlette.concurrency import run_in_threadpool

from temporal_model.core.stage_timer import StageTimer, stage_ctx
from temporal_model.core.types import Detection, FrameDetections

from .detection_cache import DetectionCache
from .schemas import SuppliedDetection

logger = logging.getLogger(__name__)

Expand All @@ -45,6 +47,37 @@ def _load_core_model(package_path: Path, device: str | None) -> Any:
return BboxTubeTemporalModel.from_package(package_path, device=device)


def _supplied_frame_detections(
frames: list[Any], detections: list[list[SuppliedDetection]]
) -> dict[str, FrameDetections]:
"""Convert caller-supplied xyxyn boxes to per-frame ``FrameDetections``.

``detections`` is index-aligned with ``frames`` (lengths validated at the
HTTP boundary; ``strict=True`` is a safety net). Boxes arrive as
normalized corners and become center-based xywhn ``Detection``s; supplied
boxes are smoke by definition (``class_id=0``).
"""
resolved: dict[str, FrameDetections] = {}
for idx, (frame, boxes) in enumerate(zip(frames, detections, strict=True)):
resolved[frame.frame_id] = FrameDetections(
frame_idx=idx,
frame_id=frame.frame_id,
timestamp=frame.timestamp,
detections=[
Detection(
class_id=0,
cx=(b.xyxyn[0] + b.xyxyn[2]) / 2.0,
cy=(b.xyxyn[1] + b.xyxyn[3]) / 2.0,
w=b.xyxyn[2] - b.xyxyn[0],
h=b.xyxyn[3] - b.xyxyn[1],
confidence=b.confidence,
)
for b in boxes
],
)
return resolved


class ModelRunner:
"""Holds the loaded model and serializes inference calls."""

Expand Down Expand Up @@ -122,6 +155,7 @@ async def predict(
frame_paths: list[Path],
*,
roi: tuple[float, float, float, float] | None = None,
detections: list[list[SuppliedDetection]] | None = None,
timer: StageTimer | None = None,
profile: dict[str, Any] | None = None,
compute_trigger: bool = False,
Expand All @@ -132,23 +166,51 @@ async def predict(
cache is accessed by one prediction at a time. When ``timer``/``profile``
are supplied, the ``detector`` stage is timed and cache counts recorded.
``roi`` is passed through to the core model untouched — the cache stays
full-frame (see the invariant in the ROI spec).
full-frame (see the invariant in the ROI spec). When ``detections`` is
supplied (index-aligned per-frame boxes from the caller's own
detector), the bundled detector and its cache are bypassed entirely:
no read, no write, no ``detector`` stage.
"""
async with self._lock:
return await run_in_threadpool(
self._predict_sync, frame_paths, roi, timer, profile, compute_trigger
self._predict_sync,
frame_paths,
roi,
detections,
timer,
profile,
compute_trigger,
)

def _predict_sync(
self,
frame_paths: list[Path],
roi: tuple[float, float, float, float] | None = None,
detections: list[list[SuppliedDetection]] | None = None,
timer: StageTimer | None = None,
profile: dict[str, Any] | None = None,
compute_trigger: bool = False,
) -> Any:
started = time.perf_counter()
frames = self._model.load_sequence(frame_paths)
if detections is not None:
out = self._model.predict(
frames,
frame_detections=_supplied_frame_detections(frames, detections),
roi=roi,
timer=timer,
compute_trigger=compute_trigger,
)
if profile is not None:
profile["n_frames"] = len(frames)
profile["cache_hits"] = 0
profile["cache_misses"] = 0
logger.info(
"predict: supplied detections, seq_len=%d, %.0fms",
len(frames),
(time.perf_counter() - started) * 1000.0,
)
return out
resolved: dict[str, Any] = {}
misses = []
for f in frames:
Expand Down
48 changes: 47 additions & 1 deletion api/src/temporal_model/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
from typing import Any, Literal

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator

from temporal_model.core.tubes import validate_roi

Expand All @@ -21,6 +21,30 @@
_BUCKET_RE = re.compile(r"^[a-z0-9][a-z0-9.-]{1,61}[a-z0-9]$")


class SuppliedDetection(BaseModel):
"""One caller-supplied detection box (normalized xyxyn corners).

Geometry rules match ``roi_xyxyn``. Checked inline rather than via the
core ``validate_roi`` helper so the error message names the detection
field, not "roi".
"""

xyxyn: tuple[float, float, float, float]
confidence: float = Field(ge=0.0, le=1.0)

@field_validator("xyxyn")
@classmethod
def _validate_xyxyn(
cls, v: tuple[float, float, float, float]
) -> tuple[float, float, float, float]:
x_min, y_min, x_max, y_max = v
if not all(0.0 <= c <= 1.0 for c in v):
raise ValueError("xyxyn coordinates must be in [0, 1]")
if x_min >= x_max or y_min >= y_max:
raise ValueError("xyxyn requires x_min < x_max and y_min < y_max")
return v


class PredictRequest(BaseModel):
frames: list[str]
# Where `frames` live: "s3" (keys in a bucket) or "local" (relative paths
Expand All @@ -37,6 +61,12 @@ class PredictRequest(BaseModel):
# detection intersecting it are dropped before scoring (see
# docs/specs/2026-06-10-api-roi-design.md).
roi_xyxyn: tuple[float, float, float, float] | None = None
# Optional caller-supplied detections, one list per frame, index-aligned
# with `frames` ([] = that frame's detector saw nothing — never null).
# When set, the bundled YOLO and its cache are bypassed entirely and tubes
# are built from these boxes (see
# docs/specs/2026-06-11-api-supplied-detections-design.md).
detections: list[list[SuppliedDetection]] | None = None

@field_validator("frames")
@classmethod
Expand Down Expand Up @@ -76,6 +106,15 @@ def _validate_roi(
raise ValueError(f"roi_xyxyn: {e}") from e
return v

@model_validator(mode="after")
def _detections_match_frames(self) -> "PredictRequest":
if self.detections is not None and len(self.detections) != len(self.frames):
raise ValueError(
"detections must have exactly one entry per frame "
f"(got {len(self.detections)} entries for {len(self.frames)} frames)"
)
return self


class FrameEntry(BaseModel):
frame_idx: int
Expand Down Expand Up @@ -113,6 +152,9 @@ class Preprocessing(BaseModel):
padded_frame_indices: list[int]
num_tube_candidates: int
num_tubes_outside_roi: int
# Provenance: "request" when the caller supplied the detections (bundled
# detector bypassed), "detector" when the bundled YOLO produced them.
detections_source: Literal["request", "detector"]


class Details(BaseModel):
Expand Down Expand Up @@ -160,6 +202,7 @@ def _to_details(
*,
threshold_overridden: bool,
packaged_threshold: float | None,
detections_source: Literal["request", "detector"],
profiling: dict[str, Any] | None = None,
compute_trigger: bool = False,
) -> Details:
Expand Down Expand Up @@ -189,6 +232,7 @@ def _to_details(
# Strict like num_candidates: core (same-commit path dependency)
# always emits the key; a silent 0 here would mask a core rename.
num_tubes_outside_roi=tubes_block["num_outside_roi"],
detections_source=detections_source,
),
tubes=[Tube(**t) for t in kept],
profiling=profiling,
Expand All @@ -205,6 +249,7 @@ def to_response(
compute_trigger: bool = False,
threshold_overridden: bool = False,
packaged_threshold: float | None = None,
detections_source: Literal["request", "detector"] = "detector",
profiling: dict[str, Any] | None = None,
) -> PredictResponse:
"""Reshape a core model output into the public response DTO."""
Expand All @@ -221,6 +266,7 @@ def to_response(
out.details,
threshold_overridden=threshold_overridden,
packaged_threshold=packaged_threshold,
detections_source=detections_source,
profiling=profiling,
compute_trigger=compute_trigger,
)
Expand Down
86 changes: 85 additions & 1 deletion api/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,22 @@ def __init__(self, output=None, error=None):
self._output = output
self._error = error
self.roi = None
self.detections = None
self.paths = None
self.compute_trigger = None

async def predict(
self, paths, *, roi=None, timer=None, profile=None, compute_trigger=False
self,
paths,
*,
roi=None,
detections=None,
timer=None,
profile=None,
compute_trigger=False,
):
self.roi = roi
self.detections = detections
self.paths = paths
self.compute_trigger = compute_trigger
if self._error:
Expand Down Expand Up @@ -465,6 +474,53 @@ def test_predict_invalid_roi_is_400(client):
assert "roi_xyxyn" in body["detail"]


def test_predict_passes_detections_to_runner(client):
r = client.post(
"/predict",
json={
"frames": KEYS,
"detections": [
[{"xyxyn": [0.1, 0.2, 0.3, 0.4], "confidence": 0.6}],
[],
],
},
)
assert r.status_code == 200
sent = client.app.state.runner.detections
assert sent[0][0].xyxyn == (0.1, 0.2, 0.3, 0.4)
assert sent[0][0].confidence == 0.6
assert sent[1] == []


def test_predict_without_detections_passes_none(client):
r = client.post("/predict", json={"frames": KEYS})
assert r.status_code == 200
assert client.app.state.runner.detections is None


def test_predict_detections_length_mismatch_is_400(client):
r = client.post("/predict", json={"frames": KEYS, "detections": [[]]})
assert r.status_code == 400
body = r.json()
assert body["code"] == "invalid_request"
assert "one entry per frame" in body["detail"]


def test_predict_malformed_detection_is_400(client):
r = client.post(
"/predict",
json={
"frames": KEYS,
"detections": [
[{"xyxyn": [0.3, 0.2, 0.1, 0.4], "confidence": 0.6}],
[],
],
},
)
assert r.status_code == 400
assert r.json()["code"] == "invalid_request"


@pytest.fixture
def local_client(monkeypatch, tmp_path):
# An edge-box style deployment: frame_source=local, frames on a shared
Expand Down Expand Up @@ -522,6 +578,34 @@ def test_predict_local_no_root_400_takes_precedence_over_model(
assert r.json()["code"] == "invalid_request"


def test_predict_detections_compose_with_roi(client):
r = client.post(
"/predict",
json={
"frames": KEYS,
"detections": [[], []],
"roi_xyxyn": [0.0, 0.0, 1.0, 1.0],
},
)
assert r.status_code == 200
assert client.app.state.runner.roi == (0.0, 0.0, 1.0, 1.0)
assert client.app.state.runner.detections == [[], []]


def test_predict_verbose_detections_source_request(client):
r = client.post(
"/predict?verbose=true", json={"frames": KEYS, "detections": [[], []]}
)
assert r.status_code == 200
assert r.json()["details"]["preprocessing"]["detections_source"] == "request"


def test_predict_verbose_detections_source_detector(client):
r = client.post("/predict?verbose=true", json={"frames": KEYS})
assert r.status_code == 200
assert r.json()["details"]["preprocessing"]["detections_source"] == "detector"


def test_predict_local_missing_frame_404(local_client):
r = local_client.post("/predict", json={"frames": ["cam12/missing.jpg"]})
assert r.status_code == 404
Expand Down
Loading
Loading