Skip to content

Add invert helper function to wrap linear_solve in FunctionLinearOperator#206

Merged
jpbrodrick89 merged 13 commits into
patrick-kidger:mainfrom
jpbrodrick89:jpb/invert
Apr 29, 2026
Merged

Add invert helper function to wrap linear_solve in FunctionLinearOperator#206
jpbrodrick89 merged 13 commits into
patrick-kidger:mainfrom
jpbrodrick89:jpb/invert

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Collaborator

Looking through old issues/PR's (#96 #97) it seems a real blocker to extending lineax further to block operators, Woodbury matrix identity and chained inversion is the concept of an InverseLinearOperator that allows composition (both with @ and +/-). I sketched up an implementation of this (see first commit on this branch) but quickly realised it was essentially identical to a FunctionLinearOperator with custom tag rules and the ability to cache the state through lx.linearise. This felt like overengineering to me and would give us one more operator to maintain (e.g. if we add more single dispatch functions or colouring rules). As such I instead decided to introduce a helper function that offered the same functionality. The key thing we lose is the ability to cache an InverseLinearOperator with lx.linearise and instead we have to make the decision which one we want when calling lx.invert with the cache keyword.

One idea I have in mind for using this is to introduce collapse: bool keyword argument which, when False and provided with a ComposedLinearOperator and a direct solver will return a composed chain of inverse operators for each child of the ComposedLinearOperator(in reverse order of course). This is usually more efficient (at least if we have a more efficient QR solver when ormqr is supported in jax) unless inner dimensions are larger than outer ones as it avoids the cost of pre-multiplying. We could have collapse=None to allow auto-detection of the large inner dimension case. cache=False would usually be preferred here to avoid excess memory requirements (whereas for the Woodbury matrix identity cache=True would probably be preferred).

As this is something that is not used elsewhere now, but could be used widely across the codebase in the future I appreciate that getting the design just right is important and am therefore open to making significant changes to this if you have other ideas.

Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in principle this looks pretty good to me!

I think my main (only?) concern is what will happen when we compose this into larger systems — like #196, where it turns out that we're playing whack-a-mole building our own compiler over operators. That's not really a blocker but we might want to be up-front about how far we intend to pursue such optimizations — as my preference would be 'not much', mostly for maintenance time reasons.

Comment thread lineax/_solve.py Outdated
@jpbrodrick89
Copy link
Copy Markdown
Collaborator Author

jpbrodrick89 commented Feb 27, 2026

Thanks Patrick, appreciate the concern. Always happy to have a call if easier to discuss. The main general "optimisation" in the spirit of #196 I still have in mind is to use multi_dot in ComposedLinearOperator.mv for shape optimisation only when it supports batching properly (see Jax issue 35308). As mentioned in #196 I don't propose optimising this further based on operator type (or depending on the often misleading xla floor count) or linear transposing, MAYBE based on structure (through tags) but I don't have any ideas for this right now. As such, inverse operator would act like any other when composing and its ordering would depend purely on shape (this is reasonable as once state is cached most solvers have similar scaling to matmuls), when adding to another operator lineax already does the simple usually preferred approach of just applying each operator separately. I also don't think it's possible or sensible to attempt to automatically detect independent applications of an inverse and vmap under the hood (e.g. lx.invert(A) @ C + B @ lx.invert(A)). We either just rely on XLA or do this manually in the solver as we would for the Woodbury case.

For inverse operators specifically, I think the two examples of an inverse chain and Woodbury Matrix identity are good ones to reason about to think how bad this could get (happy to submit draft PRs if you'd rather see a working example). Firstly though, I think we should limit our reasoning to direct solvers only, I don't envisage many cases where lx.invert would be sensible for iterative operators.

Let's start with Woodbury because it's a textbook case. We could either use @aidancrilly's #97 approach of providing each matrix directly or extracting from an AddLinearOperator (the key thing to note is that lx.invert doesn't FORCE us into either implementation and if we decide the latter is to complex/hard to maintain we can opt for the more explicit approach). Either way we'd end up M=A+UCV with A and C square. We will assume both A_solver, C_solver, S_solver are direct square solvers and that we cannot detect sparsity for S (so calling as_matrix is always mandatory) as I don't think it makes any sense to use an iterative solver. The code would then look something like this if cache defaults to True and assuming for simplicity that C has flat in and out structure (note I think we handle pytrees out the box):

A_inv = lx.invert(A, solver=A_solver)
A_inv_U = jax.vmap(A_inv.mv)(lx.materialise(U).pytree)
A_inv_b= A_inv.mv(b)
S = lx.invert(C, solver=C_solver) + PyTreeLinearOperator(V.mv(A_inv_U), out_structure=...)
sol = (IdentityLinearOperator(...) - A_inv_U @ lx.invert(S) @ V)(A_inv_B)

Yes I did oversimplify things and miss out on an optimisation by stacking U and b but this seems pretty elegant to me and the opposite of hard to maintain (compared to a non-operator approach at least). Just like in #196 I don't think it makes any sense to try overengineer by detecting whether it's more efficient to calculate A_inv_U or V_A_inv.

Now the more controversial approach is the inverse chain, which would probably need to be advertised as experimental to begin with. Yes I admit this does feel a bit like building a compiler over operators but I think it could be really valuable. I think I'd just take a really simple approach to begin with where each matrix in a composed chain is inverted unless the inner dimensions are bigger than the outer, no special handling of add operators to begin with (we could allow Woodbury to handle this in the future but I admit that would get convoluted and hard to maintain).

So lx.invert(A@(B+C)) -> lx.invert(B+C) @ lx.invert(A)

Also if you hate this but love the Woodbury we could just do that for now and I can play around with the chain stuff outwith lineax.

I have no deep thoughts/great insight on how to handle block matrices, and if it turns out that this is too hard to address in lineax so be it, at least we're doing a best effort in providing the foundation of lx.invert in the first place and maybe others can play around and work it out. 🙂

Another thing to make sure you're comfortable with is that we don't EXPLICITLY call solver.transpose or solver.conj. I think transpose is fine as it will call your custom transpose rule which calls these under the hood. I'm not 100% sure if there's a great efficiency loss when calling lx.conj(lx.invert(op)) but my gut feel is that it doesn't matte

Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! Then in this case I think I'm pretty happy with this. Two outstanding minor nits (one new one + the cache one) and then we can merge this. I'm currently aiming to merge this and #164 and then do a new release.

Comment thread lineax/_solve.py Outdated
@jpbrodrick89
Copy link
Copy Markdown
Collaborator Author

jpbrodrick89 commented Mar 12, 2026

Not sure if you're still keen on having the helper in lineax main given the new stop gradient handling, but if you are one possible extension would be to add ("has_inverse", op) tag to returned inverse operator or similar to tell lineax we already have the inverse of the inverse (the original operator) computed. Probably overly complex/overengineering for now but just thought I'd mention as it crossed my mind. 🙂

@jpbrodrick89 jpbrodrick89 merged commit 0cb6090 into patrick-kidger:main Apr 29, 2026
1 check passed
@jpbrodrick89 jpbrodrick89 mentioned this pull request May 1, 2026
jpbrodrick89 added a commit that referenced this pull request May 1, 2026
This is mostly a bug fix, documentation and under-the-hood performance improvement release with one new feature—the `lx.invert` [transformation](https://docs.kidger.site/lineax/api/linear_solve/#invert) which produces an operator representing the inverse of a matrix. Use of coloring rules should make using implicit solvers in [diffrax](https://docs.kidger.site/diffrax/) for tridiagonal `Jacobian/FunctionLinearOperator`s at least an order of magnitude faster.

## Breaking Changes

* Extraction of diagonal/tridiagonals of now leverages the promise of a matrix being tagged as diagonal/tridiagonal more heavily. If you have previously used the tag for an operator that you just wanted lineax to TREAT as diagonal/tridiagonal you may now get incorrect results. In most cases the right fix will probably be to first manually extract (tri)diagonal and construct the `(Tri)DiagonalOperator` explicitly, please raise an [issue](https://github.com/patrick-kidger/lineax/issues/new) if you need any further assistance.
* `lineax.linear_solve` now stop-gradient's automatically (#213), it is unlikely this will break any existing use-cases but may make manual stop-gradienting unecessary
* Removed AuxLinearOperator (#203)

## Features

* Add invert helper function to wrap `lineax.linear_solve` in `FunctionLinearOperator`. Materialising an inverse is now as simple as `lx.invert(op).as_matrix()`. (#206)
 
## Compatibility

* lineax v0.1.1 now requires JAX >= 0.10.0 which provides a lowering to LAPACK/cuSolver's` ormqr for more efficient QR solve adopted in #219.

## Bugfixes

* Fix derived tag check rules for composite operators (e.g. `Composed/Neg/Mul/AddLinearOperator`) (#192)
* Linearisation of functions `custom_vjp`'s are now supported by `lineax.linearize(JacobianLinearOperator(f, x, jac="bwd"))` by using `jax.linear_transpose` under the hood. (#191)
* Complex positive/negative semi-definite matrices no longer register as symmetric (#200)
* `lineax.LSMR` no longer fails when initial residual is exactly zero. (HUGE thanks to @f0uriest for spotting this tricky and hard-to-spot bug #202)
* Differentiating through `linear_solve`'s no longer differentiates through `solver.init` this means using solver's with no or incorrect jvp rule is now possible (#212)

## Performance

* Coloring rules now used to _massively_ speed up diagonal/tridiagonal extraction of tagged `Jacobian/FunctionLinearOperators`  (#164, #165)
* Normal and iterative solvers now apply `lineax.linearise` under the hood to avoid multiple sequential AD passes (#198)
* Furthermore, `lineax.Normal(lineax.Cholesky())` now materialises the inner operator before constructing the Gram matrix (#207)
* `ComposedLinearOperator.as_matrix` no longer materialises each matrix first but instead batches `mv` of the first operator over the second matrix (#196)
*  JAX's [ormqr](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.ormqr.html) now used for more efficient QR solves (#219)

## Documentation

* The `lineax.LSMR` iterative least square solver is now properly documented (#204) after @f0uriest's #202 bug-fixes make it more robust.

Other repo infra PR's not affecting Python package include #214, #216 and #218.


## New Contributors
* @patrick-kidger-bot 🤖 made their first contribution in #216 

**Full Changelog**: v0.1.0...v0.1.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants