Fix construct_singular_matrix not respecting tags#214
Conversation
b318309 to
7da8854
Compare
|
Two comments: (1) I think you still dont respect tags. Zeroing the first row (2) As long as you are thinking about singular matrix construction methods, maybe consider adding some less trivial singular patterns. Eg in #210 for some less trivial tests I change this to (still not attempting to respect tags) def construct_singular_matrix(
getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64
):
# There is currently no attempt to generate matrices respecting tags.
matrices = construct_matrix(getkey, solver, tags, num, size=size, dtype=dtype)
if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)):
return tuple(matrix.at[0, :].set(0) for matrix in matrices)
else:
version = jr.choice(getkey(), np.array([0, 1, 2, 3]))
if version == 0:
return tuple(matrix.at[1:, 1:].set(0) for matrix in matrices)
if version == 1:
return tuple(matrix.at[0, :].set(0) for matrix in matrices)
elif version == 2:
return tuple(matrix.at[:, 0].set(0) for matrix in matrices)
else:
return tuple(matrix - matrix.T for matrix in matrices)If we are serious about respecting tags in singular matrix construction and we also want nontrivial examples we have to think about each tag individually and actually each combination of tags: |
|
(1) Note that this happens to the matrix before any transformations – for making it respect tags – are applied. (2) Agreed, it would make sense to add more in here. I think we should be able to simply add more in as random options, or adjust the implementation of This aside, you're seeing this PR a little prematurely! I've not tested this change locally to see if it actually does the right thing, so for now I'm just speculatively trying this out to see what happens in CI. |
|
Ah I see. Only saw the diff should have looked at the rest of the function! One other restriction I noticed checking this out again is that we have to test matrices with distinct singular values, since otherwise the jax derivative sometimes returns NaN's. In particular tested singular matrices should always have rank exactly 1 less than maximal. So for instance my proposed skew symmetric matrix of odd size doesn't work as there are repeated singular values, causing test failures due to the jax derivative returning NaN |
9b21f77 to
26c6bbc
Compare
26c6bbc to
7add52a
Compare
This is mostly a bug fix, documentation and under-the-hood performance improvement release with one new feature—the `lx.invert` [transformation](https://docs.kidger.site/lineax/api/linear_solve/#invert) which produces an operator representing the inverse of a matrix. Use of coloring rules should make using implicit solvers in [diffrax](https://docs.kidger.site/diffrax/) for tridiagonal `Jacobian/FunctionLinearOperator`s at least an order of magnitude faster. ## Breaking Changes * Extraction of diagonal/tridiagonals of now leverages the promise of a matrix being tagged as diagonal/tridiagonal more heavily. If you have previously used the tag for an operator that you just wanted lineax to TREAT as diagonal/tridiagonal you may now get incorrect results. In most cases the right fix will probably be to first manually extract (tri)diagonal and construct the `(Tri)DiagonalOperator` explicitly, please raise an [issue](https://github.com/patrick-kidger/lineax/issues/new) if you need any further assistance. * `lineax.linear_solve` now stop-gradient's automatically (#213), it is unlikely this will break any existing use-cases but may make manual stop-gradienting unecessary * Removed AuxLinearOperator (#203) ## Features * Add invert helper function to wrap `lineax.linear_solve` in `FunctionLinearOperator`. Materialising an inverse is now as simple as `lx.invert(op).as_matrix()`. (#206) ## Compatibility * lineax v0.1.1 now requires JAX >= 0.10.0 which provides a lowering to LAPACK/cuSolver's` ormqr for more efficient QR solve adopted in #219. ## Bugfixes * Fix derived tag check rules for composite operators (e.g. `Composed/Neg/Mul/AddLinearOperator`) (#192) * Linearisation of functions `custom_vjp`'s are now supported by `lineax.linearize(JacobianLinearOperator(f, x, jac="bwd"))` by using `jax.linear_transpose` under the hood. (#191) * Complex positive/negative semi-definite matrices no longer register as symmetric (#200) * `lineax.LSMR` no longer fails when initial residual is exactly zero. (HUGE thanks to @f0uriest for spotting this tricky and hard-to-spot bug #202) * Differentiating through `linear_solve`'s no longer differentiates through `solver.init` this means using solver's with no or incorrect jvp rule is now possible (#212) ## Performance * Coloring rules now used to _massively_ speed up diagonal/tridiagonal extraction of tagged `Jacobian/FunctionLinearOperators` (#164, #165) * Normal and iterative solvers now apply `lineax.linearise` under the hood to avoid multiple sequential AD passes (#198) * Furthermore, `lineax.Normal(lineax.Cholesky())` now materialises the inner operator before constructing the Gram matrix (#207) * `ComposedLinearOperator.as_matrix` no longer materialises each matrix first but instead batches `mv` of the first operator over the second matrix (#196) * JAX's [ormqr](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.ormqr.html) now used for more efficient QR solves (#219) ## Documentation * The `lineax.LSMR` iterative least square solver is now properly documented (#204) after @f0uriest's #202 bug-fixes make it more robust. Other repo infra PR's not affecting Python package include #214, #216 and #218. ## New Contributors * @patrick-kidger-bot 🤖 made their first contribution in #216 **Full Changelog**: v0.1.0...v0.1.1
No description provided.