File tree Expand file tree Collapse file tree 6 files changed +21
-16
lines changed Expand file tree Collapse file tree 6 files changed +21
-16
lines changed Original file line number Diff line number Diff line change 1717from collections .abc import Callable
1818import functools
1919from typing import Optional , Union
20+ import warnings
2021
2122import chex
2223import jax
2324from jax import lax
2425import jax .numpy as jnp
2526from optax ._src import base
26- from optax ._src import numerics
2727import optax .tree
2828
2929
@@ -33,10 +33,20 @@ def _normalize_tree(x):
3333
3434
3535def global_norm (updates : base .PyTree ) -> chex .Array :
36- """Compute the global norm across a nested structure of tensors."""
37- return jnp .sqrt (
38- sum (jnp .sum (numerics .abs_sq (x )) for x in jax .tree .leaves (updates ))
36+ """Compute the global norm across a nested structure of tensors.
37+
38+ .. warning::
39+ Deprecated in favor of :func:`optax.tree.norm`.
40+ Args:
41+ updates: A nested structure of tensors.
42+ Returns:
43+ The global L2 norm of the updates.
44+ """
45+ warnings .warn (
46+ 'optax.global_norm is deprecated in favor of optax.tree.norm' ,
47+ DeprecationWarning
3948 )
49+ return optax .tree .norm (updates )
4050
4151
4252def _power_iteration_cond_fun (error_tolerance , num_iters , loop_vars ):
Original file line number Diff line number Diff line change @@ -53,7 +53,7 @@ def test_global_norm(self):
5353 }
5454 np .testing .assert_array_equal (
5555 jnp .sqrt (jnp .sum (flat_updates ** 2 )),
56- linear_algebra . global_norm (nested_updates ),
56+ optax . tree . norm (nested_updates ),
5757 )
5858
5959 def test_power_iteration_cond_fun (self , dim = 6 ):
Original file line number Diff line number Diff line change 2626import jax .numpy as jnp
2727import jax .scipy .stats .norm as multivariate_normal
2828from optax ._src import base
29- from optax ._src import linear_algebra
3029from optax ._src import numerics
3130import optax .tree
3231
@@ -345,4 +344,3 @@ def _value_and_grad(
345344# TODO(b/183800387): remove legacy aliases.
346345safe_norm = numerics .safe_norm
347346safe_int32_increment = numerics .safe_int32_increment
348- global_norm = linear_algebra .global_norm
Original file line number Diff line number Diff line change 5454import jax .numpy as jnp
5555from optax ._src import base
5656from optax ._src import update
57- from optax ._src import utils
57+ import optax .tree
5858
5959# As a helper for SAM we need a gradient normalizing transformation.
6060
@@ -74,7 +74,7 @@ def init_fn(params):
7474
7575 def update_fn (updates , state , params = None ):
7676 del params
77- g_norm = utils . global_norm (updates )
77+ g_norm = optax . tree . norm (updates )
7878 updates = jax .tree .map (lambda g : g / g_norm , updates )
7979 return updates , state
8080
Original file line number Diff line number Diff line change 2424import jax
2525import jax .numpy as jnp
2626from optax ._src import base
27- from optax ._src import linear_algebra
2827from optax ._src import numerics
2928import optax .tree
3029
@@ -90,7 +89,7 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
9089
9190 def update_fn (updates , state , params = None ):
9291 del params
93- g_norm = linear_algebra . global_norm (updates )
92+ g_norm = optax . tree . norm (updates )
9493 # TODO(b/163995078): revert back to the following (faster) implementation
9594 # once analyzed how it affects backprop through update (e.g. meta-gradients)
9695 # g_norm = jnp.maximum(max_norm, g_norm)
@@ -154,7 +153,7 @@ def per_example_global_norm_clip(
154153 " `grads` to have a batch dimension in the 0th axis."
155154 )
156155
157- global_grad_norms = jax .vmap (linear_algebra . global_norm )(grads )
156+ global_grad_norms = jax .vmap (optax . tree . norm )(grads )
158157 multipliers = jnp .nan_to_num (
159158 jnp .minimum (l2_norm_clip / global_grad_norms , 1.0 ), nan = 1.0
160159 )
Original file line number Diff line number Diff line change 1919import jax
2020import jax .numpy as jnp
2121import numpy as np
22- from optax ._src import linear_algebra
2322from optax .transforms import _clipping
23+ import optax .tree
2424
2525
2626STEPS = 50
@@ -70,9 +70,7 @@ def test_clip_by_global_norm(self):
7070 clipper = _clipping .clip_by_global_norm (1.0 / i )
7171 # Check that the clipper actually works and global norm is <= max_norm
7272 updates , _ = clipper .update (updates , None )
73- self .assertAlmostEqual (
74- linear_algebra .global_norm (updates ), 1.0 / i , places = 6
75- )
73+ self .assertAlmostEqual (optax .tree .norm (updates ), 1.0 / i , places = 6 )
7674 # Check that continuously clipping won't cause numerical issues.
7775 updates_step , _ = clipper .update (self .per_step_updates , None )
7876 chex .assert_trees_all_close (updates , updates_step )
You can’t perform that action at this time.
0 commit comments