diff --git a/docs/user-guide/classifier-types.md b/docs/user-guide/classifier-types.md index ddefbd08..ce2a7443 100644 --- a/docs/user-guide/classifier-types.md +++ b/docs/user-guide/classifier-types.md @@ -2,6 +2,8 @@ JABS supports three machine learning classifier types: **Random Forest**, **CatBoost**, and **XGBoost**. Each has different characteristics that may make it more suitable for your specific use case. +> **Classifier type vs. classifier mode:** This page covers the machine learning *algorithm*. Separately, JABS can train one binary classifier per behavior (the default) or a single classifier across all behaviors at once. See [Multi-Class Classification (Preview)](multi-class.md) for the experimental multi-class mode. + ## Random Forest (Default) Random Forest is the default classifier and a good starting point for most users. diff --git a/docs/user-guide/multi-class.md b/docs/user-guide/multi-class.md new file mode 100644 index 00000000..79404456 --- /dev/null +++ b/docs/user-guide/multi-class.md @@ -0,0 +1,64 @@ +# Multi-Class Classification (Preview) + +> **Preview feature.** Multi-class mode is under active development and is provided +> as a preview. Some capabilities are not yet available, and its behavior, stored +> data, and settings may change in upcoming JABS releases. Binary mode (the +> default) is unaffected. + +## Overview + +By default, JABS trains one **binary** classifier per behavior: each classifier +predicts whether a given frame contains that behavior or not, and behaviors are +independent of one another. + +**Multi-class mode** instead trains a *single* classifier across all annotated +behaviors at once. Each frame is assigned to exactly one class: one of your +behaviors, or the reserved **None** (background) class. This is appropriate when +your behaviors are **mutually exclusive** - that is, an animal cannot be doing two +of them on the same frame. + +## Enabling multi-class mode + +Open **Project Settings** and set **Classifier Mode** to **Multi-class (Preview)**. +The setting is stored with the project, and the default for all projects remains +**Binary**. + +Switching an existing project to multi-class mode is blocked if any frames are +labeled with two or more behaviors simultaneously; JABS lists the conflicting +videos so the overlaps can be resolved first. + +## Labeling for multi-class + +- Label each behavior as usual. Because classes are mutually exclusive, labeling a + frame with one behavior clears any other behavior label on that frame. +- The **None** button records an explicit *background* label - frames that are + none of your behaviors. In multi-class mode these explicit negatives are stored + on a reserved **None** track rather than as "not behavior" on an individual + behavior. The **Label Summary** reflects this: it shows the selected behavior's + count and a **None** count (instead of "Behavior" / "Not Behavior"). +- Only explicitly labeled frames (a behavior or **None**) are used for training; + unlabeled frames are ignored. + +## Known limitations (preview) + +- **No prediction post-processing.** The post-processing step available for binary + predictions is not yet applied to multi-class predictions. Multi-class + predictions are shown and saved as raw (argmax) results only. +- **Project-level training settings.** Window size and label balancing apply at the + project level for the single shared classifier rather than per behavior. Some + per-behavior options available in binary mode (for example, selective symmetric + augmentation per behavior) are not yet available in multi-class mode. +- **Mutual exclusivity required.** Behaviors must not overlap on the same frame. + Overlapping labels must be resolved before switching to multi-class mode or + training. +- **Migration.** Existing binary classifiers are not converted to multi-class + format (or vice versa); the two modes maintain separate classifier and + prediction files within a project. +- **Format stability.** The on-disk representation and available settings for + multi-class mode may change in future releases. + +## Command-line use + +`jabs-classify` auto-detects whether a saved classifier is binary or multi-class +and dispatches accordingly, so existing command-line workflows continue to work +with multi-class classifiers without additional flags. \ No newline at end of file diff --git a/src/jabs/resources/docs/user_guide/classifier-types.md b/src/jabs/resources/docs/user_guide/classifier-types.md index ddefbd08..ce2a7443 100644 --- a/src/jabs/resources/docs/user_guide/classifier-types.md +++ b/src/jabs/resources/docs/user_guide/classifier-types.md @@ -2,6 +2,8 @@ JABS supports three machine learning classifier types: **Random Forest**, **CatBoost**, and **XGBoost**. Each has different characteristics that may make it more suitable for your specific use case. +> **Classifier type vs. classifier mode:** This page covers the machine learning *algorithm*. Separately, JABS can train one binary classifier per behavior (the default) or a single classifier across all behaviors at once. See [Multi-Class Classification (Preview)](multi-class.md) for the experimental multi-class mode. + ## Random Forest (Default) Random Forest is the default classifier and a good starting point for most users. diff --git a/src/jabs/resources/docs/user_guide/multi-class.md b/src/jabs/resources/docs/user_guide/multi-class.md new file mode 100644 index 00000000..79404456 --- /dev/null +++ b/src/jabs/resources/docs/user_guide/multi-class.md @@ -0,0 +1,64 @@ +# Multi-Class Classification (Preview) + +> **Preview feature.** Multi-class mode is under active development and is provided +> as a preview. Some capabilities are not yet available, and its behavior, stored +> data, and settings may change in upcoming JABS releases. Binary mode (the +> default) is unaffected. + +## Overview + +By default, JABS trains one **binary** classifier per behavior: each classifier +predicts whether a given frame contains that behavior or not, and behaviors are +independent of one another. + +**Multi-class mode** instead trains a *single* classifier across all annotated +behaviors at once. Each frame is assigned to exactly one class: one of your +behaviors, or the reserved **None** (background) class. This is appropriate when +your behaviors are **mutually exclusive** - that is, an animal cannot be doing two +of them on the same frame. + +## Enabling multi-class mode + +Open **Project Settings** and set **Classifier Mode** to **Multi-class (Preview)**. +The setting is stored with the project, and the default for all projects remains +**Binary**. + +Switching an existing project to multi-class mode is blocked if any frames are +labeled with two or more behaviors simultaneously; JABS lists the conflicting +videos so the overlaps can be resolved first. + +## Labeling for multi-class + +- Label each behavior as usual. Because classes are mutually exclusive, labeling a + frame with one behavior clears any other behavior label on that frame. +- The **None** button records an explicit *background* label - frames that are + none of your behaviors. In multi-class mode these explicit negatives are stored + on a reserved **None** track rather than as "not behavior" on an individual + behavior. The **Label Summary** reflects this: it shows the selected behavior's + count and a **None** count (instead of "Behavior" / "Not Behavior"). +- Only explicitly labeled frames (a behavior or **None**) are used for training; + unlabeled frames are ignored. + +## Known limitations (preview) + +- **No prediction post-processing.** The post-processing step available for binary + predictions is not yet applied to multi-class predictions. Multi-class + predictions are shown and saved as raw (argmax) results only. +- **Project-level training settings.** Window size and label balancing apply at the + project level for the single shared classifier rather than per behavior. Some + per-behavior options available in binary mode (for example, selective symmetric + augmentation per behavior) are not yet available in multi-class mode. +- **Mutual exclusivity required.** Behaviors must not overlap on the same frame. + Overlapping labels must be resolved before switching to multi-class mode or + training. +- **Migration.** Existing binary classifiers are not converted to multi-class + format (or vice versa); the two modes maintain separate classifier and + prediction files within a project. +- **Format stability.** The on-disk representation and available settings for + multi-class mode may change in future releases. + +## Command-line use + +`jabs-classify` auto-detects whether a saved classifier is binary or multi-class +and dispatches accordingly, so existing command-line workflows continue to work +with multi-class classifiers without additional flags. \ No newline at end of file diff --git a/src/jabs/ui/dialogs/user_guide_dialog.py b/src/jabs/ui/dialogs/user_guide_dialog.py index 7a12a008..32cd8087 100644 --- a/src/jabs/ui/dialogs/user_guide_dialog.py +++ b/src/jabs/ui/dialogs/user_guide_dialog.py @@ -205,6 +205,7 @@ def _build_tree(self) -> None: "Feature File": "file-formats.md#feature-file", }, "Choosing a Classifier": "classifier-types.md", + "Multi-Class Classification (Preview)": "multi-class.md", "Post-Processing": "postprocessing.md", "Features Reference": "features.md", "Keyboard Shortcuts Reference": "keyboard-shortcuts.md", diff --git a/src/jabs/ui/main_control_widget/label_count_widget.py b/src/jabs/ui/main_control_widget/label_count_widget.py index 9cf6fc8d..992d8da2 100644 --- a/src/jabs/ui/main_control_widget/label_count_widget.py +++ b/src/jabs/ui/main_control_widget/label_count_widget.py @@ -47,6 +47,16 @@ def __init__(self, *args, **kwargs): frame_header = QtWidgets.QLabel("Frames") bout_header = QtWidgets.QLabel("Bouts") + # Row-header labels for the positive ("Behavior") and negative + # ("Not Behavior") classes. Kept as instance attributes so the text can + # be retitled per classifier mode (e.g. the selected behavior name and + # "None" in multi-class mode); see set_class_labels(). + self._positive_row_labels = [QtWidgets.QLabel("Behavior"), QtWidgets.QLabel("Behavior")] + self._negative_row_labels = [ + QtWidgets.QLabel("Not Behavior"), + QtWidgets.QLabel("Not Behavior"), + ] + layout = QtWidgets.QGridLayout() layout.setSpacing(2) layout.setContentsMargins(0, 0, 0, 0) @@ -55,13 +65,13 @@ def __init__(self, *args, **kwargs): layout.addWidget(frame_header, 0, 0, 1, 3, alignment=Qt.AlignmentFlag.AlignCenter) layout.addWidget(QtWidgets.QLabel("Subject"), 1, 1, alignment=Qt.AlignmentFlag.AlignRight) layout.addWidget(QtWidgets.QLabel("Total"), 1, 2, alignment=Qt.AlignmentFlag.AlignRight) - layout.addWidget(QtWidgets.QLabel("Behavior"), 2, 0) - layout.addWidget(QtWidgets.QLabel("Not Behavior"), 3, 0) + layout.addWidget(self._positive_row_labels[0], 2, 0) + layout.addWidget(self._negative_row_labels[0], 3, 0) layout.addWidget(bout_header, 4, 0, 1, 3, alignment=Qt.AlignmentFlag.AlignCenter) layout.addWidget(QtWidgets.QLabel("Subject"), 5, 1, alignment=Qt.AlignmentFlag.AlignRight) layout.addWidget(QtWidgets.QLabel("Total"), 5, 2, alignment=Qt.AlignmentFlag.AlignRight) - layout.addWidget(QtWidgets.QLabel("Behavior"), 6, 0) - layout.addWidget(QtWidgets.QLabel("Not Behavior"), 7, 0) + layout.addWidget(self._positive_row_labels[1], 6, 0) + layout.addWidget(self._negative_row_labels[1], 7, 0) # add labels containing counts to grid layout.addWidget( @@ -115,6 +125,21 @@ def __init__(self, *args, **kwargs): self.setLayout(layout) + def set_class_labels(self, positive_label: str, negative_label: str) -> None: + """Retitle the positive/negative row headers. + + Binary mode uses "Behavior"/"Not Behavior"; multi-class mode uses the + selected behavior name and the reserved background class name ("None"). + + Args: + positive_label: text for the behavior (positive) rows. + negative_label: text for the not-behavior (negative) rows. + """ + for label in self._positive_row_labels: + label.setText(positive_label) + for label in self._negative_row_labels: + label.setText(negative_label) + def set_counts( self, frame_behavior_current, diff --git a/src/jabs/ui/main_control_widget/main_control_widget.py b/src/jabs/ui/main_control_widget/main_control_widget.py index 2d992de6..cf093023 100644 --- a/src/jabs/ui/main_control_widget/main_control_widget.py +++ b/src/jabs/ui/main_control_widget/main_control_widget.py @@ -408,6 +408,10 @@ def set_classifier_selection(self, classifier_type): # unable to use the classifier pass + def set_label_summary_class_labels(self, positive_label: str, negative_label: str) -> None: + """retitle the label summary's positive/negative row headers""" + self._frame_counts.set_class_labels(positive_label, negative_label) + def set_frame_counts( self, label_behavior_current, diff --git a/src/jabs/ui/main_window/central_widget.py b/src/jabs/ui/main_window/central_widget.py index ec2c0801..07eefcb0 100644 --- a/src/jabs/ui/main_window/central_widget.py +++ b/src/jabs/ui/main_window/central_widget.py @@ -155,6 +155,9 @@ def __init__(self, *args, **kwargs) -> None: self._progress_dialog = None self._counts = None + # project-wide counts for the reserved "None" background track, used to + # source the negative-class row of the label summary in multi-class mode + self._none_counts = None self._bouts_behavior = 0 self._bouts_not_behavior = 0 @@ -355,6 +358,11 @@ def set_project(self, project: Project) -> None: self._labels = None self._loaded_video = None + # The reserved "None"-track counts are behavior-independent, so they are + # cached across behavior changes and only invalidated when the project + # changes (they belong to the previous project here). + self._none_counts = None + self._controls.update_project_settings(project.settings) self._controls.set_classifier_mode(project.settings_manager.classifier_mode) self._search_bar_widget.update_project(project) @@ -1394,11 +1402,24 @@ def _update_label_counts(self) -> None: if self._loaded_video is None: return + multiclass = self._project.settings_manager.classifier_mode == ClassifierMode.MULTICLASS + # update counts for the current video self._counts[self._loaded_video.name] = self._project.load_counts( self._loaded_video.name, self.behavior ) + # In multi-class mode the negative class shown in the label summary is + # the reserved "None" background track (explicit negatives are stored + # there), not the selected behavior's NOT_BEHAVIOR labels. Load and + # refresh those counts in parallel with the behavior counts. + if multiclass: + if self._none_counts is None: + self._none_counts = self._project.counts(MULTICLASS_NONE_BEHAVIOR) + self._none_counts[self._loaded_video.name] = self._project.load_counts( + self._loaded_video.name, MULTICLASS_NONE_BEHAVIOR + ) + current_identity = self._controls.current_identity_index label_behavior_current = 0 @@ -1411,17 +1432,42 @@ def _update_label_counts(self) -> None: bout_not_behavior_project = 0 for video, video_counts in self._counts.items(): + none_video_counts = self._none_counts.get(video, {}) if multiclass else {} for identity, counts in video_counts.items(): - label_behavior_project += counts["unfragmented_frame_counts"][0] - label_not_behavior_project += counts["unfragmented_frame_counts"][1] - bout_behavior_project += counts["unfragmented_bout_counts"][0] - bout_not_behavior_project += counts["unfragmented_bout_counts"][1] + behavior_frames = counts["unfragmented_frame_counts"][0] + behavior_bouts = counts["unfragmented_bout_counts"][0] + + if multiclass: + # negative class = BEHAVIOR labels on the "None" track + none_counts = none_video_counts.get(identity) + negative_frames = ( + none_counts["unfragmented_frame_counts"][0] if none_counts else 0 + ) + negative_bouts = ( + none_counts["unfragmented_bout_counts"][0] if none_counts else 0 + ) + else: + # negative class = NOT_BEHAVIOR labels on the selected behavior + negative_frames = counts["unfragmented_frame_counts"][1] + negative_bouts = counts["unfragmented_bout_counts"][1] + + label_behavior_project += behavior_frames + label_not_behavior_project += negative_frames + bout_behavior_project += behavior_bouts + bout_not_behavior_project += negative_bouts if video == self._loaded_video.name and identity == current_identity: - label_behavior_current = counts["unfragmented_frame_counts"][0] - label_not_behavior_current = counts["unfragmented_frame_counts"][1] - bout_behavior_current = counts["unfragmented_bout_counts"][0] - bout_not_behavior_current = counts["unfragmented_bout_counts"][1] + label_behavior_current = behavior_frames + label_not_behavior_current = negative_frames + bout_behavior_current = behavior_bouts + bout_not_behavior_current = negative_bouts + + # retitle the summary rows: behavior name / "None" in multi-class mode, + # the standard "Behavior" / "Not Behavior" otherwise + if multiclass: + self._controls.set_label_summary_class_labels(self.behavior, MULTICLASS_NONE_BEHAVIOR) + else: + self._controls.set_label_summary_class_labels("Behavior", "Not Behavior") self._controls.set_frame_counts( label_behavior_current, diff --git a/tests/ui/test_label_count_widget.py b/tests/ui/test_label_count_widget.py new file mode 100644 index 00000000..c64cc656 --- /dev/null +++ b/tests/ui/test_label_count_widget.py @@ -0,0 +1,61 @@ +import pytest + +try: + from PySide6.QtWidgets import QApplication + + from jabs.ui.main_control_widget.label_count_widget import FrameLabelCountWidget + + SKIP_UI_TESTS = False + SKIP_REASON = None +except ImportError as e: + SKIP_UI_TESTS = True + SKIP_REASON = f"Qt/UI dependencies not available: {e}" + +pytestmark = pytest.mark.skipif( + SKIP_UI_TESTS, + reason=SKIP_REASON if SKIP_UI_TESTS else "", +) + + +@pytest.fixture(scope="module", autouse=True) +def qapp(): + """Ensure a QApplication exists for widget tests.""" + app = QApplication.instance() + if app is None: + app = QApplication([]) + yield app + + +def test_default_class_labels() -> None: + """The summary defaults to the binary-mode row headers.""" + widget = FrameLabelCountWidget() + + assert [lbl.text() for lbl in widget._positive_row_labels] == ["Behavior", "Behavior"] + assert [lbl.text() for lbl in widget._negative_row_labels] == [ + "Not Behavior", + "Not Behavior", + ] + + +def test_set_class_labels_retitles_both_rows() -> None: + """set_class_labels updates the frame and bout row headers for both classes.""" + widget = FrameLabelCountWidget() + + widget.set_class_labels("Walk", "None") + + assert [lbl.text() for lbl in widget._positive_row_labels] == ["Walk", "Walk"] + assert [lbl.text() for lbl in widget._negative_row_labels] == ["None", "None"] + + +def test_set_class_labels_can_restore_defaults() -> None: + """Switching back to binary wording restores the standard headers.""" + widget = FrameLabelCountWidget() + + widget.set_class_labels("Walk", "None") + widget.set_class_labels("Behavior", "Not Behavior") + + assert [lbl.text() for lbl in widget._positive_row_labels] == ["Behavior", "Behavior"] + assert [lbl.text() for lbl in widget._negative_row_labels] == [ + "Not Behavior", + "Not Behavior", + ]