Skip to content

Commit 7e8132e

Browse files
committed
Fix tree_min and tree_max to handle zero-size array leafs.
1 parent a02de84 commit 7e8132e

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

optax/tree_utils/_tree_math.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,13 @@ def tree_max(tree: Any) -> chex.Numeric:
176176
Returns:
177177
a scalar value.
178178
"""
179-
maxes = jax.tree.map(jnp.max, tree)
180-
# initializer=-jnp.inf should work but pytype wants a jax.Array.
181-
return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf))
179+
def f(array):
180+
if jnp.size(array) == 0:
181+
return None
182+
else:
183+
return jnp.max(array)
184+
maxes = jax.tree.map(f, tree)
185+
return jax.tree.reduce(jnp.maximum, maxes, initializer=-float("inf"))
182186

183187

184188
def tree_min(tree: Any) -> chex.Numeric:
@@ -190,9 +194,13 @@ def tree_min(tree: Any) -> chex.Numeric:
190194
Returns:
191195
a scalar value.
192196
"""
193-
mins = jax.tree.map(jnp.min, tree)
194-
# initializer=jnp.inf should work but pytype wants a jax.Array.
195-
return jax.tree.reduce(jnp.minimum, mins, initializer=jnp.array(jnp.inf))
197+
def f(array):
198+
if jnp.size(array) == 0:
199+
return None
200+
else:
201+
return jnp.min(array)
202+
mins = jax.tree.map(f, tree)
203+
return jax.tree.reduce(jnp.minimum, mins, initializer=float("inf"))
196204

197205

198206
def tree_size(tree: Any) -> int:

optax/tree_utils/_tree_math_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,22 @@ def test_tree_min(self, key):
162162
got = tu.tree_min(tree)
163163
np.testing.assert_allclose(expected, got)
164164

165+
def test_tree_min_empty(self):
166+
tree = [jnp.ones([2, 3]), jnp.zeros([4, 0, 5])]
167+
got = tu.tree_min(tree)
168+
expected = 1.0
169+
assert expected == got
170+
171+
@parameterized.product(
172+
expected=(0, 10000, -10000, float("inf"), -float("inf")),
173+
dtype=('int8', 'uint8', 'float32'),
174+
which=(tu.tree_min, tu.tree_max),
175+
)
176+
def test_tree_max_min_empty_dtype(self, expected, dtype, which):
177+
tree = [expected, jnp.zeros(0, dtype)]
178+
got = which(tree)
179+
assert expected == got
180+
165181
@parameterized.parameters(
166182
'array_a', 'tree_a', 'tree_a_dict', 'tree_b', 'tree_b_dict'
167183
)

0 commit comments

Comments
 (0)