-
Notifications
You must be signed in to change notification settings - Fork 359
Added Torch solver and benchmarks #662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a PyTorch-based optimizer (TorchOptimizer) for sparse system identification using proximal gradient descent and iterative hard thresholding, along with a comprehensive benchmark script to compare optimizers across multiple nonlinear dynamical systems.
Key changes:
- New
TorchOptimizerclass implementing gradient-based sparse regression with support for SGD, Adam, AdamW, and custom CAdamW optimizers - Benchmark script evaluating optimizers on 12 different ODE systems (Lorenz, Rössler, Van der Pol, etc.)
- Test suite for the new optimizer with basic integration tests
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
| pysindy/optimizers/torch_solver.py | Core implementation of the PyTorch-based optimizer with proximal gradient methods and iterative thresholding |
| test/test_optimizers/test_torch_optimizer.py | Basic test coverage for TorchOptimizer including shape validation, sparsity, and SINDy integration |
| pysindy/optimizers/init.py | Registers TorchOptimizer as an optional dependency with conditional import |
| pyproject.toml | Adds torch to development dependencies |
| examples/benchmarks/benchmark.py | Comprehensive benchmark runner comparing multiple optimizers across various dynamical systems |
Comments suppressed due to low confidence (1)
examples/benchmarks/benchmark.py:346
- 'except' clause does nothing but pass and there is no explanatory comment.
except Exception:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
|
Whoa, I don't know why copilot auto reviewed the PR. This is the first time that's happened. Let me look into that. I'm sorry you spun your gears on copilot's comments before I had a chance to talk about larger aspects. The SR3 optimizer already solves using a hard threshold (or rather, converts between a hard threshold and an L-0 regularized problem). However, it uses cvxpy. Another point, the subclasses of My initial thought is that, as we're trying to also support jax arrays, is that it makes more sense to split the sparse problem setup from the actual iterative solution (subclasses of |
…Also added new tests for the jax implementation.
|
Sounds good, although the SR3 implementations and the torch implementation tend to give somewhat different results for the same problem. Maybe just a verbose print when switching the optimizer would help users to understand what is happening. I know some JAX so I now also added the same implementation in JAX. Everything is the same, except for the cadamw optimizer, that I had to remove as that optimizer is a bit of a pain to rewrite in JAX. |
|
I was not asking you to make a jax implementation, I was wondering whether it made sense to split how the problem is regularized from how the regularized problem is minimized. Since jax and torch (and numpy) all have very similar APIs, a single minimization approach may be able to serve multiple array types. E.g. def foo(x: np.ndarray):
return np.do_thing(x)Being replaced with def foo(x: np.ndarray | torch.Tensor):
array_pkg = np if isinstance(x, np.ndarray) else torch
return array_pkg.do_thing(x)A good example of this is your There's also the question of cAdamW. I don't know what it is or why it does better. It feels like one step further than adding a torch optimizer. I'd say let's discuss that and the benchmarks after we see where the best place to plug this in is. |
|
PS can you clarify if/how you used AI in writing this code? It's fine if it was written with AI, but I want to know how much to ask of myself and how much to ask of you. |
|
PPS, we're changing the default branch from |
|
Hey, I simply added a jax version as I thought this would make things easier going forward and I wanted to write something in JAX again as it has been some time for me. I mainly did the pytorch implementation (and now the jax implementation) as a learning experience to better understand how to find dynamic systems programmatically. Therefore, I generally don't write my own code with AI tools. The only place where I use ChatGPT is to add documentation for my code and to write tests if I actually do a Pull Request (like in this case). I'm actually not quite sure about the best way to approach combining the different versions. Sure, the functions could simply be datatype aware and change depending on the inserted data type, but the actual solving also changes somewhat. The SR3 and torch/jax implementation generally give different results with the same input configurations: System: rossler Depending on whether someone uses numpy/jax/torch arrays they could therefore get very different results and I don't know if this is desirable? The cAdamW algorithm is a somewhat changed version from this repository: https://github.com/kyleliang919/C-Optim. I found that in some cases in the benchmark it gave better results than normal AdamW and I succesfully used it in one of my other projects to decrease the amounts of needed iterations by roughly 20%. Sadly, the do not have a PYPI package that can simply be installed and used, so I added my current version from my other project in this pull request. Just for clarity and scope I could also simply remove it for now. |
|
I also just checked and I don't think I can switch from the master to main branch without forking the project again. The only branch github allows me to track is the master branch as all other are not forked from the original project. |
|
I believe you can PR into any branch in the repo regardless of which branch you forked, but I don't think you can move the target of an existing PR.
I admit that I'm not either. It might make more sense to leave that to a future refactoring. Enough people have asked about benchmarking that I've added benchmarking via ASV to the repo. But before I ask to refactor your benchmark into an ASV one, I admit I haven't thought about the acceptance criteria for methods like this. On one hand, in the past, all new approaches have had a published paper on system identification to back them up beyond a single benchmark. By that metric, this PR doesn't add something a substantially novel approach. On the other hand, academic publishing has too high a bar for research software, and I'm coming around to the idea (especially with benchmarking), that runtime matters for some users (e.g. #653). I know scikit-lean/scipy each have written algorithm notability/acceptance criteria, but this is the first time we've had to consider it in pysindy. I want to think about the way this should work in general, whether it should happen via benchmarks or what. I'd love to hear your thoughts on this, though. |
Dear Dynamicslab,
Love the project that you are doing.
I was learning a bit more about ODEs and PDEs and thought that maybe a solver based on proximal gradients and iterative hard thresholding could be useful for more complex problems.
This code introduces a solver based on a torch gradient descent algorithm and also a benchmark file.
The solver uses an adapted version of the cAdamW optimizer, as it performed slightly better than Adam or AdamW in my experiments.
The benchmark file runs multiple problems on all available solvers to check which one performs best given the problem statement. The output looks something like this:
System: lorenz
Optimizer Score MSE Time (s) Complexity
STLSQ 1.0000 3.1003e-02 0.0526 9
SR3-L0 0.9959 7.5510e+00 0.0137 7
FROLS 1.0000 6.5393e-02 0.0889 30
SSR 0.9993 1.3134e+00 0.0480 6
TorchOptimizer 1.0000 3.0785e-02 1.3960 8
Best optimizer: TorchOptimizer | Score=1.0000 | MSE=3.0785e-02 | Time=1.3960s | Complexity=8
Discovered equations:
-9.973 x0 + 9.973 x1
-0.129 1 + 27.739 x0 + -0.949 x1 + -0.993 x0 x2
-2.656 x2 + 0.996 x0 x1