Skip to content

Conversation

@hvoss-techfak
Copy link

Dear Dynamicslab,

Love the project that you are doing.
I was learning a bit more about ODEs and PDEs and thought that maybe a solver based on proximal gradients and iterative hard thresholding could be useful for more complex problems.
This code introduces a solver based on a torch gradient descent algorithm and also a benchmark file.

The solver uses an adapted version of the cAdamW optimizer, as it performed slightly better than Adam or AdamW in my experiments.

The benchmark file runs multiple problems on all available solvers to check which one performs best given the problem statement. The output looks something like this:

System: lorenz
Optimizer Score MSE Time (s) Complexity
STLSQ 1.0000 3.1003e-02 0.0526 9
SR3-L0 0.9959 7.5510e+00 0.0137 7
FROLS 1.0000 6.5393e-02 0.0889 30
SSR 0.9993 1.3134e+00 0.0480 6
TorchOptimizer 1.0000 3.0785e-02 1.3960 8

Best optimizer: TorchOptimizer | Score=1.0000 | MSE=3.0785e-02 | Time=1.3960s | Complexity=8
Discovered equations:
-9.973 x0 + 9.973 x1
-0.129 1 + 27.739 x0 + -0.949 x1 + -0.993 x0 x2
-2.656 x2 + 0.996 x0 x1

Copilot AI review requested due to automatic review settings December 5, 2025 11:06
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a PyTorch-based optimizer (TorchOptimizer) for sparse system identification using proximal gradient descent and iterative hard thresholding, along with a comprehensive benchmark script to compare optimizers across multiple nonlinear dynamical systems.

Key changes:

  • New TorchOptimizer class implementing gradient-based sparse regression with support for SGD, Adam, AdamW, and custom CAdamW optimizers
  • Benchmark script evaluating optimizers on 12 different ODE systems (Lorenz, Rössler, Van der Pol, etc.)
  • Test suite for the new optimizer with basic integration tests

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 19 comments.

Show a summary per file
File Description
pysindy/optimizers/torch_solver.py Core implementation of the PyTorch-based optimizer with proximal gradient methods and iterative thresholding
test/test_optimizers/test_torch_optimizer.py Basic test coverage for TorchOptimizer including shape validation, sparsity, and SINDy integration
pysindy/optimizers/init.py Registers TorchOptimizer as an optional dependency with conditional import
pyproject.toml Adds torch to development dependencies
examples/benchmarks/benchmark.py Comprehensive benchmark runner comparing multiple optimizers across various dynamical systems
Comments suppressed due to low confidence (1)

examples/benchmarks/benchmark.py:346

  • 'except' clause does nothing but pass and there is no explanatory comment.
        except Exception:

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Jacob-Stevens-Haas
Copy link
Member

Whoa, I don't know why copilot auto reviewed the PR. This is the first time that's happened. Let me look into that. I'm sorry you spun your gears on copilot's comments before I had a chance to talk about larger aspects.

The SR3 optimizer already solves using a hard threshold (or rather, converts between a hard threshold and an L-0 regularized problem). However, it uses cvxpy. Another point, the subclasses of BaseOptimizer conflate the abstract approach to sparse problem setup with the actual iterative solution to an optimization problem. This is the first PR that seeks to provide a similar class but with a different array package (well, there's another in development, but TBD).

My initial thought is that, as we're trying to also support jax arrays, is that it makes more sense to split the sparse problem setup from the actual iterative solution (subclasses of BaseOptimizer do both), and use torch/jax/cvxpy/numpy depending if the array type is jax.Array/torch.Tensor/cvxpy.Expression/numpy.ndarray. Since no feature libraries yet support anything other than numpy, I was planning to put work there first. As a very light approach, SR3._reduce() would check the type of array , and then dispatch to either it's existing code or yours depending upon the type of arrays. What are your thoughts on that?

…Also added new tests for the jax implementation.
@hvoss-techfak
Copy link
Author

hvoss-techfak commented Dec 6, 2025

Sounds good, although the SR3 implementations and the torch implementation tend to give somewhat different results for the same problem. Maybe just a verbose print when switching the optimizer would help users to understand what is happening.

I know some JAX so I now also added the same implementation in JAX. Everything is the same, except for the cadamw optimizer, that I had to remove as that optimizer is a bit of a pain to rewrite in JAX.

@Jacob-Stevens-Haas
Copy link
Member

I was not asking you to make a jax implementation, I was wondering whether it made sense to split how the problem is regularized from how the regularized problem is minimized. Since jax and torch (and numpy) all have very similar APIs, a single minimization approach may be able to serve multiple array types.

E.g.

def foo(x: np.ndarray):
    return np.do_thing(x)

Being replaced with

def foo(x: np.ndarray | torch.Tensor):
    array_pkg = np if isinstance(x, np.ndarray) else torch
    return array_pkg.do_thing(x)

A good example of this is your _soft_threshold and the existing _prox_l1. To that end, I would ask that you see how your code can integrate with SR3 first, so that we can limit the amount of redundant code.

There's also the question of cAdamW. I don't know what it is or why it does better. It feels like one step further than adding a torch optimizer. I'd say let's discuss that and the benchmarks after we see where the best place to plug this in is.

@Jacob-Stevens-Haas
Copy link
Member

PS can you clarify if/how you used AI in writing this code? It's fine if it was written with AI, but I want to know how much to ask of myself and how much to ask of you.

@Jacob-Stevens-Haas
Copy link
Member

PPS, we're changing the default branch from master to main, so you'll need to change the target of the PR to main

@hvoss-techfak
Copy link
Author

Hey, I simply added a jax version as I thought this would make things easier going forward and I wanted to write something in JAX again as it has been some time for me.

I mainly did the pytorch implementation (and now the jax implementation) as a learning experience to better understand how to find dynamic systems programmatically. Therefore, I generally don't write my own code with AI tools. The only place where I use ChatGPT is to add documentation for my code and to write tests if I actually do a Pull Request (like in this case).

I'm actually not quite sure about the best way to approach combining the different versions. Sure, the functions could simply be datatype aware and change depending on the inserted data type, but the actual solving also changes somewhat. The SR3 and torch/jax implementation generally give different results with the same input configurations:

System: rossler
Optimizer Score MSE Time (s) Complexity
SR3-L0 0.9909 3.6573e-04 0.0063 44
SR3-constrained 0.9909 3.6573e-04 0.0064 44
SR3-stable 0.5873 5.1498e-04 1.6874 60
TorchOptimizer 1.0000 9.0054e-10 0.0086 44
JaxOptimizer 1.0000 9.0054e-10 0.0139 44

Depending on whether someone uses numpy/jax/torch arrays they could therefore get very different results and I don't know if this is desirable?

The cAdamW algorithm is a somewhat changed version from this repository: https://github.com/kyleliang919/C-Optim. I found that in some cases in the benchmark it gave better results than normal AdamW and I succesfully used it in one of my other projects to decrease the amounts of needed iterations by roughly 20%. Sadly, the do not have a PYPI package that can simply be installed and used, so I added my current version from my other project in this pull request. Just for clarity and scope I could also simply remove it for now.

@hvoss-techfak
Copy link
Author

I also just checked and I don't think I can switch from the master to main branch without forking the project again. The only branch github allows me to track is the master branch as all other are not forked from the original project.

@Jacob-Stevens-Haas
Copy link
Member

I believe you can PR into any branch in the repo regardless of which branch you forked, but I don't think you can move the target of an existing PR.

I'm actually not quite sure about the best way to approach combining the different versions

I admit that I'm not either. It might make more sense to leave that to a future refactoring.

Enough people have asked about benchmarking that I've added benchmarking via ASV to the repo. But before I ask to refactor your benchmark into an ASV one, I admit I haven't thought about the acceptance criteria for methods like this. On one hand, in the past, all new approaches have had a published paper on system identification to back them up beyond a single benchmark. By that metric, this PR doesn't add something a substantially novel approach. On the other hand, academic publishing has too high a bar for research software, and I'm coming around to the idea (especially with benchmarking), that runtime matters for some users (e.g. #653).

I know scikit-lean/scipy each have written algorithm notability/acceptance criteria, but this is the first time we've had to consider it in pysindy. I want to think about the way this should work in general, whether it should happen via benchmarks or what. I'd love to hear your thoughts on this, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants