Skip to content

Use ormqr for more efficient QR solve#219

Merged
jpbrodrick89 merged 1 commit into
mainfrom
jpb/ormqr
Apr 30, 2026
Merged

Use ormqr for more efficient QR solve#219
jpbrodrick89 merged 1 commit into
mainfrom
jpb/ormqr

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Collaborator

Typically on CPU this leads to a pretty stable 1.5x speedup in a QR solve (a 1000 times speed up in the multiplication itself but factorisation is still a heavy part of the solve), on GPU wins are mostly for large matrices but when jax-ml/jax#36575 is merged wins will be realised across the board (sometimes by >8x). We just adopt this as the default and hide complexity from the user.

I think this is the last change I'd like to make before the next release.

@adconner I'd appreciate a quick pass over sanity check if you have time. I'm very aware that your pivoted QR PR is still open, my current thinking is that we should focus on doing a more complete least square focussed release next. The first foundation that I think might be help is having max rank tags and tracking rather than passing this to the solver. However, if you'd find this useful and rather it released now rather than in a few months I can try help iterate with you on a temporary design (I'm currently leaning towards the QR(pivoted=True) scipy/numpy homage API rather than PivotedQR).

@patrick-kidger note this means I have to bump Jax's minimum version to 0.10.0, are you comfortable with this? Does this have any implications on what version we should call the next release.

@patrick-kidger
Copy link
Copy Markdown
Owner

note this means I have to bump Jax's minimum version to 0.10.0, are you comfortable with this? Does this have any implications on what version we should call the next release.

That's okay! I'd usually trigger the downstream Optimistix/Diffrax CI suites (just manually locally, no nice setup) on top of the new JAX+Lineax to be sure that the new JAX version doesn't break anything for them i.e. that we don't release a Lineax version that is incompatible with them.

Version number bump would just be patch, not minor.

@adconner
Copy link
Copy Markdown
Contributor

Looks good to me. Glad to see this added to lineax and also appreciate your work improving the GPU situation with jax-ml/jax#36575.

I think I am in general in favor of designs which let operators express rank assumptions/knowledge in tags (in addition to other properties). In the most expressive case of this design you can track rank lower bound and rank upper bound as tags.

I briefly discussed just a full_rank tag with @patrick-kidger in the context of #158, as for instance if this is the source of truth, then (1) jvp terms are always most efficient regardless of if a more general solver is used, (2) the assume_full_rank solver special method is no longer needed, solvers just check rank tags in their init just like other tags, and (3), AutoLinearSolver could be truly auto, not needing well_defined argument.

I think we landed on essentially that this feels too different from the way this assumption is communicated in other libraries, and basically that solver selection should be considered part of the source of truth for what the operators are. Maybe if we need rank tracking for other reasons we can consider an expanded design

@jpbrodrick89 jpbrodrick89 merged commit 53edb88 into main Apr 30, 2026
1 check passed
@jpbrodrick89
Copy link
Copy Markdown
Collaborator Author

Thanks @adconner , that's a very good point of reference I had read through before and had forgotten about. I think this is why future least squares development needs a more careful design discussion, I'll probably open an issue to discuss ideas for this in a week or so.

The other idea I had for a rank tracker is a Woodbury solver that could take an AddLinearOperator and detect which one is full rank and which one is a low-rank ComposedLinearOperator, but I'm still not sure that's a good design either.

I'll write up some release notes tomorrow for the current state.

@jpbrodrick89
Copy link
Copy Markdown
Collaborator Author

Sorry @patrick-kidger I was sloppy and merged without the diffrax and optimistix tests, I will set them up to run tomorrow and flag/handle/revert if any issues.

@jpbrodrick89 jpbrodrick89 mentioned this pull request May 1, 2026
jpbrodrick89 added a commit that referenced this pull request May 1, 2026
This is mostly a bug fix, documentation and under-the-hood performance improvement release with one new feature—the `lx.invert` [transformation](https://docs.kidger.site/lineax/api/linear_solve/#invert) which produces an operator representing the inverse of a matrix. Use of coloring rules should make using implicit solvers in [diffrax](https://docs.kidger.site/diffrax/) for tridiagonal `Jacobian/FunctionLinearOperator`s at least an order of magnitude faster.

## Breaking Changes

* Extraction of diagonal/tridiagonals of now leverages the promise of a matrix being tagged as diagonal/tridiagonal more heavily. If you have previously used the tag for an operator that you just wanted lineax to TREAT as diagonal/tridiagonal you may now get incorrect results. In most cases the right fix will probably be to first manually extract (tri)diagonal and construct the `(Tri)DiagonalOperator` explicitly, please raise an [issue](https://github.com/patrick-kidger/lineax/issues/new) if you need any further assistance.
* `lineax.linear_solve` now stop-gradient's automatically (#213), it is unlikely this will break any existing use-cases but may make manual stop-gradienting unecessary
* Removed AuxLinearOperator (#203)

## Features

* Add invert helper function to wrap `lineax.linear_solve` in `FunctionLinearOperator`. Materialising an inverse is now as simple as `lx.invert(op).as_matrix()`. (#206)
 
## Compatibility

* lineax v0.1.1 now requires JAX >= 0.10.0 which provides a lowering to LAPACK/cuSolver's` ormqr for more efficient QR solve adopted in #219.

## Bugfixes

* Fix derived tag check rules for composite operators (e.g. `Composed/Neg/Mul/AddLinearOperator`) (#192)
* Linearisation of functions `custom_vjp`'s are now supported by `lineax.linearize(JacobianLinearOperator(f, x, jac="bwd"))` by using `jax.linear_transpose` under the hood. (#191)
* Complex positive/negative semi-definite matrices no longer register as symmetric (#200)
* `lineax.LSMR` no longer fails when initial residual is exactly zero. (HUGE thanks to @f0uriest for spotting this tricky and hard-to-spot bug #202)
* Differentiating through `linear_solve`'s no longer differentiates through `solver.init` this means using solver's with no or incorrect jvp rule is now possible (#212)

## Performance

* Coloring rules now used to _massively_ speed up diagonal/tridiagonal extraction of tagged `Jacobian/FunctionLinearOperators`  (#164, #165)
* Normal and iterative solvers now apply `lineax.linearise` under the hood to avoid multiple sequential AD passes (#198)
* Furthermore, `lineax.Normal(lineax.Cholesky())` now materialises the inner operator before constructing the Gram matrix (#207)
* `ComposedLinearOperator.as_matrix` no longer materialises each matrix first but instead batches `mv` of the first operator over the second matrix (#196)
*  JAX's [ormqr](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.ormqr.html) now used for more efficient QR solves (#219)

## Documentation

* The `lineax.LSMR` iterative least square solver is now properly documented (#204) after @f0uriest's #202 bug-fixes make it more robust.

Other repo infra PR's not affecting Python package include #214, #216 and #218.


## New Contributors
* @patrick-kidger-bot 🤖 made their first contribution in #216 

**Full Changelog**: v0.1.0...v0.1.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants