diff --git a/packages/jabs-core/src/jabs/core/utils/__init__.py b/packages/jabs-core/src/jabs/core/utils/__init__.py index 254bb954..297fc254 100644 --- a/packages/jabs-core/src/jabs/core/utils/__init__.py +++ b/packages/jabs-core/src/jabs/core/utils/__init__.py @@ -1,7 +1,13 @@ """JABS utilities""" from .update_checker import check_for_update, is_pypi_install -from .utilities import get_bool_env_var, hash_file, hide_stderr, to_safe_name +from .utilities import ( + get_bool_env_var, + hash_file, + hide_stderr, + pose_file_stem, + to_safe_name, +) __all__ = [ "check_for_update", @@ -9,5 +15,6 @@ "hash_file", "hide_stderr", "is_pypi_install", + "pose_file_stem", "to_safe_name", ] diff --git a/packages/jabs-core/src/jabs/core/utils/utilities.py b/packages/jabs-core/src/jabs/core/utils/utilities.py index 5b2b7b84..bb64e994 100644 --- a/packages/jabs-core/src/jabs/core/utils/utilities.py +++ b/packages/jabs-core/src/jabs/core/utils/utilities.py @@ -65,6 +65,27 @@ def get_bool_env_var(var_name, default_value=False) -> bool: return value.lower() in ("true", "1", "yes", "on", "y", "t") +_POSE_SUFFIX_RE = re.compile(r"_pose_est_v\d+$") + + +def pose_file_stem(path: str | Path) -> str: + """Return the base name of a pose file with the ``_pose_est_vN`` suffix removed. + + For example, ``"video_pose_est_v6.h5"`` becomes ``"video"``. If the input does + not include the ``_pose_est_vN`` suffix, the stem is returned unchanged + (this allows callers to pass either a pose file path or a video file path + and get a consistent identifier). + + Args: + path: A pose file path (or any path-like value) whose stem may include + the ``_pose_est_vN`` suffix. + + Returns: + The stem with any trailing ``_pose_est_vN`` suffix stripped. + """ + return _POSE_SUFFIX_RE.sub("", Path(path).stem) + + def to_safe_name(behavior: str) -> str: """Create a version of the given behavior name that is safe to use in filenames. diff --git a/packages/jabs-core/tests/test_utilities.py b/packages/jabs-core/tests/test_utilities.py new file mode 100644 index 00000000..8202493f --- /dev/null +++ b/packages/jabs-core/tests/test_utilities.py @@ -0,0 +1,52 @@ +"""Unit tests for jabs.core.utils.utilities module.""" + +from pathlib import Path + +import pytest + +from jabs.core.utils import pose_file_stem + + +@pytest.mark.parametrize( + ("path", "expected"), + [ + ("video_pose_est_v6.h5", "video"), + ("video_pose_est_v2.h5", "video"), + ("video_pose_est_v12.h5", "video"), + ("/some/dir/video_pose_est_v6.h5", "video"), + ("video.mp4", "video"), + ("video.avi", "video"), + ("nested_name_pose_est_v8.h5", "nested_name"), + ("no_suffix.h5", "no_suffix"), + ("plain_name", "plain_name"), + ], + ids=[ + "pose-v6", + "pose-v2", + "pose-v12", + "pose-with-dir", + "video-mp4", + "video-avi", + "nested-underscore-name", + "h5-no-pose-suffix", + "no-extension", + ], +) +def test_pose_file_stem(path: str, expected: str) -> None: + """Pose-file suffix is stripped and other names are returned unchanged.""" + assert pose_file_stem(path) == expected + + +def test_pose_file_stem_accepts_path() -> None: + """``pathlib.Path`` inputs are supported.""" + assert pose_file_stem(Path("/a/b/video_pose_est_v6.h5")) == "video" + + +def test_pose_file_stem_video_and_pose_match() -> None: + """A video file and its matching pose file produce the same stem. + + This is what guarantees feature-cache directories are consistent whether the + caller passes the pose file (jabs-classify, jabs-cli compute-features) or the + video file (jabs-init, GUI). + """ + assert pose_file_stem("video.mp4") == pose_file_stem("video_pose_est_v6.h5") diff --git a/src/jabs/feature_extraction/features.py b/src/jabs/feature_extraction/features.py index 3bdd6d5a..f4a19ac5 100644 --- a/src/jabs/feature_extraction/features.py +++ b/src/jabs/feature_extraction/features.py @@ -9,6 +9,7 @@ from jabs.core.enums import CacheFormat from jabs.core.exceptions import DistanceScaleException, FeatureVersionException from jabs.core.types import FeatureCacheMetadata, PerFrameCacheData +from jabs.core.utils import pose_file_stem from jabs.io.feature_cache import clear_cache, detect_cache_format from jabs.io.feature_cache.base import FeatureCacheReader, FeatureCacheWriter from jabs.io.feature_cache.hdf5 import HDF5FeatureCacheReader, HDF5FeatureCacheWriter @@ -47,6 +48,45 @@ } +def _migrate_legacy_cache_dir(directory: Path, source_file: str | Path) -> None: + """One-shot rename of a legacy ``_pose_est_vN`` cache dir to ````. + + Earlier versions of jabs-classify and jabs-cli compute-features wrote feature + caches to a subdirectory whose name kept the ``_pose_est_vN`` suffix from the + pose filename, while jabs-init / the GUI used the bare video stem. We now + normalize on the bare stem; this helper renames an existing legacy directory + so previously-computed caches are still found. + + The rename is skipped (and the legacy directory left alone) if any of: + * ``source_file`` does not carry a ``_pose_est_vN`` suffix. + * The legacy directory does not exist. + * The normalized destination already exists (avoids collisions when two + pose files in the same dir would normalize to the same stem). + """ + raw_stem = Path(source_file).stem + normalized_stem = pose_file_stem(source_file) + if raw_stem == normalized_stem: + return + legacy = directory / raw_stem + normalized = directory / normalized_stem + if not legacy.is_dir() or normalized.exists(): + return + logger.info("renaming legacy feature cache subdirectory %s -> %s", legacy, normalized) + try: + legacy.rename(normalized) + except OSError: + # Best-effort migration: a permissions or filesystem error here would + # only cost a cache recompute, so log and continue rather than aborting + # feature extraction. + logger.warning( + "failed to rename legacy feature cache subdirectory %s -> %s; " + "features will be recomputed", + legacy, + normalized, + exc_info=True, + ) + + def _normalize_op_settings(settings: dict) -> dict: """Normalize op_settings so all _BASE_FILTERS keys map to ``dict[str, bool]``. @@ -129,7 +169,13 @@ def __init__( if directory is None: self._identity_feature_dir = None else: - base = Path(directory) / Path(source_file).stem + # Normalize the cache subdirectory so paths derived from a pose file + # (e.g. "video_pose_est_v6.h5") match those derived from a video file + # (e.g. "video.mp4"). Without this, jabs-classify / jabs-cli + # compute-features wrote to "/video_pose_est_v6/" while the GUI + # / jabs-init wrote to "/video/". + _migrate_legacy_cache_dir(Path(directory), source_file) + base = Path(directory) / pose_file_stem(source_file) if include_pose_hash: base = base / self._pose_hash self._identity_feature_dir = base / str(self._identity) diff --git a/src/jabs/scripts/classify.py b/src/jabs/scripts/classify.py index d6492eea..6ce8774a 100755 --- a/src/jabs/scripts/classify.py +++ b/src/jabs/scripts/classify.py @@ -18,25 +18,27 @@ from jabs.classifier import Classifier from jabs.core.constants import APP_NAME from jabs.core.enums import CacheFormat +from jabs.core.utils import pose_file_stem from jabs.feature_extraction import IdentityFeatures from jabs.pose_estimation import open_pose_file from jabs.project.prediction_manager import PredictionManager DEFAULT_FPS = 30 +_POSE_FILE_NAME_RE = re.compile(r"^.+_pose_est_v[0-9]+\.h5$") + # find out which classifiers are supported in this environment __CLASSIFIER_CHOICES = Classifier().classifier_choices() -def get_pose_stem(pose_path: Path): - """get the stem name of a pose file +def _require_pose_file_name(pose_path: Path) -> None: + """Validate that the filename matches the canonical ``*_pose_est_vN.h5`` pattern. - takes a pose path as input and returns the name component with the '_pose_est_v#.h5' suffix removed + Raises: + ValueError: If the filename does not match the canonical pose-file + pattern. """ - m = re.match(r"^(.+)(_pose_est_v[0-9]+\.h5)$", pose_path.name) - if m: - return m.group(1) - else: + if not _POSE_FILE_NAME_RE.match(pose_path.name): raise ValueError(f"{pose_path} is not a valid pose file path") @@ -104,8 +106,9 @@ def classify_pose( cache_window (bool, optional): Whether to cache window features. Defaults to False. use_pose_hash (bool, optional): Include pose file hash as a subdirectory in the cache path. Defaults to False. """ + _require_pose_file_name(input_pose_file) pose_est = open_pose_file(input_pose_file) - pose_stem = get_pose_stem(input_pose_file) + pose_stem = pose_file_stem(input_pose_file) # allocate numpy arrays to write to h5 file prediction_labels = np.full((pose_est.num_identities, pose_est.num_frames), -1, dtype=np.int8) diff --git a/tests/feature_extraction/test_identity_features.py b/tests/feature_extraction/test_identity_features.py index 754a16b2..68cae907 100644 --- a/tests/feature_extraction/test_identity_features.py +++ b/tests/feature_extraction/test_identity_features.py @@ -11,6 +11,7 @@ import jabs.pose_estimation as pose_est_module from jabs.core.enums import CacheFormat +from jabs.core.utils import pose_file_stem from jabs.feature_extraction.features import IdentityFeatures from jabs.io.feature_cache import detect_cache_format @@ -177,7 +178,7 @@ def test_force_with_format_change_removes_stale_sentinel(tmp_path, pose_est_v5) """ # Write an initial Parquet cache. _make_identity_features(pose_est_v5, tmp_path, force=True, cache_format=CacheFormat.PARQUET) - identity_dir = tmp_path / Path(_SOURCE_FILE).stem / str(_IDENTITY) + identity_dir = tmp_path / pose_file_stem(_SOURCE_FILE) / str(_IDENTITY) assert (identity_dir / "metadata.json").exists() # Force-recompute into HDF5 format. @@ -192,3 +193,109 @@ def test_force_with_format_change_removes_stale_sentinel(tmp_path, pose_est_v5) from jabs.io.feature_cache import detect_cache_format assert detect_cache_format(identity_dir) == CacheFormat.HDF5 + + +def test_feature_dir_matches_for_pose_and_video_source(tmp_path, pose_est_v5) -> None: + """Identity feature directory is the same for pose and video source filenames. + + The cache path must not depend on whether the caller passes the pose filename + (jabs-classify / jabs-cli compute-features) or the video filename (jabs-init / + GUI). Both must resolve to the same directory. + """ + pose_source = IdentityFeatures( + source_file="sample_pose_est_v5.h5", + identity=_IDENTITY, + directory=tmp_path, + pose_est=pose_est_v5, + op_settings={}, + ) + video_source = IdentityFeatures( + source_file="sample.mp4", + identity=_IDENTITY, + directory=tmp_path, + pose_est=pose_est_v5, + op_settings={}, + ) + + assert pose_source._identity_feature_dir == video_source._identity_feature_dir + assert pose_source._identity_feature_dir == tmp_path / "sample" / str(_IDENTITY) + + +def test_legacy_cache_dir_is_renamed(tmp_path, pose_est_v5, caplog) -> None: + """A legacy ``_pose_est_vN`` cache dir is renamed to ```` on construction. + + Cached features computed by the previous CLI layout must remain discoverable + after the normalization fix. + """ + legacy = tmp_path / "sample_pose_est_v5" + legacy_identity = legacy / str(_IDENTITY) + legacy_identity.mkdir(parents=True) + sentinel = legacy_identity / "features.h5" + sentinel.touch() + + with caplog.at_level("INFO", logger="jabs.feature_extraction.features"): + instance = _make_identity_features(pose_est_v5, tmp_path, force=False) + + normalized = tmp_path / "sample" + assert not legacy.exists(), "legacy dir should be renamed away" + assert (normalized / str(_IDENTITY) / "features.h5").exists() + assert instance._identity_feature_dir == normalized / str(_IDENTITY) + assert any("renaming legacy feature cache" in r.message for r in caplog.records) + + +def test_legacy_cache_dir_left_alone_on_collision(tmp_path, pose_est_v5) -> None: + """If the normalized destination already exists, the legacy dir is not renamed. + + This protects against collisions when multiple pose files in the same directory + would normalize to the same stem (e.g. ``sample_pose_est_v5.h5`` and + ``sample_pose_est_v6.h5`` both normalize to ``sample``). + """ + legacy = tmp_path / "sample_pose_est_v5" + legacy.mkdir() + (legacy / "marker").touch() + + normalized = tmp_path / "sample" + normalized.mkdir() + (normalized / "other_marker").touch() + + _make_identity_features(pose_est_v5, tmp_path, force=False) + + assert legacy.exists(), "legacy dir must be preserved on collision" + assert (legacy / "marker").exists() + assert (normalized / "other_marker").exists() + + +def test_legacy_rename_failure_is_non_fatal(tmp_path, pose_est_v5, caplog, monkeypatch) -> None: + """A rename failure logs a warning and lets construction proceed. + + Best-effort migration: an OS-level rename error must not abort feature + extraction. The worst case is a recomputed cache. + """ + legacy = tmp_path / "sample_pose_est_v5" + legacy.mkdir() + (legacy / "marker").touch() + + def _raise(self, *args, **kwargs): + raise PermissionError("simulated rename failure") + + monkeypatch.setattr(Path, "rename", _raise) + + with caplog.at_level("WARNING", logger="jabs.feature_extraction.features"): + instance = _make_identity_features(pose_est_v5, tmp_path, force=False) + + assert instance._identity_feature_dir == tmp_path / "sample" / str(_IDENTITY) + assert legacy.exists(), "legacy dir untouched after failed rename" + assert any("failed to rename" in r.message for r in caplog.records) + + +def test_no_rename_when_video_stem_used(tmp_path, pose_est_v5) -> None: + """If the source filename has no ``_pose_est_vN`` suffix, no rename is attempted.""" + instance = IdentityFeatures( + source_file="sample.mp4", + identity=_IDENTITY, + directory=tmp_path, + pose_est=pose_est_v5, + op_settings={}, + ) + + assert instance._identity_feature_dir == tmp_path / "sample" / str(_IDENTITY) diff --git a/tests/scripts/test_classify.py b/tests/scripts/test_classify.py new file mode 100644 index 00000000..70a4e6d6 --- /dev/null +++ b/tests/scripts/test_classify.py @@ -0,0 +1,38 @@ +"""Unit tests for jabs.scripts.classify helpers.""" + +from pathlib import Path + +import pytest + +from jabs.scripts.classify import _require_pose_file_name + + +@pytest.mark.parametrize( + "name", + [ + "sample_pose_est_v2.h5", + "sample_pose_est_v6.h5", + "sample_pose_est_v12.h5", + "nested_name_pose_est_v8.h5", + ], +) +def test_require_pose_file_name_accepts_canonical(name: str) -> None: + """Canonical ``*_pose_est_vN.h5`` filenames pass validation.""" + _require_pose_file_name(Path("/some/dir") / name) + + +@pytest.mark.parametrize( + "name", + [ + "sample.mp4", + "sample.h5", + "sample_v6.h5", + "sample_pose_est.h5", + "sample_pose_est_v6.hdf5", + "predictions_pose_est_v6.h5.bak", + ], +) +def test_require_pose_file_name_rejects_invalid(name: str) -> None: + """Names that do not match ``*_pose_est_vN.h5`` are rejected.""" + with pytest.raises(ValueError, match="not a valid pose file path"): + _require_pose_file_name(Path(name))