diff --git a/optax/projections/_projections_test.py b/optax/projections/_projections_test.py index 2f4ea6a48..26a6a027e 100644 --- a/optax/projections/_projections_test.py +++ b/optax/projections/_projections_test.py @@ -244,7 +244,7 @@ def test_projection_vector(self): a = (3.0, 4.0) y_actual = proj.projection_vector(x, a) y_expected = (33 / 25, 44 / 25) - assert tree_allclose(y_actual, y_expected) + assert optax.tree.allclose(y_actual, y_expected) def test_projection_hyperplane(self): x = (1.0, 2.0) @@ -252,9 +252,7 @@ def test_projection_hyperplane(self): b = 5.0 y_actual = proj.projection_hyperplane(x, a, b) y_expected = (7 / 25, 26 / 25) - print(y_actual) - print(y_expected) - assert tree_allclose(y_actual, y_expected) + assert optax.tree.allclose(y_actual, y_expected) def test_projection_halfspace_1(self): x = (1.0, 2.0) @@ -262,9 +260,7 @@ def test_projection_halfspace_1(self): b = 5.0 y_actual = proj.projection_halfspace(x, a, b) y_expected = (7 / 25, 26 / 25) - print(y_actual) - print(y_expected) - assert tree_allclose(y_actual, y_expected) + assert optax.tree.allclose(y_actual, y_expected) def test_projection_halfspace_2(self): x = (1.0, -2.0) @@ -272,15 +268,7 @@ def test_projection_halfspace_2(self): b = 5.0 y_actual = proj.projection_halfspace(x, a, b) y_expected = x - print(y_actual) - print(y_expected) - assert tree_allclose(y_actual, y_expected) - - -def tree_allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - # Replace this with optax.tree.allclose, once that's added. - return all(jax.tree.leaves(jax.tree.map(lambda a, b: jnp.allclose( - a, b, rtol=rtol, atol=atol, equal_nan=equal_nan), a, b))) + assert optax.tree.allclose(y_actual, y_expected) if __name__ == '__main__':