Skip to content

Fix construct_singular_matrix not respecting tags#214

Merged
patrick-kidger merged 1 commit into
mainfrom
singular-tags
Mar 10, 2026
Merged

Fix construct_singular_matrix not respecting tags#214
patrick-kidger merged 1 commit into
mainfrom
singular-tags

Conversation

@patrick-kidger
Copy link
Copy Markdown
Owner

No description provided.

@adconner
Copy link
Copy Markdown
Contributor

adconner commented Mar 9, 2026

Two comments: (1) I think you still dont respect tags. Zeroing the first row matrix = matrix.at[0, :].set(0) is the method for if there are tags, but this does not preserve symmetric matrices, much less psd/nsd matrices (which is the current only usage of tagged singular matrices)

(2) As long as you are thinking about singular matrix construction methods, maybe consider adding some less trivial singular patterns. Eg in #210 for some less trivial tests I change this to (still not attempting to respect tags)

def construct_singular_matrix(
    getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64
):
    # There is currently no attempt to generate matrices respecting tags.
    matrices = construct_matrix(getkey, solver, tags, num, size=size, dtype=dtype)
    if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)):
        return tuple(matrix.at[0, :].set(0) for matrix in matrices)
    else:
        version = jr.choice(getkey(), np.array([0, 1, 2, 3]))
        if version == 0:
            return tuple(matrix.at[1:, 1:].set(0) for matrix in matrices)
        if version == 1:
            return tuple(matrix.at[0, :].set(0) for matrix in matrices)
        elif version == 2:
            return tuple(matrix.at[:, 0].set(0) for matrix in matrices)
        else:
            return tuple(matrix - matrix.T for matrix in matrices)

If we are serious about respecting tags in singular matrix construction and we also want nontrivial examples we have to think about each tag individually and actually each combination of tags: diagonal and triangular just zero a diagonal element. unit_diagonal impossible when combined with triangular, probably unnecessary when not but we could take any singular matrix with nonzero diagonal and normalize the rows to make the diagonal ones, similar when combined with symmetric or psd or nsd, just do simultaneous row/column normalization by the sqrt of diagonal values. symmetric can zero [1:,1:] to get rank 2 or zero an eigenvalue, psd and nsd probably you have to zero an eigenvalue I don't see a structural option.

@patrick-kidger
Copy link
Copy Markdown
Owner Author

(1) Note that this happens to the matrix before any transformations – for making it respect tags – are applied.

(2) Agreed, it would make sense to add more in here. I think we should be able to simply add more in as random options, or adjust the implementation of _construct_matrix_impl to respect tags+singularity in some way.

This aside, you're seeing this PR a little prematurely! I've not tested this change locally to see if it actually does the right thing, so for now I'm just speculatively trying this out to see what happens in CI.

@adconner
Copy link
Copy Markdown
Contributor

adconner commented Mar 9, 2026

Ah I see. Only saw the diff should have looked at the rest of the function!

One other restriction I noticed checking this out again is that we have to test matrices with distinct singular values, since otherwise the jax derivative sometimes returns NaN's. In particular tested singular matrices should always have rank exactly 1 less than maximal. So for instance my proposed skew symmetric matrix of odd size doesn't work as there are repeated singular values, causing test failures due to the jax derivative returning NaN

@patrick-kidger patrick-kidger force-pushed the singular-tags branch 4 times, most recently from 9b21f77 to 26c6bbc Compare March 10, 2026 14:05
@patrick-kidger patrick-kidger merged commit 5e00f8c into main Mar 10, 2026
1 check passed
@patrick-kidger patrick-kidger deleted the singular-tags branch March 10, 2026 16:18
@jpbrodrick89 jpbrodrick89 mentioned this pull request May 1, 2026
jpbrodrick89 added a commit that referenced this pull request May 1, 2026
This is mostly a bug fix, documentation and under-the-hood performance improvement release with one new feature—the `lx.invert` [transformation](https://docs.kidger.site/lineax/api/linear_solve/#invert) which produces an operator representing the inverse of a matrix. Use of coloring rules should make using implicit solvers in [diffrax](https://docs.kidger.site/diffrax/) for tridiagonal `Jacobian/FunctionLinearOperator`s at least an order of magnitude faster.

## Breaking Changes

* Extraction of diagonal/tridiagonals of now leverages the promise of a matrix being tagged as diagonal/tridiagonal more heavily. If you have previously used the tag for an operator that you just wanted lineax to TREAT as diagonal/tridiagonal you may now get incorrect results. In most cases the right fix will probably be to first manually extract (tri)diagonal and construct the `(Tri)DiagonalOperator` explicitly, please raise an [issue](https://github.com/patrick-kidger/lineax/issues/new) if you need any further assistance.
* `lineax.linear_solve` now stop-gradient's automatically (#213), it is unlikely this will break any existing use-cases but may make manual stop-gradienting unecessary
* Removed AuxLinearOperator (#203)

## Features

* Add invert helper function to wrap `lineax.linear_solve` in `FunctionLinearOperator`. Materialising an inverse is now as simple as `lx.invert(op).as_matrix()`. (#206)
 
## Compatibility

* lineax v0.1.1 now requires JAX >= 0.10.0 which provides a lowering to LAPACK/cuSolver's` ormqr for more efficient QR solve adopted in #219.

## Bugfixes

* Fix derived tag check rules for composite operators (e.g. `Composed/Neg/Mul/AddLinearOperator`) (#192)
* Linearisation of functions `custom_vjp`'s are now supported by `lineax.linearize(JacobianLinearOperator(f, x, jac="bwd"))` by using `jax.linear_transpose` under the hood. (#191)
* Complex positive/negative semi-definite matrices no longer register as symmetric (#200)
* `lineax.LSMR` no longer fails when initial residual is exactly zero. (HUGE thanks to @f0uriest for spotting this tricky and hard-to-spot bug #202)
* Differentiating through `linear_solve`'s no longer differentiates through `solver.init` this means using solver's with no or incorrect jvp rule is now possible (#212)

## Performance

* Coloring rules now used to _massively_ speed up diagonal/tridiagonal extraction of tagged `Jacobian/FunctionLinearOperators`  (#164, #165)
* Normal and iterative solvers now apply `lineax.linearise` under the hood to avoid multiple sequential AD passes (#198)
* Furthermore, `lineax.Normal(lineax.Cholesky())` now materialises the inner operator before constructing the Gram matrix (#207)
* `ComposedLinearOperator.as_matrix` no longer materialises each matrix first but instead batches `mv` of the first operator over the second matrix (#196)
*  JAX's [ormqr](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.ormqr.html) now used for more efficient QR solves (#219)

## Documentation

* The `lineax.LSMR` iterative least square solver is now properly documented (#204) after @f0uriest's #202 bug-fixes make it more robust.

Other repo infra PR's not affecting Python package include #214, #216 and #218.


## New Contributors
* @patrick-kidger-bot 🤖 made their first contribution in #216 

**Full Changelog**: v0.1.0...v0.1.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants