File tree Expand file tree Collapse file tree 2 files changed +0
-21
lines changed Expand file tree Collapse file tree 2 files changed +0
-21
lines changed Original file line number Diff line number Diff line change 2424from optax ._src import base
2525from optax ._src import numerics
2626from optax ._src import utils
27- from optax ._src .deprecations import warn_deprecated_function # pylint: disable=g-importing-member
2827from optax .transforms import _accumulation
2928from optax .transforms import _adding
3029import optax .tree
@@ -1812,16 +1811,6 @@ def update_fn(
18121811
18131812### Legacy symbols to be removed. ###
18141813
1815-
1816- @functools .partial (
1817- warn_deprecated_function , replacement = 'optax.tree.cast'
1818- )
1819- def cast_tree (
1820- tree : chex .ArrayTree , dtype : Optional [chex .ArrayDType ]
1821- ) -> chex .ArrayTree :
1822- return optax .tree .cast (tree , dtype )
1823-
1824-
18251814trace = _accumulation .trace
18261815TraceState = _accumulation .TraceState
18271816ema = _accumulation .ema
Original file line number Diff line number Diff line change 2727import jax .scipy .stats .norm as multivariate_normal
2828from optax ._src import base
2929from optax ._src import numerics
30- from optax ._src .deprecations import warn_deprecated_function # pylint: disable=g-importing-member
3130import optax .tree
3231
3332
@@ -55,15 +54,6 @@ def canonicalize_key(key_or_seed: jax.Array | int) -> jax.Array:
5554 return jax .random .key (key_or_seed )
5655
5756
58- @functools .partial (
59- warn_deprecated_function , replacement = 'optax.tree.cast'
60- )
61- def cast_tree (
62- tree : chex .ArrayTree , dtype : Optional [chex .ArrayDType ]
63- ) -> chex .ArrayTree :
64- return optax .tree .cast (tree , dtype )
65-
66-
6757def set_diags (a : jax .Array , new_diags : chex .Array ) -> chex .Array :
6858 """Set the diagonals of every DxD matrix in an input of shape NxDxD.
6959
You can’t perform that action at this time.
0 commit comments