From 7e8132ea8ef86afb04581ec2d6a58de90c934607 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Tue, 15 Jul 2025 13:20:27 -0400 Subject: [PATCH] Fix tree_min and tree_max to handle zero-size array leafs. --- optax/tree_utils/_tree_math.py | 20 ++++++++++++++------ optax/tree_utils/_tree_math_test.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 9387a0976..d0de50470 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -176,9 +176,13 @@ def tree_max(tree: Any) -> chex.Numeric: Returns: a scalar value. """ - maxes = jax.tree.map(jnp.max, tree) - # initializer=-jnp.inf should work but pytype wants a jax.Array. - return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf)) + def f(array): + if jnp.size(array) == 0: + return None + else: + return jnp.max(array) + maxes = jax.tree.map(f, tree) + return jax.tree.reduce(jnp.maximum, maxes, initializer=-float("inf")) def tree_min(tree: Any) -> chex.Numeric: @@ -190,9 +194,13 @@ def tree_min(tree: Any) -> chex.Numeric: Returns: a scalar value. """ - mins = jax.tree.map(jnp.min, tree) - # initializer=jnp.inf should work but pytype wants a jax.Array. - return jax.tree.reduce(jnp.minimum, mins, initializer=jnp.array(jnp.inf)) + def f(array): + if jnp.size(array) == 0: + return None + else: + return jnp.min(array) + mins = jax.tree.map(f, tree) + return jax.tree.reduce(jnp.minimum, mins, initializer=float("inf")) def tree_size(tree: Any) -> int: diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index f17857423..bd860220e 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -162,6 +162,22 @@ def test_tree_min(self, key): got = tu.tree_min(tree) np.testing.assert_allclose(expected, got) + def test_tree_min_empty(self): + tree = [jnp.ones([2, 3]), jnp.zeros([4, 0, 5])] + got = tu.tree_min(tree) + expected = 1.0 + assert expected == got + + @parameterized.product( + expected=(0, 10000, -10000, float("inf"), -float("inf")), + dtype=('int8', 'uint8', 'float32'), + which=(tu.tree_min, tu.tree_max), + ) + def test_tree_max_min_empty_dtype(self, expected, dtype, which): + tree = [expected, jnp.zeros(0, dtype)] + got = which(tree) + assert expected == got + @parameterized.parameters( 'array_a', 'tree_a', 'tree_a_dict', 'tree_b', 'tree_b_dict' )