Skip to content

Commit 5db134c

Browse files
author
OptaxDev
committed
Merge pull request #1368 from carlosgmartin:deprecate_optax_global_norm
PiperOrigin-RevId: 830524076
2 parents 912adf2 + f2dc8d3 commit 5db134c

File tree

6 files changed

+21
-16
lines changed

6 files changed

+21
-16
lines changed

optax/_src/linear_algebra.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from collections.abc import Callable
1818
import functools
1919
from typing import Optional, Union
20+
import warnings
2021

2122
import chex
2223
import jax
2324
from jax import lax
2425
import jax.numpy as jnp
2526
from optax._src import base
26-
from optax._src import numerics
2727
import optax.tree
2828

2929

@@ -33,10 +33,20 @@ def _normalize_tree(x):
3333

3434

3535
def global_norm(updates: base.PyTree) -> chex.Array:
36-
"""Compute the global norm across a nested structure of tensors."""
37-
return jnp.sqrt(
38-
sum(jnp.sum(numerics.abs_sq(x)) for x in jax.tree.leaves(updates))
36+
"""Compute the global norm across a nested structure of tensors.
37+
38+
.. warning::
39+
Deprecated in favor of :func:`optax.tree.norm`.
40+
Args:
41+
updates: A nested structure of tensors.
42+
Returns:
43+
The global L2 norm of the updates.
44+
"""
45+
warnings.warn(
46+
'optax.global_norm is deprecated in favor of optax.tree.norm',
47+
DeprecationWarning
3948
)
49+
return optax.tree.norm(updates)
4050

4151

4252
def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars):

optax/_src/linear_algebra_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_global_norm(self):
5353
}
5454
np.testing.assert_array_equal(
5555
jnp.sqrt(jnp.sum(flat_updates**2)),
56-
linear_algebra.global_norm(nested_updates),
56+
optax.tree.norm(nested_updates),
5757
)
5858

5959
def test_power_iteration_cond_fun(self, dim=6):

optax/_src/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import jax.numpy as jnp
2727
import jax.scipy.stats.norm as multivariate_normal
2828
from optax._src import base
29-
from optax._src import linear_algebra
3029
from optax._src import numerics
3130
import optax.tree
3231

@@ -345,4 +344,3 @@ def _value_and_grad(
345344
# TODO(b/183800387): remove legacy aliases.
346345
safe_norm = numerics.safe_norm
347346
safe_int32_increment = numerics.safe_int32_increment
348-
global_norm = linear_algebra.global_norm

optax/contrib/_sam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
import jax.numpy as jnp
5555
from optax._src import base
5656
from optax._src import update
57-
from optax._src import utils
57+
import optax.tree
5858

5959
# As a helper for SAM we need a gradient normalizing transformation.
6060

@@ -74,7 +74,7 @@ def init_fn(params):
7474

7575
def update_fn(updates, state, params=None):
7676
del params
77-
g_norm = utils.global_norm(updates)
77+
g_norm = optax.tree.norm(updates)
7878
updates = jax.tree.map(lambda g: g / g_norm, updates)
7979
return updates, state
8080

optax/transforms/_clipping.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import jax
2525
import jax.numpy as jnp
2626
from optax._src import base
27-
from optax._src import linear_algebra
2827
from optax._src import numerics
2928
import optax.tree
3029

@@ -90,7 +89,7 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
9089

9190
def update_fn(updates, state, params=None):
9291
del params
93-
g_norm = linear_algebra.global_norm(updates)
92+
g_norm = optax.tree.norm(updates)
9493
# TODO(b/163995078): revert back to the following (faster) implementation
9594
# once analyzed how it affects backprop through update (e.g. meta-gradients)
9695
# g_norm = jnp.maximum(max_norm, g_norm)
@@ -154,7 +153,7 @@ def per_example_global_norm_clip(
154153
" `grads` to have a batch dimension in the 0th axis."
155154
)
156155

157-
global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads)
156+
global_grad_norms = jax.vmap(optax.tree.norm)(grads)
158157
multipliers = jnp.nan_to_num(
159158
jnp.minimum(l2_norm_clip / global_grad_norms, 1.0), nan=1.0
160159
)

optax/transforms/_clipping_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import jax
2020
import jax.numpy as jnp
2121
import numpy as np
22-
from optax._src import linear_algebra
2322
from optax.transforms import _clipping
23+
import optax.tree
2424

2525

2626
STEPS = 50
@@ -70,9 +70,7 @@ def test_clip_by_global_norm(self):
7070
clipper = _clipping.clip_by_global_norm(1.0 / i)
7171
# Check that the clipper actually works and global norm is <= max_norm
7272
updates, _ = clipper.update(updates, None)
73-
self.assertAlmostEqual(
74-
linear_algebra.global_norm(updates), 1.0 / i, places=6
75-
)
73+
self.assertAlmostEqual(optax.tree.norm(updates), 1.0 / i, places=6)
7674
# Check that continuously clipping won't cause numerical issues.
7775
updates_step, _ = clipper.update(self.per_step_updates, None)
7876
chex.assert_trees_all_close(updates, updates_step)

0 commit comments

Comments
 (0)