jvp optimisations for pseudoinverse solvers#217
Conversation
|
@patrick-kidger happy to own maintenance of the maths and performance here, but i'd appreciate your input on whether singledispatch with default |
patrick-kidger
left a comment
There was a problem hiding this comment.
So the usual thing I've done (e.g. in Diffrax) is not to make this extensible (whether via method or singledispatch), but instead to hardcode the particular solvers directly: if isinstance(solver, Normal): ...other code path....
This leaves the 'public API' unchanged, so there's room to make a choice later if/when other considerations arise. And it's edge case enough that realistically this affects nobody by being hardcoded.
| _gram_inverse_mv, | ||
| _row_space_projection, |
There was a problem hiding this comment.
nit, leading underscore means private to the file it's in.
|
Hmm, the reason I didn't do |
|
Sorry, I was being stupid here, Modified Cholesky and LDLT do NOT handle the ill-posed/low rank (dependent columns/rows) cases, they just handle indefinite matrices. The main use case would be something like SuiteSparse SPQR, MKL sparse QR which is a bit more niche (and I don't think spineax aims to support those yet, right?). Happy to just go with isinstance if you prefer non-extensibility until there is an explicit request. |
Okay! If there isn't already a pressing use-case then yup let's go with |
|
Converted to draft as I think further iteration of the design is necessary and should be left for a future release. The reason I had as singledispatch originally was because it was useful on a variable projection project I was working on but there's probably a better way and it might be better bringing that lineax main too. |
Preamble
First of all, sincere apologies for adding another PR to the backlog. This is definitely not one that would fit in an extension package due to its invasive change of the
linear_solvejvp rule, but also is not one I'd mind terribly if you deprioritise reviewing as I don't think its a critical bottleneck in my work. My hope is that submitting it now rather than waiting for the backlog to clear simply means it has more time to swim around in your head for a smoother review when we eventually get there (no pressure! 🙂). Furthermore, this could affect new solvers such as pivoted QR.Intention of PR
I've been working on variable projection as well as thinking about improving Jax's least square JVP rule and comparing to Lineax for inspiration. I noticed we were missing some low-hanging fruit when it comes to leveraging standard optimisations for projection and Gram operators that are used in the JVP rule.
Dependent columns
The JVP rule has a term$A^\dagger A y$ which is the projection onto the row space. This has the following optimisations:
Dependent rows
The JVP rule has a term$A^\dagger A^{H\dagger} \mathrm{d}A^H(b-Ax)$ , where the first two matrices can be written as the (pseudo-)inverse of the Gram matrix $A^\dagger A^{H\dagger}=(A^H A)^\dagger$ . This has the following optimisations:
inner_operatorso we just use a single call to theinner_solverand avoid an application ofCaveat
Savings are not always quite as good as they sound as the "vecs" are summed with others that still exist before applying the outer solve, but the saving from the inner matrix is still very real (i.e. savings are about half what they're advertised to be).
Design
I have introduced two
singledispatchfunctions in_solve.py:_gram_inverse_mvand_row_space_projectionwhich returnNotImplemented(NOTraise NotImplementedError) by default to allow fallback to the current path and otherwise allows leveraging these optimisations in the jvp rule.