Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions src/jabs/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ def __init__(self, classifier: ClassifierType = ClassifierType.RANDOM_FOREST, n_
raise ValueError("Invalid classifier type")

@classmethod
def from_training_file(cls, path: Path):
def from_training_file(cls, path: Path, classifier_type: ClassifierType | None = None):
"""Initialize a classifier from an exported training data file.

This method will load the training data and train a classifier.

Args:
path: exported training data file
classifier_type: Override the classifier algorithm stored in the training
file. If ``None``, the type recorded in the file is used.

Returns:
trained classifier object
Expand All @@ -84,12 +86,13 @@ def from_training_file(cls, path: Path):
classifier = cls()
classifier.behavior_name = behavior
classifier.set_dict_settings(loaded_training_data["settings"])
classifier_type = ClassifierType(loaded_training_data["classifier_type"])
if classifier_type in classifier._supported_classifiers:
classifier.set_classifier(classifier_type)
file_classifier_type = ClassifierType(loaded_training_data["classifier_type"])
effective_type = classifier_type if classifier_type is not None else file_classifier_type
if effective_type in classifier._supported_classifiers:
classifier.set_classifier(effective_type)
else:
logging.warning(
f"Specified classifier type {classifier_type.name} is unavailable, using default: {classifier.classifier_type.name}"
f"Specified classifier type {effective_type.name} is unavailable, using default: {classifier.classifier_type.name}"
)
training_features = classifier.combine_data(
loaded_training_data["per_frame"], loaded_training_data["window"]
Expand Down Expand Up @@ -423,6 +426,51 @@ def save(self, path: Path):
self._classifier_hash = hash_file(Path(path))
self._classifier_source = "serialized"

@classmethod
def from_pickle(cls, path: Path) -> "Classifier":
"""Load a Classifier from a pickle file with full validation and metadata backfill.

Applies the same version, classifier-type, and metadata checks as ``load()``,
but as a classmethod factory so no dummy instance is required.

Args:
path: Path to the saved classifier pickle file.

Returns:
Loaded and validated ``Classifier`` instance.

Raises:
ValueError: If the file is not a ``Classifier``, was trained with an
incompatible sklearn or JABS version, or uses an unsupported
classifier type.
"""
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", InconsistentVersionWarning)
c = joblib.load(path)
for warning in caught_warnings:
if issubclass(warning.category, InconsistentVersionWarning):
raise ValueError("Classifier trained with different version of sklearn.")
else:
warnings.warn(warning.message, warning.category, stacklevel=2)

if not isinstance(c, cls):
raise ValueError(f"{path} is not an instance of Classifier")

if c.version != _VERSION:
raise ValueError(
f"Unable to deserialize pickled classifier. File version {c.version}, expected {_VERSION}."
)

if c._classifier_type not in cls._supported_classifier_choices():
raise ValueError("Invalid classifier type")

if c._classifier_file is None:
c._classifier_file = Path(path).name
c._classifier_hash = hash_file(Path(path))
c._classifier_source = "pickle"

return c

def load(self, path: Path):
"""load a classifier from a file

Expand Down
59 changes: 56 additions & 3 deletions src/jabs/classifier/multi_class_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,25 +418,78 @@ def load(self, path: Path) -> None:
logger.info("MultiClassClassifier loaded from %s", path)

@classmethod
def from_training_file(cls, path: Path) -> MultiClassClassifier:
def from_pickle(cls, path: Path) -> MultiClassClassifier:
"""Load a MultiClassClassifier from a pickle file with full validation and metadata backfill.

Applies the same version, classifier-type, and metadata checks as ``load()``,
but as a classmethod factory so no dummy instance is required.

Args:
path: Path to the saved classifier pickle file.

Returns:
Loaded and validated ``MultiClassClassifier`` instance.

Raises:
ValueError: If the file is not a ``MultiClassClassifier``, was saved
with a different version, or uses an unsupported classifier type.
"""
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always", InconsistentVersionWarning)
c = joblib.load(path)
for warning in caught_warnings:
if issubclass(warning.category, InconsistentVersionWarning):
raise ValueError("Classifier trained with different version of sklearn.")
else:
warnings.warn(warning.message, warning.category, stacklevel=2)

if not isinstance(c, cls):
raise ValueError(f"{path} is not an instance of MultiClassClassifier")

if c._version != _VERSION:
raise ValueError(
f"Unable to deserialize pickled classifier. "
f"File version {c._version}, expected {_VERSION}."
)

if c._classifier_type not in cls._supported_classifier_choices():
raise ValueError("Invalid classifier type")

if c._classifier_file is None:
c._classifier_file = Path(path).name
c._classifier_hash = hash_file(Path(path))
c._classifier_source = "pickle"

logger.info("MultiClassClassifier loaded from %s", path)
return c

@classmethod
def from_training_file(
cls, path: Path, classifier_type: ClassifierType | None = None
) -> MultiClassClassifier:
"""Train a new MultiClassClassifier from an exported training file.

Args:
path: Path to a multi-class training HDF5 file produced by
``export_training_data_multiclass()``.
classifier_type: Override the classifier algorithm stored in the training
file. If ``None``, the type recorded in the file is used.

Returns:
A freshly trained ``MultiClassClassifier`` instance.

Raises:
ValueError: If the file is not a valid multi-class training export or
the stored classifier type is unsupported in the current environment.
the effective classifier type is unsupported in the current environment.
"""
loaded, _ = load_multiclass_training_data(path)

effective_type = (
classifier_type if classifier_type is not None else loaded["classifier_type"]
)
classifier = cls(
behavior_names=loaded["behavior_names"],
classifier_type=loaded["classifier_type"],
classifier_type=effective_type,
)
classifier.set_dict_settings(loaded["settings"])
classifier.train(
Expand Down
2 changes: 1 addition & 1 deletion src/jabs/feature_extraction/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class IdentityFeatures:

def __init__(
self,
source_file: str,
source_file: str | Path,
identity: int,
directory: str | Path | None,
pose_est: PoseEstimation,
Expand Down
Loading
Loading