Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/metrax/classification_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def _default_threshold(num_thresholds: int) -> jax.Array:
return thresholds


def _convert_logits_to_probabilities(
predictions: jax.Array, from_logits: bool )-> jax.Array:
"""Converts logits to probabilities if `from_logits` is True
Args:
predictions: JAX array of predicted values, expected to be logits if `from_logits` is True.
from_logits: Boolean indicating whether `predictions` are logits.
Returns:
JAX array of probabilities if `from_logits` is True, otherwise returns `predictions` unchanged.
"""

if from_logits:
predictions = jax.nn.softmax(predictions, axis=-1)
# Assuming binary classification, take the positive class probability.


return predictions

@flax.struct.dataclass
class Accuracy(base.Average):
r"""Computes accuracy, which is the frequency with which `predictions` match `labels`.
Expand All @@ -69,6 +86,7 @@ def from_model_output(
predictions: jax.Array,
labels: jax.Array,
sample_weights: jax.Array | None = None,
from_logits: bool = False,
) -> 'Accuracy':
"""Updates the metric state with new `predictions` and `labels`.

Expand Down Expand Up @@ -99,6 +117,12 @@ def from_model_output(
comparison, or if `sample_weights` cannot be broadcast to `labels`'
shape.
"""

if from_logits:
predictions = jax.nn.softmax(predictions, axis=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inconsistent with the other metrics where _convert_logits_to_probabilities is called




correct = predictions == labels
count = jnp.ones_like(labels, dtype=jnp.int32)
if sample_weights is not None:
Expand Down Expand Up @@ -149,6 +173,7 @@ def from_model_output(
predictions: jax.Array,
labels: jax.Array,
threshold: float = 0.5,
from_logits: bool = False,
) -> 'Precision':
"""Updates the metric.

Expand All @@ -166,7 +191,10 @@ def from_model_output(
ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible.
"""
predictions = _convert_logits_to_probabilities(predictions, from_logits)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update these so it only calls _convert_logits_to_probabilities if from_logits is true?


predictions = jnp.where(predictions >= threshold, 1, 0)

true_positives = jnp.sum((predictions == 1) & (labels == 1))
false_positives = jnp.sum((predictions == 1) & (labels == 0))

Expand Down Expand Up @@ -219,7 +247,7 @@ def empty(cls) -> 'Recall':

@classmethod
def from_model_output(
cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5
cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, from_logits: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the docstrings (here and below)?

) -> 'Recall':
"""Updates the metric.

Expand All @@ -237,6 +265,8 @@ def from_model_output(
ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible.
"""
predictions = _convert_logits_to_probabilities(predictions, from_logits)

predictions = jnp.where(predictions >= threshold, 1, 0)
true_positives = jnp.sum((predictions == 1) & (labels == 1))
false_negatives = jnp.sum((predictions == 0) & (labels == 1))
Expand Down Expand Up @@ -325,6 +355,7 @@ def from_model_output(
labels: jax.Array,
sample_weights: jax.Array | None = None,
num_thresholds: int = 200,
from_logits: bool = False,
) -> 'AUCPR':
"""Updates the metric.

Expand All @@ -345,6 +376,8 @@ def from_model_output(
ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible.
"""
predictions = _convert_logits_to_probabilities(predictions, from_logits)

pred_is_pos = jnp.greater(
predictions,
_default_threshold(num_thresholds=num_thresholds)[..., None],
Expand Down Expand Up @@ -513,6 +546,7 @@ def from_model_output(
labels: jax.Array,
sample_weights: jax.Array | None = None,
num_thresholds: int = 200,
from_logits: bool = False,
) -> 'AUCROC':
"""Updates the metric.

Expand All @@ -533,6 +567,8 @@ def from_model_output(
ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible.
"""
predictions = _convert_logits_to_probabilities(predictions, from_logits)

pred_is_pos = jnp.greater(
predictions,
_default_threshold(num_thresholds=num_thresholds)[..., None],
Expand Down Expand Up @@ -622,7 +658,8 @@ def from_model_output(
predictions: jax.Array,
labels: jax.Array,
beta = beta,
threshold = 0.5,) -> 'FBetaScore':
threshold = 0.5,
from_logits : bool = False) -> 'FBetaScore':
"""Updates the metric.
Note: When only predictions and labels are given, the score calculated
is the F1 score if the FBetaScore beta value has not been previously modified.
Expand Down Expand Up @@ -656,6 +693,10 @@ def from_model_output(
if threshold < 0.0 or threshold > 1.0:
raise ValueError('The "Threshold" value must be between 0 and 1.')

# If the predictions are logits, convert them to probabilities

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra newline

predictions = _convert_logits_to_probabilities(predictions, from_logits)

# Modify predictions with the given threshold value
predictions = jnp.where(predictions >= threshold, 1, 0)

Expand Down
81 changes: 59 additions & 22 deletions src/metrax/classification_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Tests for metrax classification metrics."""

import os

import jax
os.environ['KERAS_BACKEND'] = 'jax'

from absl.testing import absltest
Expand All @@ -34,6 +36,7 @@
).astype(np.float32)
OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE))
OUTPUT_PREDS_F16 = OUTPUT_PREDS.astype(jnp.float16)
OUTPUT_LOGITS_F16 = np.random.randn(BATCHES, BATCH_SIZE).astype(jnp.float16)
OUTPUT_PREDS_F32 = OUTPUT_PREDS.astype(jnp.float32)
OUTPUT_PREDS_BF16 = OUTPUT_PREDS.astype(jnp.bfloat16)
OUTPUT_LABELS_BS1 = np.random.randint(
Expand Down Expand Up @@ -92,8 +95,9 @@ def test_fbeta_empty(self):
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
('batch_size_logits_f16', OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS,True),
)
def test_accuracy(self, y_true, y_pred, sample_weights):
def test_accuracy(self, y_true, y_pred, sample_weights, from_logits=False):
"""Test that `Accuracy` metric computes correct values."""
if sample_weights is None:
sample_weights = np.ones_like(y_true)
Expand All @@ -104,6 +108,7 @@ def test_accuracy(self, y_true, y_pred, sample_weights):
predictions=logits,
labels=labels,
sample_weights=weights,
from_logits=from_logits,
)
metrax_accuracy = metrax_accuracy.merge(update)
keras_accuracy.update_state(labels, logits, weights)
Expand All @@ -120,6 +125,7 @@ def test_accuracy(self, y_true, y_pred, sample_weights):

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5),
('basic_f16_logits', OUTPUT_LABELS, OUTPUT_LOGITS_F16, 0.5, True),
('high_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7),
('low_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5),
Expand All @@ -130,12 +136,17 @@ def test_accuracy(self, y_true, y_pred, sample_weights):
('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1),
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5),
)
def test_precision(self, y_true, y_pred, threshold):
def test_precision(self, y_true, y_pred, threshold,from_logits=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix spacing after comma

"""Test that `Precision` metric computes correct values."""
y_true = y_true.reshape((-1,))
y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0)
y_true_keras = y_true.reshape((-1,))
if from_logits:
probs = jax.nn.softmax(y_pred, axis=-1)
y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0)
else:
y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0)

keras_precision = keras.metrics.Precision(thresholds=threshold)
keras_precision.update_state(y_true, y_pred)
keras_precision.update_state(y_true_keras, y_pred_keras)
expected = keras_precision.result()

metric = None
Expand All @@ -144,6 +155,7 @@ def test_precision(self, y_true, y_pred, threshold):
predictions=logits,
labels=labels,
threshold=threshold,
from_logits=from_logits,
)
metric = update if metric is None else metric.merge(update)

Expand All @@ -161,6 +173,7 @@ def test_precision(self, y_true, y_pred, threshold):
('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5),
('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7),
('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.1),
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5,True),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5),
('high_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7),
('low_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1),
Expand All @@ -169,12 +182,18 @@ def test_precision(self, y_true, y_pred, threshold):
('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1),
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5),
)
def test_recall(self, y_true, y_pred, threshold):
def test_recall(self, y_true, y_pred, threshold, from_logits=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove the default value and add the value to the above test case tuples for clarity?

"""Test that `Recall` metric computes correct values."""
y_true = y_true.reshape((-1,))
y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0)

y_true_keras = y_true.reshape((-1,))
if from_logits:
probs = jax.nn.softmax(y_pred, axis=-1)
y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0)
else:
y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0)

keras_recall = keras.metrics.Recall(thresholds=threshold)
keras_recall.update_state(y_true, y_pred)
keras_recall.update_state(y_true_keras, y_pred_keras)
expected = keras_recall.result()

metric = None
Expand All @@ -183,6 +202,7 @@ def test_recall(self, y_true, y_pred, threshold):
predictions=logits,
labels=labels,
threshold=threshold,
from_logits=from_logits,
)
metric = update if metric is None else metric.merge(update)

Expand All @@ -193,19 +213,21 @@ def test_recall(self, y_true, y_pred, threshold):

@parameterized.product(
inputs=(
(OUTPUT_LABELS, OUTPUT_PREDS, None),
(OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
(OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
(OUTPUT_LABELS, OUTPUT_PREDS, None, False),
(OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None, False),
(OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS, False),
(OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS, True),
),
dtype=(
jnp.float16,
jnp.float32,
jnp.bfloat16,
jnp.bfloat16
),
)
def test_aucpr(self, inputs, dtype):
def test_aucpr(self, inputs, dtype, from_logits=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove default here as well

"""Test that `AUC-PR` Metric computes correct values."""
y_true, y_pred, sample_weights = inputs
y_true, y_pred, sample_weights, from_logits = inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is shadowing the from_logits variable - dedupe

y_true = y_true.astype(dtype)
y_pred = y_pred.astype(dtype)
if sample_weights is None:
Expand All @@ -217,10 +239,13 @@ def test_aucpr(self, inputs, dtype):
predictions=logits,
labels=labels,
sample_weights=weights,
from_logits=from_logits,
)
metric = update if metric is None else metric.merge(update)

keras_aucpr = keras.metrics.AUC(curve='PR')
if from_logits:
y_pred = jax.nn.softmax(y_pred, axis=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please match 2 spaces per indentation style here and everywhere else

for labels, logits, weights in zip(y_true, y_pred, sample_weights):
keras_aucpr.update_state(labels, logits, sample_weight=weights)
expected = keras_aucpr.result()
Expand All @@ -234,19 +259,21 @@ def test_aucpr(self, inputs, dtype):

@parameterized.product(
inputs=(
(OUTPUT_LABELS, OUTPUT_PREDS, None),
(OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
(OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
(OUTPUT_LABELS, OUTPUT_PREDS, None, False),
(OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None, False),
(OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS, False),
(OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS, True)
),
dtype=(
jnp.float16,
jnp.float32,
jnp.bfloat16,
jnp.bfloat16
),
)
def test_aucroc(self, inputs, dtype):
def test_aucroc(self, inputs, dtype, from_logits=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove default arg

"""Test that `AUC-ROC` Metric computes correct values."""
y_true, y_pred, sample_weights = inputs
y_true, y_pred, sample_weights,from_logits = inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix spacing

y_true = y_true.astype(dtype)
y_pred = y_pred.astype(dtype)
if sample_weights is None:
Expand All @@ -258,10 +285,13 @@ def test_aucroc(self, inputs, dtype):
predictions=logits,
labels=labels,
sample_weights=weights,
from_logits=from_logits, # AUCROC typically expects probabilities.
)
metric = update if metric is None else metric.merge(update)

keras_aucroc = keras.metrics.AUC(curve='ROC')
if from_logits:
y_pred = jax.nn.softmax(y_pred, axis=-1)
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
keras_aucroc.update_state(labels, logits, sample_weight=weights)
expected = keras_aucroc.result()
Expand All @@ -286,16 +316,23 @@ def test_aucroc(self, inputs, dtype):
('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0),
('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0),
('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0),
('batch_size_one_logits_beta_3.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 3.0, True),
)
def test_fbetascore(self, y_true, y_pred, threshold, beta):
def test_fbetascore(self, y_true, y_pred, threshold, beta, from_logits=False):
# Define the Keras FBeta class to be tested against
keras_fbeta = keras.metrics.FBetaScore(beta=beta, threshold=threshold)
keras_fbeta.update_state(y_true, y_pred)
if from_logits:
y_pred_keras = jax.nn.softmax(y_pred, axis=-1)
keras_fbeta.update_state(y_true, y_pred_keras)

else:
keras_fbeta.update_state(y_true, y_pred)

expected = keras_fbeta.result()

# Calculate the F-beta score using the metrax variant
metric = metrax.FBetaScore
metric = metric.from_model_output(y_pred, y_true, beta, threshold)
metric = metric.from_model_output(y_pred, y_true, beta, threshold, from_logits=from_logits)

# Use lower tolerance for lower precision dtypes.
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5
Expand Down