Skip to content

Commit a571029

Browse files
committed
Minor improvements and cleanup using available tree utilities.
1 parent a02de84 commit a571029

File tree

8 files changed

+16
-30
lines changed

8 files changed

+16
-30
lines changed

docs/api/utilities.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Power iteration
6161
.. autofunction:: power_iteration
6262

6363
Non-negative least squares
64-
~~~~~~~~~~~~~~~
64+
~~~~~~~~~~~~~~~~~~~~~~~~~~
6565
.. autofunction:: nnls
6666

6767

examples/perturbations.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
"source": [
4242
"import jax\n",
4343
"import jax.numpy as jnp\n",
44-
"import operator\n",
4544
"from jax import tree_util as jtu\n",
4645
"\n",
4746
"import optax.tree\n",
@@ -773,7 +772,7 @@
773772
" pert_softmax = pert_argmax_fun(rng, inputs)\n",
774773
" argmax = argmax_tree(inputs)\n",
775774
" diffs = jax.tree.map(lambda x, y: jnp.sum((x - y) ** 2 / 4), argmax, pert_softmax)\n",
776-
" return jax.tree.reduce(operator.add, diffs)"
775+
" return optax.tree.sum(diffs)"
777776
]
778777
},
779778
{

optax/_src/utils_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from optax._src import transform
2727
from optax._src import update
2828
from optax._src import utils
29+
import optax.tree
2930

3031

3132
def _shape_to_tuple(shape):
@@ -40,8 +41,7 @@ class ScaleGradientTest(parameterized.TestCase):
4041
def test_scale_gradient_pytree(self, scale):
4142
def fn(inputs):
4243
outputs = utils.scale_gradient(inputs, scale)
43-
outputs = jax.tree.map(lambda x: x**2, outputs)
44-
return sum(jax.tree.leaves(outputs))
44+
return optax.tree.norm(outputs, squared=True)
4545

4646
inputs = {'a': -1.0, 'b': {'c': (2.0,), 'd': 0.0}}
4747

@@ -50,7 +50,7 @@ def fn(inputs):
5050
jax.tree.map(lambda i, g: self.assertEqual(g, 2 * i * scale), inputs, grads)
5151
self.assertEqual(
5252
fn(inputs),
53-
sum(jax.tree.leaves(jax.tree.map(lambda x: x**2, inputs))),
53+
optax.tree.norm(inputs, squared=True),
5454
)
5555

5656

optax/contrib/_privacy_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax
2121
import jax.numpy as jnp
2222
from optax.contrib import _privacy
23+
import optax.tree
2324

2425

2526
class DifferentiallyPrivateAggregateTest(chex.TestCase):
@@ -64,12 +65,8 @@ def test_clipping_norm(self, l2_norm_clip):
6465
state = dp_agg.init(self.params)
6566
update_fn = self.variant(dp_agg.update)
6667

67-
# Shape of the three arrays below is (self.batch_size, )
68-
norms = [
69-
jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1)
70-
for g in jax.tree.leaves(self.per_eg_grads)
71-
]
72-
global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0)
68+
global_norms = jax.vmap(optax.tree.norm)(self.per_eg_grads)
69+
7370
divisors = jnp.maximum(global_norms / l2_norm_clip, 1.0)
7471
# Since the values of all the parameters are the same within each example,
7572
# we can easily compute what the values should be:

optax/contrib/_sophia.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,8 @@ def update_fn(updates, state: SophiaState, params=None, **hess_fn_kwargs):
157157
lambda m, h: m / jnp.maximum(gamma * h, eps), mu_hat, state.nu
158158
)
159159
if clip_threshold is not None:
160-
sum_not_clipped = jax.tree.reduce(
161-
lambda x, y: x + y,
162-
jax.tree.map(lambda u: jnp.sum(jnp.abs(u) < clip_threshold), updates),
163-
)
160+
not_clipped = jax.tree.map(lambda u: jnp.abs(u) < clip_threshold, updates)
161+
sum_not_clipped = optax.tree.sum(not_clipped)
164162
if verbose:
165163
win_rate = sum_not_clipped / optax.tree.size(updates)
166164
jax.lax.cond(

optax/perturbations/_make_pert_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""Tests for optax.perturbations, checking values and gradients."""
1717

1818
from functools import partial # pylint: disable=g-importing-member
19-
import operator
2019

2120
from absl.testing import absltest
2221
from absl.testing import parameterized
@@ -159,7 +158,7 @@ def loss(tree):
159158
pred = apply_element_tree(tree)
160159
pred_true = apply_element_tree(example_tree)
161160
tree_loss = jax.tree.map(lambda x, y: (x - y) ** 2, pred, pred_true)
162-
list_loss = jax.tree.reduce(operator.add, tree_loss)
161+
list_loss = optax.tree.sum(tree_loss)
163162
return jax.tree.map(lambda *leaves: sum(leaves) / len(leaves), list_loss)
164163

165164
loss_pert = jax.jit(_make_pert.make_perturbed_fun(

optax/transforms/_accumulation.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,8 @@ def skip_not_finite(
176176
- `num_not_finite`: total number of inf and NaN found in `updates`.
177177
"""
178178
del gradient_step, params
179-
all_is_finite = [
180-
jnp.sum(jnp.logical_not(jnp.isfinite(p)))
181-
for p in jax.tree.leaves(updates)
182-
]
183-
num_not_finite = jnp.sum(jnp.array(all_is_finite))
179+
not_finite = jax.tree.map(lambda x: ~jnp.isfinite(x), updates)
180+
num_not_finite = optax.tree.sum(not_finite)
184181
should_skip = num_not_finite > 0
185182
return should_skip, {
186183
'should_skip': should_skip,
@@ -210,9 +207,7 @@ def skip_large_updates(
210207
- `norm_squared`: overall norm square of the `updates`.
211208
"""
212209
del gradient_step, params
213-
norm_sq = jnp.sum(
214-
jnp.array([jnp.sum(p**2) for p in jax.tree.leaves(updates)])
215-
)
210+
norm_sq = optax.tree.norm(updates, squared=True)
216211
# This will also return True if `norm_sq` is NaN.
217212
should_skip = jnp.logical_not(norm_sq < max_squared_norm)
218213
return should_skip, {'should_skip': should_skip, 'norm_squared': norm_sq}

optax/tree_utils/_tree_math.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric:
151151
numerical issues.
152152
"""
153153
vdots = jax.tree.map(_vdot_safe, tree_x, tree_y)
154-
return jax.tree.reduce(operator.add, vdots, initializer=0)
154+
return tree_sum(vdots)
155155

156156

157157
def tree_sum(tree: Any) -> chex.Numeric:
@@ -450,6 +450,4 @@ def tree_allclose(
450450
def f(a, b):
451451
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
452452
tree = jax.tree.map(f, a, b)
453-
leaves = jax.tree.leaves(tree)
454-
result = functools.reduce(operator.and_, leaves, True)
455-
return result
453+
return jax.tree.reduce(operator.and_, tree, True)

0 commit comments

Comments
 (0)