Add invert helper function to wrap linear_solve in FunctionLinearOperator#206
Conversation
patrick-kidger
left a comment
There was a problem hiding this comment.
I think in principle this looks pretty good to me!
I think my main (only?) concern is what will happen when we compose this into larger systems — like #196, where it turns out that we're playing whack-a-mole building our own compiler over operators. That's not really a blocker but we might want to be up-front about how far we intend to pursue such optimizations — as my preference would be 'not much', mostly for maintenance time reasons.
|
Thanks Patrick, appreciate the concern. Always happy to have a call if easier to discuss. The main general "optimisation" in the spirit of #196 I still have in mind is to use For inverse operators specifically, I think the two examples of an inverse chain and Woodbury Matrix identity are good ones to reason about to think how bad this could get (happy to submit draft PRs if you'd rather see a working example). Firstly though, I think we should limit our reasoning to direct solvers only, I don't envisage many cases where Let's start with Woodbury because it's a textbook case. We could either use @aidancrilly's #97 approach of providing each matrix directly or extracting from an A_inv = lx.invert(A, solver=A_solver)
A_inv_U = jax.vmap(A_inv.mv)(lx.materialise(U).pytree)
A_inv_b= A_inv.mv(b)
S = lx.invert(C, solver=C_solver) + PyTreeLinearOperator(V.mv(A_inv_U), out_structure=...)
sol = (IdentityLinearOperator(...) - A_inv_U @ lx.invert(S) @ V)(A_inv_B)Yes I did oversimplify things and miss out on an optimisation by stacking U and b but this seems pretty elegant to me and the opposite of hard to maintain (compared to a non-operator approach at least). Just like in #196 I don't think it makes any sense to try overengineer by detecting whether it's more efficient to calculate Now the more controversial approach is the inverse chain, which would probably need to be advertised as experimental to begin with. Yes I admit this does feel a bit like building a compiler over operators but I think it could be really valuable. I think I'd just take a really simple approach to begin with where each matrix in a composed chain is inverted unless the inner dimensions are bigger than the outer, no special handling of add operators to begin with (we could allow Woodbury to handle this in the future but I admit that would get convoluted and hard to maintain). So Also if you hate this but love the Woodbury we could just do that for now and I can play around with the chain stuff outwith lineax. I have no deep thoughts/great insight on how to handle block matrices, and if it turns out that this is too hard to address in lineax so be it, at least we're doing a best effort in providing the foundation of lx.invert in the first place and maybe others can play around and work it out. 🙂 Another thing to make sure you're comfortable with is that we don't EXPLICITLY call |
patrick-kidger
left a comment
There was a problem hiding this comment.
Okay! Then in this case I think I'm pretty happy with this. Two outstanding minor nits (one new one + the cache one) and then we can merge this. I'm currently aiming to merge this and #164 and then do a new release.
|
Not sure if you're still keen on having the helper in lineax main given the new stop gradient handling, but if you are one possible extension would be to add |
This is mostly a bug fix, documentation and under-the-hood performance improvement release with one new feature—the `lx.invert` [transformation](https://docs.kidger.site/lineax/api/linear_solve/#invert) which produces an operator representing the inverse of a matrix. Use of coloring rules should make using implicit solvers in [diffrax](https://docs.kidger.site/diffrax/) for tridiagonal `Jacobian/FunctionLinearOperator`s at least an order of magnitude faster. ## Breaking Changes * Extraction of diagonal/tridiagonals of now leverages the promise of a matrix being tagged as diagonal/tridiagonal more heavily. If you have previously used the tag for an operator that you just wanted lineax to TREAT as diagonal/tridiagonal you may now get incorrect results. In most cases the right fix will probably be to first manually extract (tri)diagonal and construct the `(Tri)DiagonalOperator` explicitly, please raise an [issue](https://github.com/patrick-kidger/lineax/issues/new) if you need any further assistance. * `lineax.linear_solve` now stop-gradient's automatically (#213), it is unlikely this will break any existing use-cases but may make manual stop-gradienting unecessary * Removed AuxLinearOperator (#203) ## Features * Add invert helper function to wrap `lineax.linear_solve` in `FunctionLinearOperator`. Materialising an inverse is now as simple as `lx.invert(op).as_matrix()`. (#206) ## Compatibility * lineax v0.1.1 now requires JAX >= 0.10.0 which provides a lowering to LAPACK/cuSolver's` ormqr for more efficient QR solve adopted in #219. ## Bugfixes * Fix derived tag check rules for composite operators (e.g. `Composed/Neg/Mul/AddLinearOperator`) (#192) * Linearisation of functions `custom_vjp`'s are now supported by `lineax.linearize(JacobianLinearOperator(f, x, jac="bwd"))` by using `jax.linear_transpose` under the hood. (#191) * Complex positive/negative semi-definite matrices no longer register as symmetric (#200) * `lineax.LSMR` no longer fails when initial residual is exactly zero. (HUGE thanks to @f0uriest for spotting this tricky and hard-to-spot bug #202) * Differentiating through `linear_solve`'s no longer differentiates through `solver.init` this means using solver's with no or incorrect jvp rule is now possible (#212) ## Performance * Coloring rules now used to _massively_ speed up diagonal/tridiagonal extraction of tagged `Jacobian/FunctionLinearOperators` (#164, #165) * Normal and iterative solvers now apply `lineax.linearise` under the hood to avoid multiple sequential AD passes (#198) * Furthermore, `lineax.Normal(lineax.Cholesky())` now materialises the inner operator before constructing the Gram matrix (#207) * `ComposedLinearOperator.as_matrix` no longer materialises each matrix first but instead batches `mv` of the first operator over the second matrix (#196) * JAX's [ormqr](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.ormqr.html) now used for more efficient QR solves (#219) ## Documentation * The `lineax.LSMR` iterative least square solver is now properly documented (#204) after @f0uriest's #202 bug-fixes make it more robust. Other repo infra PR's not affecting Python package include #214, #216 and #218. ## New Contributors * @patrick-kidger-bot 🤖 made their first contribution in #216 **Full Changelog**: v0.1.0...v0.1.1
Looking through old issues/PR's (#96 #97) it seems a real blocker to extending lineax further to block operators, Woodbury matrix identity and chained inversion is the concept of an
InverseLinearOperatorthat allows composition (both with@and+/-). I sketched up an implementation of this (see first commit on this branch) but quickly realised it was essentially identical to aFunctionLinearOperatorwith custom tag rules and the ability to cache the state throughlx.linearise. This felt like overengineering to me and would give us one more operator to maintain (e.g. if we add more single dispatch functions or colouring rules). As such I instead decided to introduce a helper function that offered the same functionality. The key thing we lose is the ability to cache anInverseLinearOperatorwithlx.lineariseand instead we have to make the decision which one we want when callinglx.invertwith thecachekeyword.One idea I have in mind for using this is to introduce
collapse: boolkeyword argument which, whenFalseand provided with aComposedLinearOperatorand a direct solver will return a composed chain of inverse operators for each child of theComposedLinearOperator(in reverse order of course). This is usually more efficient (at least if we have a more efficient QR solver whenormqris supported in jax) unless inner dimensions are larger than outer ones as it avoids the cost of pre-multiplying. We could havecollapse=Noneto allow auto-detection of the large inner dimension case.cache=Falsewould usually be preferred here to avoid excess memory requirements (whereas for the Woodbury matrix identitycache=Truewould probably be preferred).As this is something that is not used elsewhere now, but could be used widely across the codebase in the future I appreciate that getting the design just right is important and am therefore open to making significant changes to this if you have other ideas.