Skip to content

Commit a02de84

Browse files
vrouletOptaxDev
authored andcommitted
Trimming the library.
Remove tests of deprecated functions. Reschedule deletion of deprecated functions for earlier future releases. PiperOrigin-RevId: 777652709
1 parent 5aad135 commit a02de84

File tree

9 files changed

+34
-1220
lines changed

9 files changed

+34
-1220
lines changed

docs/api/control_variates.rst

Lines changed: 0 additions & 21 deletions
This file was deleted.

docs/api/stochastic_gradient_estimators.rst

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
1-
Stochastic Gradient Estimators
2-
==============================
1+
Stochastic Gradient Estimators and Control Variates
2+
===================================================
33

44
.. warning::
5-
This module has been deprecated and will be removed in optax 0.3.0.
5+
This module has been deprecated and will be removed in optax 0.2.7
66

77
.. currentmodule:: optax.monte_carlo
88

99
.. autosummary::
10+
control_delta_method
11+
control_variates_jacobians
1012
measure_valued_jacobians
13+
moving_avg_baseline
1114
pathwise_jacobians
1215
score_function_jacobians
1316

17+
18+
Control delta method
19+
~~~~~~~~~~~~~~~~~~~~
20+
.. autofunction:: control_delta_method
21+
22+
Control variates Jacobians
23+
~~~~~~~~~~~~~~~~~~~~~~~~~~
24+
.. autofunction:: control_variates_jacobians
25+
26+
Moving average baseline
27+
~~~~~~~~~~~~~~~~~~~~~~~
28+
.. autofunction:: moving_avg_baseline
29+
1430
Measure valued Jacobians
1531
~~~~~~~~~~~~~~~~~~~~~~~~
1632
.. autofunction:: measure_valued_jacobians

optax/_src/alias.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,14 +1358,14 @@ def noisy_sgd(
13581358
"""
13591359
if seed is not None:
13601360
warnings.warn(
1361-
'"seed" is deprecated and will be removed in optax 0.3.0, use "key".',
1361+
'"seed" is deprecated and will be removed in optax 0.2.7, use "key".',
13621362
DeprecationWarning,
13631363
)
13641364
if key is not None:
13651365
raise ValueError('Only one of seed or key can be specified.')
13661366
key = jax.random.key(seed)
13671367
if key is None:
1368-
warnings.warn('Specifying a key will be required in optax 0.3.0.')
1368+
warnings.warn('Specifying a key will be required in optax 0.2.7.')
13691369
key = jax.random.key(0)
13701370
key = utils.canonicalize_key(key)
13711371

optax/contrib/_privacy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ def differentially_private_aggregate(
7070

7171
if seed is not None:
7272
warnings.warn(
73-
'"seed" is deprecated and will be removed in optax 0.3.0, use "key".',
73+
'"seed" is deprecated and will be removed in optax 0.2.7, use "key".',
7474
DeprecationWarning,
7575
)
7676
if key is not None:
7777
raise ValueError('Only one of seed or key can be specified.')
7878
key = jax.random.key(seed)
7979
if key is None:
80-
warnings.warn('Specifying a key will be required in optax 0.3.0.')
80+
warnings.warn('Specifying a key will be required in optax 0.2.7.')
8181
key = jax.random.key(0)
8282
key = utils.canonicalize_key(key)
8383

optax/monte_carlo/control_variates.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def control_delta_method(
9393
state.
9494
9595
.. deprecated:: 0.2.4
96-
This function will be removed in 0.3.0
96+
This function will be removed in 0.2.7
9797
"""
9898

9999
def delta(
@@ -171,7 +171,7 @@ def moving_avg_baseline(
171171
state.
172172
173173
.. deprecated:: 0.2.4
174-
This function will be removed in 0.3.0
174+
This function will be removed in 0.2.7
175175
"""
176176

177177
def moving_avg(
@@ -281,7 +281,7 @@ def control_variates_jacobians(
281281
* The updated CV state.
282282
283283
.. deprecated:: 0.2.4
284-
This function will be removed in 0.3.0
284+
This function will be removed in 0.2.7
285285
"""
286286
control_variate = control_variate_from_function(function)
287287
stochastic_cv, expected_value_cv, update_state_cv = control_variate
@@ -418,7 +418,7 @@ def estimate_control_variate_coefficients(
418418
in `params`.
419419
420420
.. deprecated:: 0.2.4
421-
This function will be removed in 0.3.0
421+
This function will be removed in 0.2.7
422422
"""
423423
# Resample to avoid biased gradients.
424424
cv_rng, _ = jax.random.split(rng)

0 commit comments

Comments
 (0)