Skip to content

Commit d080f09

Browse files
committed
Add linear program solver based on the restarted Halpern primal-dual hybrid gradient (rHPDHG) algorithm.
1 parent 3d8c391 commit d080f09

File tree

12 files changed

+591
-3
lines changed

12 files changed

+591
-3
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ disable=R,
129129
wrong-import-order,
130130
xrange-builtin,
131131
zip-builtin-not-iterating,
132+
invalid-name,
132133

133134

134135
[REPORTS]

docs/api/linprog.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Linear programming
2+
==================
3+
4+
.. currentmodule:: optax.linprog
5+
6+
.. autosummary::
7+
rhpdhg
8+
9+
10+
Restarted Halpern primal-dual hybrid gradient method
11+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
12+
.. autofunction:: rhpdhg

docs/gallery.rst

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@
209209
.. only:: html
210210

211211
.. image:: /images/examples/linear_assignment_problem.png
212-
:alt:
212+
:alt: Linear assignment problem.
213213

214214
:doc:`_collections/examples/linear_assignment_problem`
215215

@@ -219,6 +219,23 @@
219219
</div>
220220

221221

222+
.. raw:: html
223+
224+
<div class="sphx-glr-thumbcontainer" tooltip="Linear programming.">
225+
226+
.. only:: html
227+
228+
.. image:: /images/examples/linear_programming.png
229+
:alt: Linear programming.
230+
231+
:doc:`_collections/examples/linear_programming`
232+
233+
.. raw:: html
234+
235+
<div class="sphx-glr-thumbnail-title">Linear programming.</div>
236+
</div>
237+
238+
222239
.. raw:: html
223240

224241
</div>
76.8 KB
Loading

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ for instructions on installing JAX.
5454
:caption: 📖 Reference
5555
:maxdepth: 2
5656

57+
api/linprog
5758
api/assignment
5859
api/optimizers
5960
api/transformations

examples/linear_programming.ipynb

Lines changed: 229 additions & 0 deletions
Large diffs are not rendered by default.

optax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from optax import assignment
2121
from optax import contrib
22+
from optax import linprog
2223
from optax import losses
2324
from optax import monte_carlo
2425
from optax import perturbations
@@ -364,6 +365,7 @@
364365
"lion",
365366
"linear_onecycle_schedule",
366367
"linear_schedule",
368+
"linprog",
367369
"log_cosh",
368370
"lookahead",
369371
"LookaheadParams",

optax/_src/alias.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2482,7 +2482,7 @@ def lbfgs(
24822482
... )
24832483
... params = optax.apply_updates(params, updates)
24842484
... print('Objective function: ', f(params))
2485-
Objective function: 7.5166864
2485+
Objective function: 7.516686...
24862486
Objective function: 7.460699e-14
24872487
Objective function: 2.6505726e-28
24882488
Objective function: 0.0

optax/linprog/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""The linear programming sub-package."""
16+
17+
# pylint:disable=g-importing-member
18+
19+
from optax.linprog._rhpdhg import solve_general as rhpdhg

optax/linprog/_rhpdhg.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""The restarted Halpern primal-dual hybrid gradient method."""
16+
17+
from jax import lax, numpy as jnp
18+
from optax import tree_utils as otu
19+
20+
21+
def solve_canonical(
22+
c, A, b, iters, reflect=True, restarts=True, tau=None, sigma=None
23+
):
24+
r"""Solves a linear program using the restarted Halpern primal-dual hybrid
25+
gradient (RHPDHG) method.
26+
27+
Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`x \geq 0`.
28+
29+
See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_.
30+
31+
Args:
32+
c: Cost vector.
33+
A: Equality constraint matrix.
34+
b: Equality constraint vector.
35+
iters: Number of iterations to run the solver for.
36+
reflect: Use reflection. See paper for details.
37+
restarts: Use restarts. See paper for details.
38+
tau: Primal step size. See paper for details.
39+
sigma: Dual step size. See paper for details.
40+
41+
Returns:
42+
A dictionary whose entries are as follows:
43+
- primal: The final primal solution.
44+
- dual: The final dual solution.
45+
- primal_iterates: The primal iterates.
46+
- dual_iterates: The dual iterates.
47+
48+
Examples:
49+
>>> from jax import numpy as jnp
50+
>>> import optax
51+
>>> c = -jnp.array([2, 1])
52+
>>> A = jnp.zeros([0, 2])
53+
>>> b = jnp.zeros(0)
54+
>>> G = jnp.array([[3, 1], [1, 1], [1, 4]])
55+
>>> h = jnp.array([21, 9, 24])
56+
>>> x = optax.linprog.rhpdhg(c, A, b, G, h, 1_000_000)['primal']
57+
>>> print(x[0])
58+
5.99...
59+
>>> print(x[1])
60+
2.99...
61+
62+
References:
63+
Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming
64+
<https://arxiv.org/abs/2407.16144>`_, 2024
65+
"""
66+
67+
if tau is None or sigma is None:
68+
A_norm = jnp.linalg.norm(A, axis=(0, 1), ord=2)
69+
if tau is None:
70+
tau = 1 / (2 * A_norm)
71+
if sigma is None:
72+
sigma = 1 / (2 * A_norm)
73+
74+
def T(z):
75+
# primal dual hybrid gradient (PDHG)
76+
x, y = z
77+
xn = x + tau * (y @ A - c)
78+
xn = xn.clip(min=0)
79+
yn = y + sigma * (b - A @ (2 * xn - x))
80+
return xn, yn
81+
82+
def H(z, k, z0):
83+
# Halpern PDHG
84+
Tz = T(z)
85+
if reflect:
86+
zc = otu.tree_sub(otu.tree_scalar_mul(2, Tz), z)
87+
else:
88+
zc = Tz
89+
kp2 = k + 2
90+
zn = otu.tree_add(
91+
otu.tree_scalar_mul((k + 1) / kp2, zc),
92+
otu.tree_scalar_mul(1 / kp2, z0),
93+
)
94+
return zn, Tz
95+
96+
def update(carry, _):
97+
z, k, z0, d0 = carry
98+
zn, Tz = H(z, k, z0)
99+
100+
if restarts:
101+
d = otu.tree_l2_norm(otu.tree_sub(z, Tz), squared=True)
102+
restart = d <= d0 * jnp.exp(-2)
103+
new_carry = otu.tree_where(
104+
restart,
105+
(zn, 0, zn, d),
106+
(zn, k + 1, z0, d0),
107+
)
108+
else:
109+
new_carry = zn, k + 1, z0, d0
110+
111+
return new_carry, z
112+
113+
def run():
114+
m, n = A.shape
115+
x = jnp.zeros(n)
116+
y = jnp.zeros(m)
117+
z0 = x, y
118+
d0 = otu.tree_l2_norm(otu.tree_sub(z0, T(z0)), squared=True)
119+
(z, _, _, _), zs = lax.scan(update, (z0, 0, z0, d0), length=iters)
120+
x, y = z
121+
xs, ys = zs
122+
return {
123+
"primal": x,
124+
"dual": y,
125+
"primal_iterates": xs,
126+
"dual_iterates": ys,
127+
}
128+
129+
return run()
130+
131+
132+
def general_to_canonical(c, A, b, G, h):
133+
"""Converts a linear program from general form to canonical form.
134+
135+
The solution to the new linear program will consist of the concatenation of
136+
- the positive part of x
137+
- the negative part of x
138+
- slacks
139+
140+
That is, we go from
141+
142+
Minimize c · x subject to
143+
A x = b
144+
G x ≤ h
145+
146+
to
147+
148+
Minimize c · (x⁺ - x⁻) subject to
149+
A (x⁺ - x⁻) = b
150+
G (x⁺ - x⁻) + s = h
151+
x⁺, x⁻, s ≥ 0
152+
153+
Args:
154+
c: Cost vector.
155+
A: Equality constraint matrix.
156+
b: Equality constraint vector.
157+
G: Inequality constraint matrix.
158+
h: Inequality constraint vector.
159+
160+
Returns:
161+
A triple (c', A', b') representing the corresponding canonical form.
162+
"""
163+
c_can = jnp.concatenate([c, -c, jnp.zeros(h.size)])
164+
G_ = jnp.concatenate([G, -G, jnp.eye(h.size)], 1)
165+
A_ = jnp.concatenate([A, -A, jnp.zeros([b.size, h.size])], 1)
166+
A_can = jnp.concatenate([A_, G_], 0)
167+
b_can = jnp.concatenate([b, h])
168+
return c_can, A_can, b_can
169+
170+
171+
def solve_general(
172+
c, A, b, G, h, iters, reflect=True, restarts=True, tau=None, sigma=None
173+
):
174+
r"""Solves a linear program using the restarted Halpern primal-dual hybrid
175+
gradient (RHPDHG) method.
176+
177+
Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`G x \leq h`.
178+
179+
See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_.
180+
181+
Args:
182+
c: Cost vector.
183+
A: Equality constraint matrix.
184+
b: Equality constraint vector.
185+
G: Inequality constraint matrix.
186+
h: Inequality constraint vector.
187+
iters: Number of iterations to run the solver for.
188+
reflect: Use reflection. See paper for details.
189+
restarts: Use restarts. See paper for details.
190+
tau: Primal step size. See paper for details.
191+
sigma: Dual step size. See paper for details.
192+
193+
Returns:
194+
A dictionary whose entries are as follows:
195+
- primal: The final primal solution.
196+
- slacks: The final primal slack values.
197+
- canonical_result: The result for the canonical program that was used
198+
internally to find this solution. See paper for details.
199+
200+
References:
201+
Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming
202+
<https://arxiv.org/abs/2407.16144>`_, 2024
203+
"""
204+
canonical = general_to_canonical(c, A, b, G, h)
205+
result = solve_canonical(*canonical, iters, reflect, restarts, tau, sigma)
206+
x_pos, x_neg, slacks = jnp.split(result["primal"], [c.size, c.size * 2])
207+
return {
208+
"primal": x_pos - x_neg,
209+
"slacks": slacks,
210+
"canonical_result": result,
211+
}

0 commit comments

Comments
 (0)