Skip to content

prevent linear_solve calling jvp of solver.init#212

Merged
patrick-kidger merged 1 commit into
patrick-kidger:mainfrom
jpbrodrick89:jpb/fix-stop-gradient
Mar 18, 2026
Merged

prevent linear_solve calling jvp of solver.init#212
patrick-kidger merged 1 commit into
patrick-kidger:mainfrom
jpbrodrick89:jpb/fix-stop-gradient

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Collaborator

@jpbrodrick89 jpbrodrick89 commented Mar 9, 2026

fixes #211, previously stop_gradient was called AFTER solver.init. This meant the jvp of solver.init was still called before the return tangent being set to zero (the primal returned by the jvp is retained). If we call stop_gradient BEFORE then the jvp is never called. This means we can use primitives without jvp's (e.g. geqp3/geqrf) or with incorrect jvp primal's (e.g. qr with pivoting=True).

Note that we don't need to apply stop_gradient to options as we do not support taking a gradient with respect to e.g. precondtioners (we just get a nondiff error).

If we're happy with this I will modify my invert PR #206 to mirror the same pattern.

@jpbrodrick89
Copy link
Copy Markdown
Collaborator Author

@adconner lmk if this gets your #210 to pass tests.

@adconner adconner mentioned this pull request Mar 9, 2026
@adconner
Copy link
Copy Markdown
Contributor

adconner commented Mar 9, 2026

@adconner lmk if this gets your #210 to pass tests.

Yes fixed!

Comment thread tests/test_solve.py

# Differentiating through operator only, but options has a dynamic array.
# solver.init should not be differentiated through.
jax.jvp(f, (m,), (mt,))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for completeness we could test the backward pass too:

_, f_vjp = jax.vjp(f, m)
f_vjp(vt)

where earlier line 233 maybe

vt = jax.random.normal(getkey(), (3,))

@patrick-kidger
Copy link
Copy Markdown
Owner

@adconner's comment aside, this LGTM!

@patrick-kidger patrick-kidger force-pushed the jpb/fix-stop-gradient branch from ab774a7 to d226590 Compare March 17, 2026 23:30
@patrick-kidger patrick-kidger merged commit b6f3087 into patrick-kidger:main Mar 18, 2026
1 check passed
@patrick-kidger
Copy link
Copy Markdown
Owner

Alright, merged! Thank you for this 🎉

@jpbrodrick89 jpbrodrick89 mentioned this pull request May 1, 2026
jpbrodrick89 added a commit that referenced this pull request May 1, 2026
This is mostly a bug fix, documentation and under-the-hood performance improvement release with one new feature—the `lx.invert` [transformation](https://docs.kidger.site/lineax/api/linear_solve/#invert) which produces an operator representing the inverse of a matrix. Use of coloring rules should make using implicit solvers in [diffrax](https://docs.kidger.site/diffrax/) for tridiagonal `Jacobian/FunctionLinearOperator`s at least an order of magnitude faster.

## Breaking Changes

* Extraction of diagonal/tridiagonals of now leverages the promise of a matrix being tagged as diagonal/tridiagonal more heavily. If you have previously used the tag for an operator that you just wanted lineax to TREAT as diagonal/tridiagonal you may now get incorrect results. In most cases the right fix will probably be to first manually extract (tri)diagonal and construct the `(Tri)DiagonalOperator` explicitly, please raise an [issue](https://github.com/patrick-kidger/lineax/issues/new) if you need any further assistance.
* `lineax.linear_solve` now stop-gradient's automatically (#213), it is unlikely this will break any existing use-cases but may make manual stop-gradienting unecessary
* Removed AuxLinearOperator (#203)

## Features

* Add invert helper function to wrap `lineax.linear_solve` in `FunctionLinearOperator`. Materialising an inverse is now as simple as `lx.invert(op).as_matrix()`. (#206)
 
## Compatibility

* lineax v0.1.1 now requires JAX >= 0.10.0 which provides a lowering to LAPACK/cuSolver's` ormqr for more efficient QR solve adopted in #219.

## Bugfixes

* Fix derived tag check rules for composite operators (e.g. `Composed/Neg/Mul/AddLinearOperator`) (#192)
* Linearisation of functions `custom_vjp`'s are now supported by `lineax.linearize(JacobianLinearOperator(f, x, jac="bwd"))` by using `jax.linear_transpose` under the hood. (#191)
* Complex positive/negative semi-definite matrices no longer register as symmetric (#200)
* `lineax.LSMR` no longer fails when initial residual is exactly zero. (HUGE thanks to @f0uriest for spotting this tricky and hard-to-spot bug #202)
* Differentiating through `linear_solve`'s no longer differentiates through `solver.init` this means using solver's with no or incorrect jvp rule is now possible (#212)

## Performance

* Coloring rules now used to _massively_ speed up diagonal/tridiagonal extraction of tagged `Jacobian/FunctionLinearOperators`  (#164, #165)
* Normal and iterative solvers now apply `lineax.linearise` under the hood to avoid multiple sequential AD passes (#198)
* Furthermore, `lineax.Normal(lineax.Cholesky())` now materialises the inner operator before constructing the Gram matrix (#207)
* `ComposedLinearOperator.as_matrix` no longer materialises each matrix first but instead batches `mv` of the first operator over the second matrix (#196)
*  JAX's [ormqr](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.ormqr.html) now used for more efficient QR solves (#219)

## Documentation

* The `lineax.LSMR` iterative least square solver is now properly documented (#204) after @f0uriest's #202 bug-fixes make it more robust.

Other repo infra PR's not affecting Python package include #214, #216 and #218.


## New Contributors
* @patrick-kidger-bot 🤖 made their first contribution in #216 

**Full Changelog**: v0.1.0...v0.1.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Differentiation of linear_solve causes differentiation through factorization (solver state)

3 participants