diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 95f8885c0..56813cc24 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1028,6 +1028,7 @@ def amsgrad( eps: float = 1e-8, eps_root: float = 0.0, mu_dtype: Optional[Any] = None, + bias_correction_v: bool = True ) -> base.GradientTransformationExtraArgs: """The AMSGrad optimizer. @@ -1077,7 +1078,12 @@ def amsgrad( """ return combine.chain( transform.scale_by_amsgrad( - b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + mu_dtype=mu_dtype, + bias_correction_v=bias_correction_v ), transform.scale_by_learning_rate(learning_rate), ) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 9f1faa043..ce1a949a1 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -323,6 +323,7 @@ def scale_by_amsgrad( eps: float = 1e-8, eps_root: float = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, + bias_correction_v: bool = True ) -> base.GradientTransformation: """Rescale updates according to the AMSGrad algorithm. @@ -336,6 +337,9 @@ def scale_by_amsgrad( numerical stability when backpropagating gradients through the rescaling. mu_dtype: Optional `dtype` to be used for the first order accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. + bias_correction_v: Whether to apply bias correction to the second moment + estimate before taking the elementwise maximum (``nu_max``). Set to + ``False`` to match the original AMSGrad paper and PyTorch/Keras behavior. Returns: A :class:`optax.GradientTransformation` object. @@ -357,8 +361,11 @@ def update_fn(updates, state, params=None): nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_increment(state.count) mu_hat = optax.tree.bias_correction(mu, b1, count_inc) - nu_hat = optax.tree.bias_correction(nu, b2, count_inc) - nu_max = jax.tree.map(jnp.maximum, state.nu_max, nu_hat) + if bias_correction_v: + nu_eff = optax.tree.bias_correction(nu, b2, count_inc) + else: + nu_eff = nu + nu_max = jax.tree.map(jnp.maximum, state.nu_max, nu_eff) updates = jax.tree.map( lambda m, v: None if m is None else m / (jnp.sqrt(v + eps_root) + eps), mu_hat,