Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/jabs/classifier/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,17 @@ def _prepare_cv_labels(
return features["labels"], None, None

behavior_names = list(getattr(classifier, "behavior_names", []))
labels, _ = classifier_utils.merge_labels(features["labels_by_behavior"], behavior_names)
class_names = [MULTICLASS_NONE_BEHAVIOR, *behavior_names]
multiclass_settings = classifier.project_settings or project.get_project_defaults()

labels_by_behavior = features["labels_by_behavior"]
if not labels_by_behavior:
# No labeled frames yet: return an empty label array so the caller finds
# no valid CV splits and skips cross-validation, mirroring the binary
# path, rather than letting merge_labels() raise on empty input.
return np.empty(0, dtype=np.intp), class_names, multiclass_settings

labels, _ = classifier_utils.merge_labels(labels_by_behavior, behavior_names)
return labels, class_names, multiclass_settings


Expand Down
61 changes: 34 additions & 27 deletions src/jabs/project/prediction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,34 +177,41 @@ def _load_prediction_record(

try:
pred = io.load(path, BehaviorPrediction, behavior=behavior)
if nident is None or nident <= 0:
nident = pred.predicted_class.shape[0]
if pred.predicted_class.shape[0] != nident or pred.probabilities.shape[0] != nident:
print(f"unable to open saved inferences for {video}", file=sys.stderr)
return {}, {}, {}, None

# Guard against reading a record in the wrong mode: multi-class
# records carry class_names (and 3D probabilities), binary records
# do not. Reading across modes would silently return mis-shaped data.
record_is_multiclass = pred.class_names is not None
if record_is_multiclass != expect_multiclass:
expected = "multi-class" if expect_multiclass else "binary"
found = "multi-class" if record_is_multiclass else "binary"
raise ValueError(
f"Expected {expected} predictions for {behavior!r} in {video!r}, "
f"but the stored record is {found}."
)

class_names = pred.class_names

for i in range(nident):
predictions[i] = pred.predicted_class[i]
probabilities[i] = pred.probabilities[i]
if pred.predicted_class_postprocessed is not None:
postprocessed_predictions[i] = pred.predicted_class_postprocessed[i]

except (KeyError, FileNotFoundError):
# no saved predictions for this behavior for this video
pass
return predictions, probabilities, postprocessed_predictions, class_names
except ValueError as e:
# invalid/corrupted prediction record (e.g. a schema mismatch such as
# class_names present with 2-D probabilities); treat as no usable
# predictions rather than propagating the load error
print(f"unable to open saved inferences for {video}: {e}", file=sys.stderr)
return {}, {}, {}, None

if nident is None or nident <= 0:
nident = pred.predicted_class.shape[0]
if pred.predicted_class.shape[0] != nident or pred.probabilities.shape[0] != nident:
print(f"unable to open saved inferences for {video}", file=sys.stderr)
return {}, {}, {}, None

# Guard against reading a record in the wrong mode: multi-class records
# carry class_names (and 3D probabilities), binary records do not.
# Reading across modes would silently return mis-shaped data, so this is
# a logic error rather than a corrupt file and is surfaced by raising.
record_is_multiclass = pred.class_names is not None
if record_is_multiclass != expect_multiclass:
expected = "multi-class" if expect_multiclass else "binary"
found = "multi-class" if record_is_multiclass else "binary"
raise ValueError(
f"Expected {expected} predictions for {behavior!r} in {video!r}, "
f"but the stored record is {found}."
)

class_names = pred.class_names

for i in range(nident):
predictions[i] = pred.predicted_class[i]
probabilities[i] = pred.probabilities[i]
if pred.predicted_class_postprocessed is not None:
postprocessed_predictions[i] = pred.predicted_class_postprocessed[i]

return predictions, probabilities, postprocessed_predictions, class_names
46 changes: 46 additions & 0 deletions tests/classifier/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,52 @@ def get_feature_importance(limit=10):
return []


class _EmptyMultiClassClassifier:
"""Multiclass test double reporting no valid splits (no labeled frames)."""

def __init__(self):
self.behavior_names = ["Walk"]

@property
def project_settings(self) -> dict:
return {"window_size": 5}

@staticmethod
def get_leave_one_group_out_max(_labels, _groups) -> int:
return 0

@staticmethod
def leave_one_group_out(*_args, **_kwargs):
raise AssertionError("leave_one_group_out should not be called when max splits is zero")


def test_multiclass_cv_skips_when_no_labels() -> None:
"""Empty labels_by_behavior should skip CV gracefully rather than raise.

merge_labels() raises on an empty dict; _prepare_cv_labels must short-circuit
so the multiclass path mirrors the binary "no valid splits" behavior.
"""
features = {
"per_frame": pd.DataFrame({"a": []}),
"window": pd.DataFrame({"b": []}),
"groups": np.array([], dtype=np.int32),
"labels_by_behavior": {},
}
status_messages: list[str] = []
results = run_leave_one_group_out_cv(
classifier=_EmptyMultiClassClassifier(),
project=type("P", (), {"get_project_defaults": lambda self: {"window_size": 5}})(),
features=features,
group_mapping={},
behavior="Walk",
k=1,
status_callback=status_messages.append,
)

assert results == []
assert any("skipping CV" in msg for msg in status_messages)


def test_run_leave_one_group_out_cv_returns_empty_when_no_valid_splits() -> None:
"""No valid CV splits should not raise; CV is skipped with empty results."""
features = {
Expand Down
5 changes: 4 additions & 1 deletion tests/project/test_prediction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pytest

from jabs.core.utils import to_safe_name
from jabs.project.prediction_manager import MULTICLASS_PREDICTION_KEY, PredictionManager


Expand Down Expand Up @@ -206,7 +207,9 @@ def test_load_multiclass_predictions_invalid_shape_returns_empty(
h5.attrs["pose_file"] = "test_pose.h5"
h5.attrs["pose_hash"] = "testhash"
prediction_group = h5.create_group("predictions")
behavior_group = prediction_group.create_group(MULTICLASS_PREDICTION_KEY)
# write under the on-disk safe name so the loader actually finds the
# group and exercises the invalid-schema (ValueError) fallback path
behavior_group = prediction_group.create_group(to_safe_name(MULTICLASS_PREDICTION_KEY))
behavior_group.attrs["app_version"] = "1.0.0"
behavior_group.attrs["prediction_date"] = "2025-01-01"
behavior_group.create_dataset("predicted_class", data=[[1, 0, -1], [0, 1, -1]])
Expand Down
Loading