Skip to content

Commit e0f0add

Browse files
committed
Minor improvements and cleanup using available tree utilities.
1 parent adcaca1 commit e0f0add

File tree

6 files changed

+10
-21
lines changed

6 files changed

+10
-21
lines changed

docs/api/utilities.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Power iteration
6161
.. autofunction:: power_iteration
6262

6363
Non-negative least squares
64-
~~~~~~~~~~~~~~~
64+
~~~~~~~~~~~~~~~~~~~~~~~~~~
6565
.. autofunction:: nnls
6666

6767

examples/perturbations.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
"source": [
4242
"import jax\n",
4343
"import jax.numpy as jnp\n",
44-
"import operator\n",
4544
"from jax import tree_util as jtu\n",
4645
"\n",
4746
"import optax.tree\n",
@@ -773,7 +772,7 @@
773772
" pert_softmax = pert_argmax_fun(rng, inputs)\n",
774773
" argmax = argmax_tree(inputs)\n",
775774
" diffs = jax.tree.map(lambda x, y: jnp.sum((x - y) ** 2 / 4), argmax, pert_softmax)\n",
776-
" return jax.tree.reduce(operator.add, diffs)"
775+
" return optax.tree.sum(diffs)"
777776
]
778777
},
779778
{

optax/contrib/_sophia.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,8 @@ def update_fn(updates, state: SophiaState, params=None, **hess_fn_kwargs):
157157
lambda m, h: m / jnp.maximum(gamma * h, eps), mu_hat, state.nu
158158
)
159159
if clip_threshold is not None:
160-
sum_not_clipped = jax.tree.reduce(
161-
lambda x, y: x + y,
162-
jax.tree.map(lambda u: jnp.sum(jnp.abs(u) < clip_threshold), updates),
163-
)
160+
not_clipped = jax.tree.map(lambda u: jnp.abs(u) < clip_threshold, updates)
161+
sum_not_clipped = optax.tree.sum(not_clipped)
164162
if verbose:
165163
win_rate = sum_not_clipped / optax.tree.size(updates)
166164
jax.lax.cond(

optax/perturbations/_make_pert_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""Tests for optax.perturbations, checking values and gradients."""
1717

1818
from functools import partial # pylint: disable=g-importing-member
19-
import operator
2019

2120
from absl.testing import absltest
2221
from absl.testing import parameterized
@@ -159,7 +158,7 @@ def loss(tree):
159158
pred = apply_element_tree(tree)
160159
pred_true = apply_element_tree(example_tree)
161160
tree_loss = jax.tree.map(lambda x, y: (x - y) ** 2, pred, pred_true)
162-
list_loss = jax.tree.reduce(operator.add, tree_loss)
161+
list_loss = optax.tree.sum(tree_loss)
163162
return jax.tree.map(lambda *leaves: sum(leaves) / len(leaves), list_loss)
164163

165164
loss_pert = jax.jit(_make_pert.make_perturbed_fun(

optax/transforms/_accumulation.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,8 @@ def skip_not_finite(
176176
- `num_not_finite`: total number of inf and NaN found in `updates`.
177177
"""
178178
del gradient_step, params
179-
all_is_finite = [
180-
jnp.sum(jnp.logical_not(jnp.isfinite(p)))
181-
for p in jax.tree.leaves(updates)
182-
]
183-
num_not_finite = jnp.sum(jnp.array(all_is_finite))
179+
not_finite = jax.tree.map(lambda x: ~jnp.isfinite(x), updates)
180+
num_not_finite = optax.tree.sum(not_finite)
184181
should_skip = num_not_finite > 0
185182
return should_skip, {
186183
'should_skip': should_skip,
@@ -210,9 +207,7 @@ def skip_large_updates(
210207
- `norm_squared`: overall norm square of the `updates`.
211208
"""
212209
del gradient_step, params
213-
norm_sq = jnp.sum(
214-
jnp.array([jnp.sum(p**2) for p in jax.tree.leaves(updates)])
215-
)
210+
norm_sq = optax.tree.norm(updates, squared=True)
216211
# This will also return True if `norm_sq` is NaN.
217212
should_skip = jnp.logical_not(norm_sq < max_squared_norm)
218213
return should_skip, {'should_skip': should_skip, 'norm_squared': norm_sq}

optax/tree_utils/_tree_math.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric:
151151
numerical issues.
152152
"""
153153
vdots = jax.tree.map(_vdot_safe, tree_x, tree_y)
154-
return jax.tree.reduce(operator.add, vdots, initializer=0)
154+
return tree_sum(vdots)
155155

156156

157157
def tree_sum(tree: Any) -> chex.Numeric:
@@ -450,6 +450,4 @@ def tree_allclose(
450450
def f(a, b):
451451
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
452452
tree = jax.tree.map(f, a, b)
453-
leaves = jax.tree.leaves(tree)
454-
result = functools.reduce(operator.and_, leaves, True)
455-
return result
453+
return jax.tree.reduce(operator.and_, tree, True)

0 commit comments

Comments
 (0)