Skip to content

Commit ccaff21

Browse files
vrouletOptaxDev
authored andcommitted
Remove use of cast_tree in favor of optax.tree.cast
PiperOrigin-RevId: 830951266
1 parent cc751e3 commit ccaff21

File tree

2 files changed

+0
-21
lines changed

2 files changed

+0
-21
lines changed

optax/_src/transform.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from optax._src import base
2525
from optax._src import numerics
2626
from optax._src import utils
27-
from optax._src.deprecations import warn_deprecated_function # pylint: disable=g-importing-member
2827
from optax.transforms import _accumulation
2928
from optax.transforms import _adding
3029
import optax.tree
@@ -1812,16 +1811,6 @@ def update_fn(
18121811

18131812
### Legacy symbols to be removed. ###
18141813

1815-
1816-
@functools.partial(
1817-
warn_deprecated_function, replacement='optax.tree.cast'
1818-
)
1819-
def cast_tree(
1820-
tree: chex.ArrayTree, dtype: Optional[chex.ArrayDType]
1821-
) -> chex.ArrayTree:
1822-
return optax.tree.cast(tree, dtype)
1823-
1824-
18251814
trace = _accumulation.trace
18261815
TraceState = _accumulation.TraceState
18271816
ema = _accumulation.ema

optax/_src/utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import jax.scipy.stats.norm as multivariate_normal
2828
from optax._src import base
2929
from optax._src import numerics
30-
from optax._src.deprecations import warn_deprecated_function # pylint: disable=g-importing-member
3130
import optax.tree
3231

3332

@@ -55,15 +54,6 @@ def canonicalize_key(key_or_seed: jax.Array | int) -> jax.Array:
5554
return jax.random.key(key_or_seed)
5655

5756

58-
@functools.partial(
59-
warn_deprecated_function, replacement='optax.tree.cast'
60-
)
61-
def cast_tree(
62-
tree: chex.ArrayTree, dtype: Optional[chex.ArrayDType]
63-
) -> chex.ArrayTree:
64-
return optax.tree.cast(tree, dtype)
65-
66-
6757
def set_diags(a: jax.Array, new_diags: chex.Array) -> chex.Array:
6858
"""Set the diagonals of every DxD matrix in an input of shape NxDxD.
6959

0 commit comments

Comments
 (0)