diff --git a/optax/losses/_smoothing.py b/optax/losses/_smoothing.py index 430215a30..0864ba61a 100644 --- a/optax/losses/_smoothing.py +++ b/optax/losses/_smoothing.py @@ -14,6 +14,8 @@ # ============================================================================== """Smoothing functions.""" +from typing import Union + from jax import typing import jax.numpy as jnp from optax._src import utils @@ -22,6 +24,8 @@ def smooth_labels( labels: typing.ArrayLike, alpha: float, + *, + axis: Union[int, tuple[int, ...], None] = -1, ) -> jnp.ndarray: """Apply label smoothing. @@ -32,6 +36,7 @@ def smooth_labels( Args: labels: One hot labels to be smoothed. alpha: The smoothing factor. + axis: Axis or axes along which to compute. Returns: a smoothed version of the one hot input labels. @@ -41,5 +46,5 @@ def smooth_labels( `_, 2019 """ utils.check_subdtype(labels, jnp.floating) - num_categories = labels.shape[-1] + num_categories = jnp.size(labels, axis) return (1.0 - alpha) * labels + alpha / num_categories