-
Notifications
You must be signed in to change notification settings - Fork 13
Added from_logits flag #109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,23 @@ def _default_threshold(num_thresholds: int) -> jax.Array: | |
| return thresholds | ||
|
|
||
|
|
||
| def _convert_logits_to_probabilities( | ||
| predictions: jax.Array, from_logits: bool )-> jax.Array: | ||
| """Converts logits to probabilities if `from_logits` is True | ||
| Args: | ||
| predictions: JAX array of predicted values, expected to be logits if `from_logits` is True. | ||
| from_logits: Boolean indicating whether `predictions` are logits. | ||
| Returns: | ||
| JAX array of probabilities if `from_logits` is True, otherwise returns `predictions` unchanged. | ||
| """ | ||
|
|
||
| if from_logits: | ||
| predictions = jax.nn.softmax(predictions, axis=-1) | ||
| # Assuming binary classification, take the positive class probability. | ||
|
|
||
|
|
||
| return predictions | ||
|
|
||
| @flax.struct.dataclass | ||
| class Accuracy(base.Average): | ||
| r"""Computes accuracy, which is the frequency with which `predictions` match `labels`. | ||
|
|
@@ -69,6 +86,7 @@ def from_model_output( | |
| predictions: jax.Array, | ||
| labels: jax.Array, | ||
| sample_weights: jax.Array | None = None, | ||
| from_logits: bool = False, | ||
| ) -> 'Accuracy': | ||
| """Updates the metric state with new `predictions` and `labels`. | ||
|
|
||
|
|
@@ -99,6 +117,12 @@ def from_model_output( | |
| comparison, or if `sample_weights` cannot be broadcast to `labels`' | ||
| shape. | ||
| """ | ||
|
|
||
| if from_logits: | ||
| predictions = jax.nn.softmax(predictions, axis=-1) | ||
|
|
||
|
|
||
|
|
||
| correct = predictions == labels | ||
| count = jnp.ones_like(labels, dtype=jnp.int32) | ||
| if sample_weights is not None: | ||
|
|
@@ -149,6 +173,7 @@ def from_model_output( | |
| predictions: jax.Array, | ||
| labels: jax.Array, | ||
| threshold: float = 0.5, | ||
| from_logits: bool = False, | ||
| ) -> 'Precision': | ||
| """Updates the metric. | ||
|
|
||
|
|
@@ -166,7 +191,10 @@ def from_model_output( | |
| ValueError: If type of `labels` is wrong or the shapes of `predictions` | ||
| and `labels` are incompatible. | ||
| """ | ||
| predictions = _convert_logits_to_probabilities(predictions, from_logits) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you update these so it only calls |
||
|
|
||
| predictions = jnp.where(predictions >= threshold, 1, 0) | ||
|
|
||
| true_positives = jnp.sum((predictions == 1) & (labels == 1)) | ||
| false_positives = jnp.sum((predictions == 1) & (labels == 0)) | ||
|
|
||
|
|
@@ -219,7 +247,7 @@ def empty(cls) -> 'Recall': | |
|
|
||
| @classmethod | ||
| def from_model_output( | ||
| cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5 | ||
| cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, from_logits: bool = False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you update the docstrings (here and below)? |
||
| ) -> 'Recall': | ||
| """Updates the metric. | ||
|
|
||
|
|
@@ -237,6 +265,8 @@ def from_model_output( | |
| ValueError: If type of `labels` is wrong or the shapes of `predictions` | ||
| and `labels` are incompatible. | ||
| """ | ||
| predictions = _convert_logits_to_probabilities(predictions, from_logits) | ||
|
|
||
| predictions = jnp.where(predictions >= threshold, 1, 0) | ||
| true_positives = jnp.sum((predictions == 1) & (labels == 1)) | ||
| false_negatives = jnp.sum((predictions == 0) & (labels == 1)) | ||
|
|
@@ -325,6 +355,7 @@ def from_model_output( | |
| labels: jax.Array, | ||
| sample_weights: jax.Array | None = None, | ||
| num_thresholds: int = 200, | ||
| from_logits: bool = False, | ||
| ) -> 'AUCPR': | ||
| """Updates the metric. | ||
|
|
||
|
|
@@ -345,6 +376,8 @@ def from_model_output( | |
| ValueError: If type of `labels` is wrong or the shapes of `predictions` | ||
| and `labels` are incompatible. | ||
| """ | ||
| predictions = _convert_logits_to_probabilities(predictions, from_logits) | ||
|
|
||
| pred_is_pos = jnp.greater( | ||
| predictions, | ||
| _default_threshold(num_thresholds=num_thresholds)[..., None], | ||
|
|
@@ -513,6 +546,7 @@ def from_model_output( | |
| labels: jax.Array, | ||
| sample_weights: jax.Array | None = None, | ||
| num_thresholds: int = 200, | ||
| from_logits: bool = False, | ||
| ) -> 'AUCROC': | ||
| """Updates the metric. | ||
|
|
||
|
|
@@ -533,6 +567,8 @@ def from_model_output( | |
| ValueError: If type of `labels` is wrong or the shapes of `predictions` | ||
| and `labels` are incompatible. | ||
| """ | ||
| predictions = _convert_logits_to_probabilities(predictions, from_logits) | ||
|
|
||
| pred_is_pos = jnp.greater( | ||
| predictions, | ||
| _default_threshold(num_thresholds=num_thresholds)[..., None], | ||
|
|
@@ -622,7 +658,8 @@ def from_model_output( | |
| predictions: jax.Array, | ||
| labels: jax.Array, | ||
| beta = beta, | ||
| threshold = 0.5,) -> 'FBetaScore': | ||
| threshold = 0.5, | ||
| from_logits : bool = False) -> 'FBetaScore': | ||
| """Updates the metric. | ||
| Note: When only predictions and labels are given, the score calculated | ||
| is the F1 score if the FBetaScore beta value has not been previously modified. | ||
|
|
@@ -656,6 +693,10 @@ def from_model_output( | |
| if threshold < 0.0 or threshold > 1.0: | ||
| raise ValueError('The "Threshold" value must be between 0 and 1.') | ||
|
|
||
| # If the predictions are logits, convert them to probabilities | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove extra newline |
||
| predictions = _convert_logits_to_probabilities(predictions, from_logits) | ||
|
|
||
| # Modify predictions with the given threshold value | ||
| predictions = jnp.where(predictions >= threshold, 1, 0) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,8 @@ | |
| """Tests for metrax classification metrics.""" | ||
|
|
||
| import os | ||
|
|
||
| import jax | ||
| os.environ['KERAS_BACKEND'] = 'jax' | ||
|
|
||
| from absl.testing import absltest | ||
|
|
@@ -34,6 +36,7 @@ | |
| ).astype(np.float32) | ||
| OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE)) | ||
| OUTPUT_PREDS_F16 = OUTPUT_PREDS.astype(jnp.float16) | ||
| OUTPUT_LOGITS_F16 = np.random.randn(BATCHES, BATCH_SIZE).astype(jnp.float16) | ||
| OUTPUT_PREDS_F32 = OUTPUT_PREDS.astype(jnp.float32) | ||
| OUTPUT_PREDS_BF16 = OUTPUT_PREDS.astype(jnp.bfloat16) | ||
| OUTPUT_LABELS_BS1 = np.random.randint( | ||
|
|
@@ -92,8 +95,9 @@ def test_fbeta_empty(self): | |
| ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS), | ||
| ('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS), | ||
| ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), | ||
| ('batch_size_logits_f16', OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS,True), | ||
| ) | ||
| def test_accuracy(self, y_true, y_pred, sample_weights): | ||
| def test_accuracy(self, y_true, y_pred, sample_weights, from_logits=False): | ||
| """Test that `Accuracy` metric computes correct values.""" | ||
| if sample_weights is None: | ||
| sample_weights = np.ones_like(y_true) | ||
|
|
@@ -104,6 +108,7 @@ def test_accuracy(self, y_true, y_pred, sample_weights): | |
| predictions=logits, | ||
| labels=labels, | ||
| sample_weights=weights, | ||
| from_logits=from_logits, | ||
| ) | ||
| metrax_accuracy = metrax_accuracy.merge(update) | ||
| keras_accuracy.update_state(labels, logits, weights) | ||
|
|
@@ -120,6 +125,7 @@ def test_accuracy(self, y_true, y_pred, sample_weights): | |
|
|
||
| @parameterized.named_parameters( | ||
| ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5), | ||
| ('basic_f16_logits', OUTPUT_LABELS, OUTPUT_LOGITS_F16, 0.5, True), | ||
| ('high_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7), | ||
| ('low_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1), | ||
| ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5), | ||
|
|
@@ -130,12 +136,17 @@ def test_accuracy(self, y_true, y_pred, sample_weights): | |
| ('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1), | ||
| ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5), | ||
| ) | ||
| def test_precision(self, y_true, y_pred, threshold): | ||
| def test_precision(self, y_true, y_pred, threshold,from_logits=False): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix spacing after comma |
||
| """Test that `Precision` metric computes correct values.""" | ||
| y_true = y_true.reshape((-1,)) | ||
| y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) | ||
| y_true_keras = y_true.reshape((-1,)) | ||
| if from_logits: | ||
| probs = jax.nn.softmax(y_pred, axis=-1) | ||
| y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0) | ||
| else: | ||
| y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) | ||
|
|
||
| keras_precision = keras.metrics.Precision(thresholds=threshold) | ||
| keras_precision.update_state(y_true, y_pred) | ||
| keras_precision.update_state(y_true_keras, y_pred_keras) | ||
| expected = keras_precision.result() | ||
|
|
||
| metric = None | ||
|
|
@@ -144,6 +155,7 @@ def test_precision(self, y_true, y_pred, threshold): | |
| predictions=logits, | ||
| labels=labels, | ||
| threshold=threshold, | ||
| from_logits=from_logits, | ||
| ) | ||
| metric = update if metric is None else metric.merge(update) | ||
|
|
||
|
|
@@ -161,6 +173,7 @@ def test_precision(self, y_true, y_pred, threshold): | |
| ('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5), | ||
| ('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7), | ||
| ('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.1), | ||
| ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5,True), | ||
| ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5), | ||
| ('high_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7), | ||
| ('low_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1), | ||
|
|
@@ -169,12 +182,18 @@ def test_precision(self, y_true, y_pred, threshold): | |
| ('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1), | ||
| ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5), | ||
| ) | ||
| def test_recall(self, y_true, y_pred, threshold): | ||
| def test_recall(self, y_true, y_pred, threshold, from_logits=False): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you remove the default value and add the value to the above test case tuples for clarity? |
||
| """Test that `Recall` metric computes correct values.""" | ||
| y_true = y_true.reshape((-1,)) | ||
| y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) | ||
|
|
||
| y_true_keras = y_true.reshape((-1,)) | ||
| if from_logits: | ||
| probs = jax.nn.softmax(y_pred, axis=-1) | ||
| y_pred_keras = jnp.where(probs.reshape((-1,)) >= threshold, 1, 0) | ||
| else: | ||
| y_pred_keras = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0) | ||
|
|
||
| keras_recall = keras.metrics.Recall(thresholds=threshold) | ||
| keras_recall.update_state(y_true, y_pred) | ||
| keras_recall.update_state(y_true_keras, y_pred_keras) | ||
| expected = keras_recall.result() | ||
|
|
||
| metric = None | ||
|
|
@@ -183,6 +202,7 @@ def test_recall(self, y_true, y_pred, threshold): | |
| predictions=logits, | ||
| labels=labels, | ||
| threshold=threshold, | ||
| from_logits=from_logits, | ||
| ) | ||
| metric = update if metric is None else metric.merge(update) | ||
|
|
||
|
|
@@ -193,19 +213,21 @@ def test_recall(self, y_true, y_pred, threshold): | |
|
|
||
| @parameterized.product( | ||
| inputs=( | ||
| (OUTPUT_LABELS, OUTPUT_PREDS, None), | ||
| (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), | ||
| (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS), | ||
| (OUTPUT_LABELS, OUTPUT_PREDS, None, False), | ||
| (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None, False), | ||
| (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS, False), | ||
| (OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS, True), | ||
| ), | ||
| dtype=( | ||
| jnp.float16, | ||
| jnp.float32, | ||
| jnp.bfloat16, | ||
| jnp.bfloat16 | ||
| ), | ||
| ) | ||
| def test_aucpr(self, inputs, dtype): | ||
| def test_aucpr(self, inputs, dtype, from_logits=False): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove default here as well |
||
| """Test that `AUC-PR` Metric computes correct values.""" | ||
| y_true, y_pred, sample_weights = inputs | ||
| y_true, y_pred, sample_weights, from_logits = inputs | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is shadowing the from_logits variable - dedupe |
||
| y_true = y_true.astype(dtype) | ||
| y_pred = y_pred.astype(dtype) | ||
| if sample_weights is None: | ||
|
|
@@ -217,10 +239,13 @@ def test_aucpr(self, inputs, dtype): | |
| predictions=logits, | ||
| labels=labels, | ||
| sample_weights=weights, | ||
| from_logits=from_logits, | ||
| ) | ||
| metric = update if metric is None else metric.merge(update) | ||
|
|
||
| keras_aucpr = keras.metrics.AUC(curve='PR') | ||
| if from_logits: | ||
| y_pred = jax.nn.softmax(y_pred, axis=-1) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please match 2 spaces per indentation style here and everywhere else |
||
| for labels, logits, weights in zip(y_true, y_pred, sample_weights): | ||
| keras_aucpr.update_state(labels, logits, sample_weight=weights) | ||
| expected = keras_aucpr.result() | ||
|
|
@@ -234,19 +259,21 @@ def test_aucpr(self, inputs, dtype): | |
|
|
||
| @parameterized.product( | ||
| inputs=( | ||
| (OUTPUT_LABELS, OUTPUT_PREDS, None), | ||
| (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), | ||
| (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS), | ||
| (OUTPUT_LABELS, OUTPUT_PREDS, None, False), | ||
| (OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None, False), | ||
| (OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS, False), | ||
| (OUTPUT_LABELS, OUTPUT_LOGITS_F16, SAMPLE_WEIGHTS, True) | ||
| ), | ||
| dtype=( | ||
| jnp.float16, | ||
| jnp.float32, | ||
| jnp.bfloat16, | ||
| jnp.bfloat16 | ||
| ), | ||
| ) | ||
| def test_aucroc(self, inputs, dtype): | ||
| def test_aucroc(self, inputs, dtype, from_logits=False): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove default arg |
||
| """Test that `AUC-ROC` Metric computes correct values.""" | ||
| y_true, y_pred, sample_weights = inputs | ||
| y_true, y_pred, sample_weights,from_logits = inputs | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix spacing |
||
| y_true = y_true.astype(dtype) | ||
| y_pred = y_pred.astype(dtype) | ||
| if sample_weights is None: | ||
|
|
@@ -258,10 +285,13 @@ def test_aucroc(self, inputs, dtype): | |
| predictions=logits, | ||
| labels=labels, | ||
| sample_weights=weights, | ||
| from_logits=from_logits, # AUCROC typically expects probabilities. | ||
| ) | ||
| metric = update if metric is None else metric.merge(update) | ||
|
|
||
| keras_aucroc = keras.metrics.AUC(curve='ROC') | ||
| if from_logits: | ||
| y_pred = jax.nn.softmax(y_pred, axis=-1) | ||
| for labels, logits, weights in zip(y_true, y_pred, sample_weights): | ||
| keras_aucroc.update_state(labels, logits, sample_weight=weights) | ||
| expected = keras_aucroc.result() | ||
|
|
@@ -286,16 +316,23 @@ def test_aucroc(self, inputs, dtype): | |
| ('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0), | ||
| ('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0), | ||
| ('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0), | ||
| ('batch_size_one_logits_beta_3.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 3.0, True), | ||
| ) | ||
| def test_fbetascore(self, y_true, y_pred, threshold, beta): | ||
| def test_fbetascore(self, y_true, y_pred, threshold, beta, from_logits=False): | ||
| # Define the Keras FBeta class to be tested against | ||
| keras_fbeta = keras.metrics.FBetaScore(beta=beta, threshold=threshold) | ||
| keras_fbeta.update_state(y_true, y_pred) | ||
| if from_logits: | ||
| y_pred_keras = jax.nn.softmax(y_pred, axis=-1) | ||
| keras_fbeta.update_state(y_true, y_pred_keras) | ||
|
|
||
| else: | ||
| keras_fbeta.update_state(y_true, y_pred) | ||
|
|
||
| expected = keras_fbeta.result() | ||
|
|
||
| # Calculate the F-beta score using the metrax variant | ||
| metric = metrax.FBetaScore | ||
| metric = metric.from_model_output(y_pred, y_true, beta, threshold) | ||
| metric = metric.from_model_output(y_pred, y_true, beta, threshold, from_logits=from_logits) | ||
|
|
||
| # Use lower tolerance for lower precision dtypes. | ||
| rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is inconsistent with the other metrics where
_convert_logits_to_probabilitiesis called