Skip to content

Commit de7cdd1

Browse files
committed
Apply minor improvements using tree utility functions. Deprecate optax.global_norm in favor of optax.tree.norm.
1 parent adcaca1 commit de7cdd1

File tree

11 files changed

+25
-26
lines changed

11 files changed

+25
-26
lines changed

examples/perturbations.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
"source": [
4242
"import jax\n",
4343
"import jax.numpy as jnp\n",
44-
"import operator\n",
4544
"from jax import tree_util as jtu\n",
4645
"\n",
4746
"import optax.tree\n",
@@ -773,7 +772,7 @@
773772
" pert_softmax = pert_argmax_fun(rng, inputs)\n",
774773
" argmax = argmax_tree(inputs)\n",
775774
" diffs = jax.tree.map(lambda x, y: jnp.sum((x - y) ** 2 / 4), argmax, pert_softmax)\n",
776-
" return jax.tree.reduce(operator.add, diffs)"
775+
" return optax.tree.sum(diffs)"
777776
]
778777
},
779778
{

optax/_src/linear_algebra.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections.abc import Callable
1818
import functools
1919
from typing import Optional, Union
20+
import warnings
2021

2122
import chex
2223
import jax
@@ -33,10 +34,16 @@ def _normalize_tree(x):
3334

3435

3536
def global_norm(updates: base.PyTree) -> chex.Array:
36-
"""Compute the global norm across a nested structure of tensors."""
37-
return jnp.sqrt(
38-
sum(jnp.sum(numerics.abs_sq(x)) for x in jax.tree.leaves(updates))
37+
"""Compute the global norm across a nested structure of tensors.
38+
39+
.. warning::
40+
Deprecated in favor of :func:`optax.tree.norm`.
41+
"""
42+
warnings.warn(
43+
"optax.global_norm is deprecated in favor of optax.tree.norm",
44+
DeprecationWarning
3945
)
46+
return optax.tree.norm(updates)
4047

4148

4249
def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars):

optax/_src/linear_algebra_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_global_norm(self):
5353
}
5454
np.testing.assert_array_equal(
5555
jnp.sqrt(jnp.sum(flat_updates**2)),
56-
linear_algebra.global_norm(nested_updates),
56+
optax.tree.norm(nested_updates),
5757
)
5858

5959
def test_power_iteration_cond_fun(self, dim=6):

optax/_src/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,4 +345,3 @@ def _value_and_grad(
345345
# TODO(b/183800387): remove legacy aliases.
346346
safe_norm = numerics.safe_norm
347347
safe_int32_increment = numerics.safe_int32_increment
348-
global_norm = linear_algebra.global_norm

optax/contrib/_sam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from optax._src import base
5656
from optax._src import update
5757
from optax._src import utils
58+
import optax.tree
5859

5960
# As a helper for SAM we need a gradient normalizing transformation.
6061

@@ -74,7 +75,7 @@ def init_fn(params):
7475

7576
def update_fn(updates, state, params=None):
7677
del params
77-
g_norm = utils.global_norm(updates)
78+
g_norm = optax.tree.norm(updates)
7879
updates = jax.tree.map(lambda g: g / g_norm, updates)
7980
return updates, state
8081

optax/contrib/_sophia.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,8 @@ def update_fn(updates, state: SophiaState, params=None, **hess_fn_kwargs):
157157
lambda m, h: m / jnp.maximum(gamma * h, eps), mu_hat, state.nu
158158
)
159159
if clip_threshold is not None:
160-
sum_not_clipped = jax.tree.reduce(
161-
lambda x, y: x + y,
162-
jax.tree.map(lambda u: jnp.sum(jnp.abs(u) < clip_threshold), updates),
163-
)
160+
not_clipped = jax.tree.map(lambda u: jnp.abs(u) < clip_threshold, updates)
161+
sum_not_clipped = optax.tree.sum(not_clipped)
164162
if verbose:
165163
win_rate = sum_not_clipped / optax.tree.size(updates)
166164
jax.lax.cond(

optax/perturbations/_make_pert_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""Tests for optax.perturbations, checking values and gradients."""
1717

1818
from functools import partial # pylint: disable=g-importing-member
19-
import operator
2019

2120
from absl.testing import absltest
2221
from absl.testing import parameterized
@@ -159,7 +158,7 @@ def loss(tree):
159158
pred = apply_element_tree(tree)
160159
pred_true = apply_element_tree(example_tree)
161160
tree_loss = jax.tree.map(lambda x, y: (x - y) ** 2, pred, pred_true)
162-
list_loss = jax.tree.reduce(operator.add, tree_loss)
161+
list_loss = optax.tree.sum(tree_loss)
163162
return jax.tree.map(lambda *leaves: sum(leaves) / len(leaves), list_loss)
164163

165164
loss_pert = jax.jit(_make_pert.make_perturbed_fun(

optax/transforms/_accumulation.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,8 @@ def skip_not_finite(
176176
- `num_not_finite`: total number of inf and NaN found in `updates`.
177177
"""
178178
del gradient_step, params
179-
all_is_finite = [
180-
jnp.sum(jnp.logical_not(jnp.isfinite(p)))
181-
for p in jax.tree.leaves(updates)
182-
]
183-
num_not_finite = jnp.sum(jnp.array(all_is_finite))
179+
not_finite = jax.tree.map(lambda x: ~jnp.isfinite(x), updates)
180+
num_not_finite = optax.tree.sum(not_finite)
184181
should_skip = num_not_finite > 0
185182
return should_skip, {
186183
'should_skip': should_skip,
@@ -210,9 +207,7 @@ def skip_large_updates(
210207
- `norm_squared`: overall norm square of the `updates`.
211208
"""
212209
del gradient_step, params
213-
norm_sq = jnp.sum(
214-
jnp.array([jnp.sum(p**2) for p in jax.tree.leaves(updates)])
215-
)
210+
norm_sq = optax.tree.norm(updates, squared=True)
216211
# This will also return True if `norm_sq` is NaN.
217212
should_skip = jnp.logical_not(norm_sq < max_squared_norm)
218213
return should_skip, {'should_skip': should_skip, 'norm_squared': norm_sq}

optax/transforms/_clipping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
9090

9191
def update_fn(updates, state, params=None):
9292
del params
93-
g_norm = linear_algebra.global_norm(updates)
93+
g_norm = optax.tree.norm(updates)
9494
# TODO(b/163995078): revert back to the following (faster) implementation
9595
# once analyzed how it affects backprop through update (e.g. meta-gradients)
9696
# g_norm = jnp.maximum(max_norm, g_norm)
@@ -154,7 +154,7 @@ def per_example_global_norm_clip(
154154
" `grads` to have a batch dimension in the 0th axis."
155155
)
156156

157-
global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads)
157+
global_grad_norms = jax.vmap(optax.tree.norm)(grads)
158158
multipliers = jnp.nan_to_num(
159159
jnp.minimum(l2_norm_clip / global_grad_norms, 1.0), nan=1.0
160160
)

optax/transforms/_clipping_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
from optax._src import linear_algebra
2323
from optax.transforms import _clipping
24+
import optax.tree
2425

2526

2627
STEPS = 50
@@ -71,7 +72,7 @@ def test_clip_by_global_norm(self):
7172
# Check that the clipper actually works and global norm is <= max_norm
7273
updates, _ = clipper.update(updates, None)
7374
self.assertAlmostEqual(
74-
linear_algebra.global_norm(updates), 1.0 / i, places=6
75+
optax.tree.norm(updates), 1.0 / i, places=6
7576
)
7677
# Check that continuously clipping won't cause numerical issues.
7778
updates_step, _ = clipper.update(self.per_step_updates, None)

0 commit comments

Comments
 (0)