prevent linear_solve calling jvp of solver.init#212
Merged
patrick-kidger merged 1 commit intoMar 18, 2026
Conversation
Collaborator
Author
Contributor
adconner
reviewed
Mar 9, 2026
|
|
||
| # Differentiating through operator only, but options has a dynamic array. | ||
| # solver.init should not be differentiated through. | ||
| jax.jvp(f, (m,), (mt,)) |
Contributor
There was a problem hiding this comment.
Maybe for completeness we could test the backward pass too:
_, f_vjp = jax.vjp(f, m)
f_vjp(vt)where earlier line 233 maybe
vt = jax.random.normal(getkey(), (3,))
Owner
|
@adconner's comment aside, this LGTM! |
add vjp test
ab774a7 to
d226590
Compare
Owner
|
Alright, merged! Thank you for this 🎉 |
Merged
jpbrodrick89
added a commit
that referenced
this pull request
May 1, 2026
This is mostly a bug fix, documentation and under-the-hood performance improvement release with one new feature—the `lx.invert` [transformation](https://docs.kidger.site/lineax/api/linear_solve/#invert) which produces an operator representing the inverse of a matrix. Use of coloring rules should make using implicit solvers in [diffrax](https://docs.kidger.site/diffrax/) for tridiagonal `Jacobian/FunctionLinearOperator`s at least an order of magnitude faster. ## Breaking Changes * Extraction of diagonal/tridiagonals of now leverages the promise of a matrix being tagged as diagonal/tridiagonal more heavily. If you have previously used the tag for an operator that you just wanted lineax to TREAT as diagonal/tridiagonal you may now get incorrect results. In most cases the right fix will probably be to first manually extract (tri)diagonal and construct the `(Tri)DiagonalOperator` explicitly, please raise an [issue](https://github.com/patrick-kidger/lineax/issues/new) if you need any further assistance. * `lineax.linear_solve` now stop-gradient's automatically (#213), it is unlikely this will break any existing use-cases but may make manual stop-gradienting unecessary * Removed AuxLinearOperator (#203) ## Features * Add invert helper function to wrap `lineax.linear_solve` in `FunctionLinearOperator`. Materialising an inverse is now as simple as `lx.invert(op).as_matrix()`. (#206) ## Compatibility * lineax v0.1.1 now requires JAX >= 0.10.0 which provides a lowering to LAPACK/cuSolver's` ormqr for more efficient QR solve adopted in #219. ## Bugfixes * Fix derived tag check rules for composite operators (e.g. `Composed/Neg/Mul/AddLinearOperator`) (#192) * Linearisation of functions `custom_vjp`'s are now supported by `lineax.linearize(JacobianLinearOperator(f, x, jac="bwd"))` by using `jax.linear_transpose` under the hood. (#191) * Complex positive/negative semi-definite matrices no longer register as symmetric (#200) * `lineax.LSMR` no longer fails when initial residual is exactly zero. (HUGE thanks to @f0uriest for spotting this tricky and hard-to-spot bug #202) * Differentiating through `linear_solve`'s no longer differentiates through `solver.init` this means using solver's with no or incorrect jvp rule is now possible (#212) ## Performance * Coloring rules now used to _massively_ speed up diagonal/tridiagonal extraction of tagged `Jacobian/FunctionLinearOperators` (#164, #165) * Normal and iterative solvers now apply `lineax.linearise` under the hood to avoid multiple sequential AD passes (#198) * Furthermore, `lineax.Normal(lineax.Cholesky())` now materialises the inner operator before constructing the Gram matrix (#207) * `ComposedLinearOperator.as_matrix` no longer materialises each matrix first but instead batches `mv` of the first operator over the second matrix (#196) * JAX's [ormqr](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.ormqr.html) now used for more efficient QR solves (#219) ## Documentation * The `lineax.LSMR` iterative least square solver is now properly documented (#204) after @f0uriest's #202 bug-fixes make it more robust. Other repo infra PR's not affecting Python package include #214, #216 and #218. ## New Contributors * @patrick-kidger-bot 🤖 made their first contribution in #216 **Full Changelog**: v0.1.0...v0.1.1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
fixes #211, previously stop_gradient was called AFTER solver.init. This meant the jvp of solver.init was still called before the return tangent being set to zero (the primal returned by the jvp is retained). If we call stop_gradient BEFORE then the jvp is never called. This means we can use primitives without jvp's (e.g. geqp3/geqrf) or with incorrect jvp primal's (e.g. qr with
pivoting=True).Note that we don't need to apply
stop_gradientto options as we do not support taking a gradient with respect to e.g. precondtioners (we just get a nondiff error).If we're happy with this I will modify my invert PR #206 to mirror the same pattern.