diff --git a/src/jabs/classifier/classifier.py b/src/jabs/classifier/classifier.py index a4848daf..bc21a58f 100644 --- a/src/jabs/classifier/classifier.py +++ b/src/jabs/classifier/classifier.py @@ -67,13 +67,15 @@ def __init__(self, classifier: ClassifierType = ClassifierType.RANDOM_FOREST, n_ raise ValueError("Invalid classifier type") @classmethod - def from_training_file(cls, path: Path): + def from_training_file(cls, path: Path, classifier_type: ClassifierType | None = None): """Initialize a classifier from an exported training data file. This method will load the training data and train a classifier. Args: path: exported training data file + classifier_type: Override the classifier algorithm stored in the training + file. If ``None``, the type recorded in the file is used. Returns: trained classifier object @@ -84,12 +86,13 @@ def from_training_file(cls, path: Path): classifier = cls() classifier.behavior_name = behavior classifier.set_dict_settings(loaded_training_data["settings"]) - classifier_type = ClassifierType(loaded_training_data["classifier_type"]) - if classifier_type in classifier._supported_classifiers: - classifier.set_classifier(classifier_type) + file_classifier_type = ClassifierType(loaded_training_data["classifier_type"]) + effective_type = classifier_type if classifier_type is not None else file_classifier_type + if effective_type in classifier._supported_classifiers: + classifier.set_classifier(effective_type) else: logging.warning( - f"Specified classifier type {classifier_type.name} is unavailable, using default: {classifier.classifier_type.name}" + f"Specified classifier type {effective_type.name} is unavailable, using default: {classifier.classifier_type.name}" ) training_features = classifier.combine_data( loaded_training_data["per_frame"], loaded_training_data["window"] @@ -423,6 +426,51 @@ def save(self, path: Path): self._classifier_hash = hash_file(Path(path)) self._classifier_source = "serialized" + @classmethod + def from_pickle(cls, path: Path) -> "Classifier": + """Load a Classifier from a pickle file with full validation and metadata backfill. + + Applies the same version, classifier-type, and metadata checks as ``load()``, + but as a classmethod factory so no dummy instance is required. + + Args: + path: Path to the saved classifier pickle file. + + Returns: + Loaded and validated ``Classifier`` instance. + + Raises: + ValueError: If the file is not a ``Classifier``, was trained with an + incompatible sklearn or JABS version, or uses an unsupported + classifier type. + """ + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", InconsistentVersionWarning) + c = joblib.load(path) + for warning in caught_warnings: + if issubclass(warning.category, InconsistentVersionWarning): + raise ValueError("Classifier trained with different version of sklearn.") + else: + warnings.warn(warning.message, warning.category, stacklevel=2) + + if not isinstance(c, cls): + raise ValueError(f"{path} is not an instance of Classifier") + + if c.version != _VERSION: + raise ValueError( + f"Unable to deserialize pickled classifier. File version {c.version}, expected {_VERSION}." + ) + + if c._classifier_type not in cls._supported_classifier_choices(): + raise ValueError("Invalid classifier type") + + if c._classifier_file is None: + c._classifier_file = Path(path).name + c._classifier_hash = hash_file(Path(path)) + c._classifier_source = "pickle" + + return c + def load(self, path: Path): """load a classifier from a file diff --git a/src/jabs/classifier/multi_class_classifier.py b/src/jabs/classifier/multi_class_classifier.py index 872b56f8..0cfe4ba1 100644 --- a/src/jabs/classifier/multi_class_classifier.py +++ b/src/jabs/classifier/multi_class_classifier.py @@ -418,25 +418,78 @@ def load(self, path: Path) -> None: logger.info("MultiClassClassifier loaded from %s", path) @classmethod - def from_training_file(cls, path: Path) -> MultiClassClassifier: + def from_pickle(cls, path: Path) -> MultiClassClassifier: + """Load a MultiClassClassifier from a pickle file with full validation and metadata backfill. + + Applies the same version, classifier-type, and metadata checks as ``load()``, + but as a classmethod factory so no dummy instance is required. + + Args: + path: Path to the saved classifier pickle file. + + Returns: + Loaded and validated ``MultiClassClassifier`` instance. + + Raises: + ValueError: If the file is not a ``MultiClassClassifier``, was saved + with a different version, or uses an unsupported classifier type. + """ + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", InconsistentVersionWarning) + c = joblib.load(path) + for warning in caught_warnings: + if issubclass(warning.category, InconsistentVersionWarning): + raise ValueError("Classifier trained with different version of sklearn.") + else: + warnings.warn(warning.message, warning.category, stacklevel=2) + + if not isinstance(c, cls): + raise ValueError(f"{path} is not an instance of MultiClassClassifier") + + if c._version != _VERSION: + raise ValueError( + f"Unable to deserialize pickled classifier. " + f"File version {c._version}, expected {_VERSION}." + ) + + if c._classifier_type not in cls._supported_classifier_choices(): + raise ValueError("Invalid classifier type") + + if c._classifier_file is None: + c._classifier_file = Path(path).name + c._classifier_hash = hash_file(Path(path)) + c._classifier_source = "pickle" + + logger.info("MultiClassClassifier loaded from %s", path) + return c + + @classmethod + def from_training_file( + cls, path: Path, classifier_type: ClassifierType | None = None + ) -> MultiClassClassifier: """Train a new MultiClassClassifier from an exported training file. Args: path: Path to a multi-class training HDF5 file produced by ``export_training_data_multiclass()``. + classifier_type: Override the classifier algorithm stored in the training + file. If ``None``, the type recorded in the file is used. Returns: A freshly trained ``MultiClassClassifier`` instance. Raises: ValueError: If the file is not a valid multi-class training export or - the stored classifier type is unsupported in the current environment. + the effective classifier type is unsupported in the current environment. """ loaded, _ = load_multiclass_training_data(path) + effective_type = ( + classifier_type if classifier_type is not None else loaded["classifier_type"] + ) classifier = cls( behavior_names=loaded["behavior_names"], - classifier_type=loaded["classifier_type"], + classifier_type=effective_type, ) classifier.set_dict_settings(loaded["settings"]) classifier.train( diff --git a/src/jabs/feature_extraction/features.py b/src/jabs/feature_extraction/features.py index 1a747a99..45535afb 100644 --- a/src/jabs/feature_extraction/features.py +++ b/src/jabs/feature_extraction/features.py @@ -104,7 +104,7 @@ class IdentityFeatures: def __init__( self, - source_file: str, + source_file: str | Path, identity: int, directory: str | Path | None, pose_est: PoseEstimation, diff --git a/src/jabs/scripts/classify.py b/src/jabs/scripts/classify.py index d6492eea..9f27fa8c 100755 --- a/src/jabs/scripts/classify.py +++ b/src/jabs/scripts/classify.py @@ -9,18 +9,22 @@ import argparse import re import sys +import warnings from pathlib import Path +import h5py +import joblib import numpy as np import pandas as pd from rich.progress import BarColumn, Progress, TextColumn +from sklearn.exceptions import InconsistentVersionWarning -from jabs.classifier import Classifier -from jabs.core.constants import APP_NAME -from jabs.core.enums import CacheFormat +from jabs.classifier import Classifier, MultiClassClassifier +from jabs.core.constants import APP_NAME, MULTICLASS_NONE_BEHAVIOR +from jabs.core.enums import CacheFormat, ClassifierType from jabs.feature_extraction import IdentityFeatures from jabs.pose_estimation import open_pose_file -from jabs.project.prediction_manager import PredictionManager +from jabs.project.prediction_manager import MULTICLASS_PREDICTION_KEY, PredictionManager DEFAULT_FPS = 30 @@ -28,10 +32,20 @@ __CLASSIFIER_CHOICES = Classifier().classifier_choices() -def get_pose_stem(pose_path: Path): - """get the stem name of a pose file +def get_pose_stem(pose_path: Path) -> str: + """Get the stem name of a pose file. - takes a pose path as input and returns the name component with the '_pose_est_v#.h5' suffix removed + Takes a pose path as input and returns the name component with the + '_pose_est_v#.h5' suffix removed. + + Args: + pose_path: Path to the pose estimation file. + + Returns: + Stem portion of the filename without the pose suffix. + + Raises: + ValueError: If the path does not match the expected pose file naming convention. """ m = re.match(r"^(.+)(_pose_est_v[0-9]+\.h5)$", pose_path.name) if m: @@ -40,33 +54,124 @@ def get_pose_stem(pose_path: Path): raise ValueError(f"{pose_path} is not a valid pose file path") +def _load_classifier_from_pickle(path: Path) -> Classifier | MultiClassClassifier: + """Load a binary or multi-class classifier from a pickle file. + + Peeks at the deserialized type, then delegates to the class-specific + ``from_pickle()`` classmethod so that version checks, supported + classifier-type checks, and metadata backfill are applied consistently. + + Args: + path: Path to the saved classifier pickle file. + + Returns: + Loaded ``Classifier`` or ``MultiClassClassifier`` instance. + + Raises: + ValueError: If the file was trained with an incompatible sklearn or JABS + version, uses an unsupported classifier type, or contains an + unrecognized object type. + FileNotFoundError: If ``path`` does not exist. + PermissionError: If the file cannot be read. + Exception: Other exceptions raised by joblib/pickle during deserialization + (e.g. corrupt file). + """ + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always", InconsistentVersionWarning) + obj = joblib.load(path) + for w in caught_warnings: + if issubclass(w.category, InconsistentVersionWarning): + raise ValueError("Classifier trained with a different version of sklearn.") + warnings.warn(w.message, w.category, stacklevel=2) + + if isinstance(obj, MultiClassClassifier): + return MultiClassClassifier.from_pickle(path) + if isinstance(obj, Classifier): + return Classifier.from_pickle(path) + raise ValueError(f"Unrecognized classifier type in {path}: {type(obj).__name__}") + + +def _is_multiclass_training_file(path: Path) -> bool: + """Return True if the training file contains multi-class training data. + + Args: + path: Path to the training HDF5 file. + + Returns: + True if the file has ``classifier_mode == "multiclass"``, False otherwise. + """ + with h5py.File(path, "r") as f: + return f.attrs.get("classifier_mode", "") == "multiclass" + + +def train_multiclass( + training_file: Path, classifier_type: ClassifierType | None = None +) -> MultiClassClassifier: + """Train a multi-class classifier using the provided training file. + + Loads training data from the specified HDF5 file, initializes a + ``MultiClassClassifier``, and prints training details such as behavior names, + classifier type, window size, and other relevant settings. + + Args: + training_file: Path to the multi-class training HDF5 file exported by JABS. + classifier_type: Override the classifier algorithm stored in the training file. + If ``None``, the type recorded in the file is used. + + Returns: + Trained ``MultiClassClassifier`` instance. + """ + classifier = MultiClassClassifier.from_training_file( + training_file, classifier_type=classifier_type + ) + classifier_settings = classifier.project_settings + + print("Training multi-class classifier for:", ", ".join(classifier.behavior_names)) + print(f" Classifier Type: {classifier.classifier_name}") + print(f" Window Size: {classifier_settings['window_size']}") + print(f" Social: {classifier_settings['social']}") + print(f" Balanced Labels: {classifier_settings['balance_labels']}") + print(f" Symmetric Behavior: {classifier_settings['symmetric_behavior']}") + print(f" CM Units: {bool(classifier_settings['cm_units'])}") + + return classifier + + def train_and_classify( training_file_path: Path, input_pose_file: Path, out_dir: Path, - fps=DEFAULT_FPS, + fps: int = DEFAULT_FPS, feature_dir: str | None = None, cache_window: bool = False, use_pose_hash: bool = False, -): + classifier_type: ClassifierType | None = None, +) -> None: """Train a classifier using the provided training file and classify behaviors in a pose file. - Loads the training data, trains a classifier, and applies it to the input pose file to predict behaviors. - The classification results are saved to the specified output directory. + Loads the training data, trains a classifier, and applies it to the input pose file + to predict behaviors. The classification results are saved to the specified output directory. Args: - training_file_path (Path): Path to the training HDF5 file. - input_pose_file (Path): Path to the input pose HDF5 file to classify. - out_dir (Path): Directory to store classification output. - fps (int, optional): Frames per second for feature extraction. Defaults to DEFAULT_FPS. - feature_dir (str or None, optional): Directory for feature cache. If provided, features are cached here. - cache_window (bool, optional): Whether to cache window features. Defaults to False. - use_pose_hash (bool, optional): Include pose file hash as a subdirectory in the cache path. Defaults to False. + training_file_path: Path to the training HDF5 file. + input_pose_file: Path to the input pose HDF5 file to classify. + out_dir: Directory to store classification output. + fps: Frames per second for feature extraction. + feature_dir: Directory for feature cache. If provided, features are cached here. + cache_window: Whether to cache window features. + use_pose_hash: Include pose file hash as a subdirectory in the cache path. + classifier_type: Override the classifier algorithm stored in the training file. + If ``None``, the type recorded in the file is used. """ if not training_file_path.exists(): sys.exit("Unable to open training data\n") - classifier = train(training_file_path) + if _is_multiclass_training_file(training_file_path): + classifier: Classifier | MultiClassClassifier = train_multiclass( + training_file_path, classifier_type=classifier_type + ) + else: + classifier = train(training_file_path, classifier_type=classifier_type) classify_pose( classifier, input_pose_file, @@ -80,48 +185,75 @@ def train_and_classify( def classify_pose( - classifier: Classifier, + classifier: Classifier | MultiClassClassifier, input_pose_file: Path, out_dir: Path, - behavior: str, - fps=DEFAULT_FPS, + behavior: str | None = None, + fps: int = DEFAULT_FPS, feature_dir: str | None = None, cache_window: bool = False, use_pose_hash: bool = False, -): +) -> None: """Classify behaviors in a pose file using a trained classifier. - Loads pose data, extracts features for each identity, predicts behavior labels and probabilities, - and writes the results to an output HDF5 file. + Loads pose data, extracts features for each identity, predicts behavior labels + and probabilities, and writes the results to an output HDF5 file. + + For binary classifiers, ``behavior`` names the behavior being classified and is + used as the prediction record key. For multi-class classifiers, ``behavior`` is + ignored - the key is always ``MULTICLASS_PREDICTION_KEY`` and ``class_names`` + are populated from the classifier. Args: - classifier (Classifier): Trained classifier instance. - input_pose_file (Path): Path to the input pose HDF5 file. - out_dir (Path): Directory to store classification output. - behavior (str): Name of the behavior being classified. - fps (int, optional): Frames per second for feature extraction. Defaults to DEFAULT_FPS. - feature_dir (str or None, optional): Directory for feature cache. If provided, features are cached here. - cache_window (bool, optional): Whether to cache window features. Defaults to False. - use_pose_hash (bool, optional): Include pose file hash as a subdirectory in the cache path. Defaults to False. + classifier: Trained binary or multi-class classifier instance. + input_pose_file: Path to the input pose HDF5 file. + out_dir: Directory to store classification output. + behavior: Behavior name for binary classifiers. Ignored for multi-class. + fps: Frames per second for feature extraction. + feature_dir: Directory for feature cache. If provided, features are cached here. + cache_window: Whether to cache window features. + use_pose_hash: Include pose file hash as a subdirectory in the cache path. + + Raises: + ValueError: If a binary classifier is given but ``behavior`` is None. """ + multiclass = isinstance(classifier, MultiClassClassifier) + + if multiclass: + class_names: list[str] | None = [MULTICLASS_NONE_BEHAVIOR, *classifier.behavior_names] + behavior_key = MULTICLASS_PREDICTION_KEY + else: + if behavior is None: + raise ValueError("behavior is required for binary classifiers") + class_names = None + behavior_key = behavior + pose_est = open_pose_file(input_pose_file) pose_stem = get_pose_stem(input_pose_file) - # allocate numpy arrays to write to h5 file - prediction_labels = np.full((pose_est.num_identities, pose_est.num_frames), -1, dtype=np.int8) - prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32) + n_identities = pose_est.num_identities + n_frames = pose_est.num_frames + + prediction_labels = np.full((n_identities, n_frames), -1, dtype=np.int8) + if multiclass: + assert class_names is not None + n_classes = len(class_names) + prediction_prob: np.ndarray = np.zeros( + (n_identities, n_frames, n_classes), dtype=np.float32 + ) + else: + prediction_prob = np.zeros((n_identities, n_frames), dtype=np.float32) classifier_settings = classifier.project_settings print(f"Classifying {input_pose_file}...") - # run prediction for each identity with Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("{task.completed} of {task.total} identities"), ) as progress: - task = progress.add_task("Processing", total=pose_est.num_identities) + task = progress.add_task("Processing", total=n_identities) for curr_id in pose_est.identities: features = IdentityFeatures( input_pose_file, @@ -137,53 +269,53 @@ def classify_pose( per_frame_features = pd.DataFrame(features["per_frame"]) window_features = pd.DataFrame(features["window"]) - - data = Classifier.combine_data(per_frame_features, window_features) + data = classifier.combine_data(per_frame_features, window_features) if data.shape[0] > 0: - # predict probabilities and derive predictions - predictions, probabilities = classifier.derive_predictions( - classifier.predict_proba(data, features["frame_indexes"]) - ) - - # Copy results into results matrix + prob = classifier.predict_proba(data, features["frame_indexes"]) + predictions, confidence = classifier.derive_predictions(prob) prediction_labels[curr_id] = predictions - prediction_prob[curr_id] = probabilities + # Multiclass: persist full class-probability matrix (n_frames, n_classes). + # Binary: persist per-frame confidence scalar. + prediction_prob[curr_id] = prob if multiclass else confidence progress.update(task, advance=1) print(f"Writing predictions to {out_dir}") - behavior_out_dir = out_dir try: - behavior_out_dir.mkdir(parents=True, exist_ok=True) + out_dir.mkdir(parents=True, exist_ok=True) except OSError as e: sys.exit(f"Unable to create output directory: {e}") - behavior_out_path = behavior_out_dir / (pose_stem + "_behavior.h5") + + behavior_out_path = out_dir / (pose_stem + "_behavior.h5") PredictionManager.write_predictions( - behavior, + behavior_key, behavior_out_path, prediction_labels, prediction_prob, pose_est, classifier, + class_names=class_names, ) -def train(training_file: Path) -> Classifier: - """Train a classifier using the provided training file. +def train(training_file: Path, classifier_type: ClassifierType | None = None) -> Classifier: + """Train a binary classifier using the provided training file. Loads training data from the specified HDF5 file, initializes a classifier, and prints training details such as behavior name, classifier type, window size, and other relevant settings. Args: - training_file (Path): Path to the training HDF5 file exported by JABS. + training_file: Path to the training HDF5 file exported by JABS. + classifier_type: Override the classifier algorithm stored in the training file. + If ``None``, the type recorded in the file is used. Returns: - Classifier: The trained classifier instance. + Trained ``Classifier`` instance. """ - classifier = Classifier.from_training_file(training_file) + classifier = Classifier.from_training_file(training_file, classifier_type=classifier_type) classifier_settings = classifier.project_settings print("Training classifier for:", classifier.behavior_name) @@ -197,8 +329,8 @@ def train(training_file: Path) -> Classifier: return classifier -def main(): - """jabs-classify entrypoint. dispatch to different main functions depending on command specified""" +def main() -> None: + """jabs-classify entrypoint - dispatch to different main functions depending on command.""" if len(sys.argv) < 2: usage_main() elif sys.argv[1] == "classify": @@ -209,8 +341,8 @@ def main(): usage_main() -def usage_main(): - """print usage information for the script""" +def usage_main() -> None: + """Print usage information for the script.""" print("usage: " + script_name() + " COMMAND COMMAND_ARGS\n", file=sys.stderr) print("commands:", file=sys.stderr) print(" classify classify a pose file", file=sys.stderr) @@ -224,9 +356,8 @@ def usage_main(): ) -def classify_main(): - """implementation of the `jabs-classify classify` command""" - # strip out the 'command' from sys.argv +def classify_main() -> None: + """Implementation of the `jabs-classify classify` command.""" classify_args = sys.argv[2:] parser = argparse.ArgumentParser(prog=f"{script_name()} classify") @@ -254,7 +385,8 @@ def classify_main(): ) training_group.add_argument( "--classifier", - help=f"Classifier file produced from the `{script_name()} train` command", + help=f"Classifier file produced from the `{script_name()} train` command or saved " + "by the JABS GUI (binary .pickle or multi-class _multiclass.pickle)", ) required_args.add_argument( @@ -281,8 +413,9 @@ def classify_main(): parser.add_argument( "--skip-window-cache", help=( - "Default will cache all features when --feature-dir is provided. Providing this flag will only cache " - "per-frame features, reducing cache size at the cost of needing to re-calculate window features." + "Default will cache all features when --feature-dir is provided. Providing this flag " + "will only cache per-frame features, reducing cache size at the cost of needing to " + "re-calculate window features." ), default=False, action="store_true", @@ -313,27 +446,33 @@ def classify_main(): feature_dir=args.feature_dir, cache_window=not args.skip_window_cache, use_pose_hash=args.use_pose_hash, + classifier_type=args.classifier_type, ) elif args.classifier is not None: try: - classifier = Classifier() - classifier.load(Path(args.classifier)) - except ValueError as e: + classifier = _load_classifier_from_pickle(Path(args.classifier)) + except Exception as e: print(f"Unable to load classifier from {args.classifier}:") sys.exit(str(e)) - behavior = classifier.behavior_name classifier_settings = classifier.project_settings - print(f"Classifying using trained classifier: {args.classifier}") - try: - print(f" Classifier type: {__CLASSIFIER_CHOICES[classifier.classifier_type]}") - except KeyError: - sys.exit("Error: Classifier type not supported on this platform") - print(f" Behavior: {behavior}") - print(f" Window Size: {classifier_settings['window_size']}") - print(f" Social: {classifier_settings['social']}") - print(f" CM Units: {classifier_settings['cm_units']}") + + if isinstance(classifier, MultiClassClassifier): + print(" Mode: multi-class") + print(f" Behaviors: {', '.join(classifier.behavior_names)}") + print(f" Window Size: {classifier_settings['window_size']}") + behavior = None + else: + try: + print(f" Classifier type: {__CLASSIFIER_CHOICES[classifier.classifier_type]}") + except KeyError: + sys.exit("Error: Classifier type not supported on this platform") + behavior = classifier.behavior_name + print(f" Behavior: {behavior}") + print(f" Window Size: {classifier_settings['window_size']}") + print(f" Social: {classifier_settings['social']}") + print(f" CM Units: {classifier_settings['cm_units']}") classify_pose( classifier, @@ -347,9 +486,8 @@ def classify_main(): ) -def train_main(): - """implementation of the `jabs-classify train` command""" - # strip out the 'command' component from sys.argv +def train_main() -> None: + """Implementation of the `jabs-classify train` command.""" train_args = sys.argv[2:] parser = argparse.ArgumentParser(prog=f"{script_name()} train") @@ -357,14 +495,23 @@ def train_main(): parser.add_argument("out_file", help="output filename") args = parser.parse_args(train_args) - classifier = train(args.training_file) + training_path = Path(args.training_file) + + if not training_path.exists(): + sys.exit("Unable to open training data\n") + + trained: Classifier | MultiClassClassifier + if _is_multiclass_training_file(training_path): + trained = train_multiclass(training_path) + else: + trained = train(training_path) print(f"Saving trained classifier to '{args.out_file}'") - classifier.save(Path(args.out_file)) + trained.save(Path(args.out_file)) def script_name() -> str: - """return the script name""" + """Return the script name.""" return Path(sys.argv[0]).name diff --git a/src/jabs/ui/behavior_timeline/behavior_timeline_widget.py b/src/jabs/ui/behavior_timeline/behavior_timeline_widget.py index 4a9d5646..eb63e70b 100644 --- a/src/jabs/ui/behavior_timeline/behavior_timeline_widget.py +++ b/src/jabs/ui/behavior_timeline/behavior_timeline_widget.py @@ -272,6 +272,11 @@ def pose(self, pose_est: PoseEstimation) -> None: self._num_frames = pose_est.num_frames self._reset_layout() + @property + def multiclass_color_lut(self) -> npt.NDArray[np.uint8] | None: + """Return the multiclass color LUT, or None when in binary mode.""" + return self._multiclass_color_lut + def set_classifier_mode(self, mode: ClassifierMode, behavior_names: list[str]) -> None: """Set the classifier mode and rebuild the layout for multi-class or binary display. diff --git a/src/jabs/ui/colors.py b/src/jabs/ui/colors.py index f65bfaf3..64ef7623 100644 --- a/src/jabs/ui/colors.py +++ b/src/jabs/ui/colors.py @@ -39,6 +39,7 @@ (BACKGROUND_COLOR.redF(), BACKGROUND_COLOR.greenF(), BACKGROUND_COLOR.blueF()), (NOT_BEHAVIOR_COLOR.redF(), NOT_BEHAVIOR_COLOR.greenF(), NOT_BEHAVIOR_COLOR.blueF()), (BEHAVIOR_COLOR.redF(), BEHAVIOR_COLOR.greenF(), BEHAVIOR_COLOR.blueF()), + (0.0, 0.0, 0.0), # black is hard to read on dark backgrounds ] diff --git a/src/jabs/ui/main_window/central_widget.py b/src/jabs/ui/main_window/central_widget.py index 0e6376e7..54381d8a 100644 --- a/src/jabs/ui/main_window/central_widget.py +++ b/src/jabs/ui/main_window/central_widget.py @@ -274,16 +274,47 @@ def label_overlay_mode(self, mode: PlayerWidget.LabelOverlayMode) -> None: """ if mode != self._label_overlay_mode: self._label_overlay_mode = mode - # also update self._player_widget labels if mode == PlayerWidget.LabelOverlayMode.LABEL: - self._player_widget.set_labels( - [labels.get_labels() for labels in self._get_label_list()] - ) + if ( + self._project is not None + and self._labels is not None + and self._project.settings_manager.classifier_mode == ClassifierMode.MULTICLASS + ): + behavior_names = self._controls.behaviors + multiclass_arrays = [ + self._labels.build_multiclass_label_array(str(i), behavior_names) + for i in range(self._pose_est.num_identities) + ] + lut = self._jabs_timeline.multiclass_color_lut + if lut is not None: + self._player_widget.set_label_color_lut(lut) + self._player_widget.set_labels(multiclass_arrays) + else: + self._player_widget.set_label_color_lut(None) + self._player_widget.set_labels(None) + else: + self._player_widget.set_label_color_lut(None) + self._player_widget.set_labels( + [labels.get_labels() for labels in self._get_label_list()] + ) elif mode == PlayerWidget.LabelOverlayMode.PREDICTION: - # prediction_list, _ = self._get_prediction_list() - self._player_widget.set_labels(self._prediction_list) + if ( + self._project is not None + and self._loaded_video is not None + and self._project.settings_manager.classifier_mode == ClassifierMode.MULTICLASS + ): + lut = self._jabs_timeline.multiclass_color_lut + if lut is not None: + self._player_widget.set_label_color_lut(lut) + self._player_widget.set_labels(self._build_multiclass_overlay_labels()) + else: + self._player_widget.set_label_color_lut(None) + self._player_widget.set_labels(None) + else: + self._player_widget.set_label_color_lut(None) + self._player_widget.set_labels(self._prediction_list) else: - # if the player is set to show nothing, clear the labels + self._player_widget.set_label_color_lut(None) self._player_widget.set_labels(None) @property @@ -817,16 +848,19 @@ def _set_label_track(self) -> None: ] if self._project.settings_manager.classifier_mode == ClassifierMode.MULTICLASS: behavior_names = self._controls.behaviors - self._jabs_timeline.set_labels( - [ - self._labels.build_multiclass_label_array(str(i), behavior_names) - for i in range(self._pose_est.num_identities) - ], - mask_list, - ) + multiclass_arrays = [ + self._labels.build_multiclass_label_array(str(i), behavior_names) + for i in range(self._pose_est.num_identities) + ] + self._jabs_timeline.set_labels(multiclass_arrays, mask_list) if self._label_overlay_mode == PlayerWidget.LabelOverlayMode.LABEL: - label_list = self._get_label_list() - self._player_widget.set_labels([labels.get_labels() for labels in label_list]) + lut = self._jabs_timeline.multiclass_color_lut + if lut is not None: + self._player_widget.set_label_color_lut(lut) + self._player_widget.set_labels(multiclass_arrays) + else: + self._player_widget.set_label_color_lut(None) + self._player_widget.set_labels(None) else: label_list = self._get_label_list() self._jabs_timeline.set_labels( @@ -834,7 +868,7 @@ def _set_label_track(self) -> None: mask_list, ) if self._label_overlay_mode == PlayerWidget.LabelOverlayMode.LABEL: - # if configured to show labels, update the player widget with the new labels + self._player_widget.set_label_color_lut(None) self._player_widget.set_labels([labels.get_labels() for labels in label_list]) self._set_prediction_vis() @@ -1148,6 +1182,24 @@ def _update_classify_progress(self, step: int) -> None: return self._progress_dialog.setValue(step) + def _build_multiclass_overlay_labels(self) -> list[np.ndarray]: + """Build per-identity LUT-index arrays for the multiclass prediction overlay. + + Maps prediction class indices from ``self._predictions`` to LUT indices: + -1 (no pose) -> 0, 0..N-1 -> 1..N. + + Returns: + List of arrays, one per identity, with LUT indices for each frame. + """ + n_frames = self._player_widget.num_frames + overlay_labels = [] + for i in range(self._pose_est.num_identities): + arr = self._predictions.get(i) + if arr is None: + arr = np.full(n_frames, -1, dtype=np.int16) + overlay_labels.append(np.where(arr == -1, 0, np.asarray(arr, dtype=np.int16) + 1)) + return overlay_labels + def _set_prediction_vis(self) -> None: """update data being displayed by the prediction visualization widget""" if self._project is None or self._loaded_video is None: @@ -1157,8 +1209,13 @@ def _set_prediction_vis(self) -> None: predictions_rows, probabilities_rows = self._get_multiclass_prediction_rows() self._jabs_timeline.set_predictions(predictions_rows, probabilities_rows) if self._label_overlay_mode == PlayerWidget.LabelOverlayMode.PREDICTION: - # Multi-class frame overlay is tracked under T9; keep disabled for now. - self._player_widget.set_labels(None) + lut = self._jabs_timeline.multiclass_color_lut + if lut is not None: + self._player_widget.set_label_color_lut(lut) + self._player_widget.set_labels(self._build_multiclass_overlay_labels()) + else: + self._player_widget.set_label_color_lut(None) + self._player_widget.set_labels(None) return self._prediction_list, self._probability_list = self._get_prediction_list() @@ -1167,7 +1224,7 @@ def _set_prediction_vis(self) -> None: [[prob] for prob in self._probability_list], ) if self._label_overlay_mode == PlayerWidget.LabelOverlayMode.PREDICTION: - # if the player is set to show predictions, update the player widget + self._player_widget.set_label_color_lut(None) self._player_widget.set_labels(self._prediction_list) def _get_prediction_list(self) -> tuple[list[np.ndarray], list[np.ndarray]]: diff --git a/src/jabs/ui/player_widget/frame_with_overlays.py b/src/jabs/ui/player_widget/frame_with_overlays.py index e54bd777..16c2649c 100644 --- a/src/jabs/ui/player_widget/frame_with_overlays.py +++ b/src/jabs/ui/player_widget/frame_with_overlays.py @@ -1,6 +1,7 @@ import enum import numpy as np +import numpy.typing as npt from intervaltree import IntervalTree from PySide6 import QtCore, QtGui, QtWidgets @@ -79,6 +80,7 @@ def __init__(self, *args, **kwargs): self._pose_overlay_mode = self.PoseOverlayMode.NONE self._id_overlay_mode = self.IdentityOverlayMode.BBOX + self._label_color_lut: npt.NDArray[np.uint8] | None = None self._control_overlay = ControlOverlay(self) self._control_overlay.playback_speed_changed.connect(self.playback_speed_changed) @@ -243,6 +245,25 @@ def set_active_identity(self, identity: int) -> None: """ self._active_identity = identity + @property + def label_color_lut(self) -> npt.NDArray[np.uint8] | None: + """Return the per-class RGBA color LUT used by the label overlay, or None for binary mode.""" + return self._label_color_lut + + def set_label_color_lut(self, lut: npt.NDArray[np.uint8] | None) -> None: + """Set the color LUT for label overlay rendering. + + When set, the label overlay treats each label value as a LUT index and + looks up its RGBA color directly, enabling multi-class coloring. When + ``None``, the overlay falls back to the hardcoded binary color scheme. + + Args: + lut: RGBA array of shape ``(N, 4)`` dtype ``uint8``, or ``None`` to + restore binary-mode coloring. + """ + self._label_color_lut = lut + self.update() + def set_label_overlay(self, labels: list[np.ndarray]) -> None: """set label values to use for overlaying on the frame. diff --git a/src/jabs/ui/player_widget/overlays/label_overlay.py b/src/jabs/ui/player_widget/overlays/label_overlay.py index 6a9eba46..84c02738 100644 --- a/src/jabs/ui/player_widget/overlays/label_overlay.py +++ b/src/jabs/ui/player_widget/overlays/label_overlay.py @@ -80,13 +80,20 @@ def _overlay_labels(self, painter: QtGui.QPainter, crop_rect: QtCore.QRect) -> N behavior_y = widget_y - self._BEHAVIOR_LABEL_SIZE - match self.parent.labels[identity][self.parent.current_frame]: - case TrackLabels.Label.BEHAVIOR: - prediction_color = BEHAVIOR_COLOR - case TrackLabels.Label.NOT_BEHAVIOR: - prediction_color = NOT_BEHAVIOR_COLOR - case _: - prediction_color = BACKGROUND_COLOR + label_val = int(self.parent.labels[identity][self.parent.current_frame]) + lut = self.parent.label_color_lut + if lut is not None: + idx = max(0, min(label_val, len(lut) - 1)) + r, g, b, a = lut[idx] + prediction_color = QtGui.QColor(int(r), int(g), int(b), int(a)) + else: + match label_val: + case TrackLabels.Label.BEHAVIOR: + prediction_color = BEHAVIOR_COLOR + case TrackLabels.Label.NOT_BEHAVIOR: + prediction_color = NOT_BEHAVIOR_COLOR + case _: + prediction_color = BACKGROUND_COLOR painter.setBrush(prediction_color) painter.setPen(self._BEHAVIOR_LABEL_OUTLINE_COLOR) diff --git a/src/jabs/ui/player_widget/player_widget.py b/src/jabs/ui/player_widget/player_widget.py index ee53eb85..99745aca 100644 --- a/src/jabs/ui/player_widget/player_widget.py +++ b/src/jabs/ui/player_widget/player_widget.py @@ -582,6 +582,17 @@ def set_active_identity(self, identity: int) -> None: self._player_thread.setActiveIdentity.emit(identity) self.reload_frame() + def set_label_color_lut(self, lut: np.ndarray | None) -> None: + """Set the color LUT for the label overlay. + + When set, each label value is treated as a LUT index for RGBA color + lookup. Pass ``None`` to restore binary-mode coloring. + + Args: + lut: RGBA array of shape ``(N, 4)`` dtype ``uint8``, or ``None``. + """ + self._frame_widget.set_label_color_lut(lut) + def set_labels(self, labels: list[np.ndarray] | None) -> None: """set labels used for overlay in the frame widget