In Singular Learning Theory (SLT), the Local Learning Coefficient (LLC) quantifies the effective local dimensionality of a model around a trained optimum. Estimating it can be tricky. That is what we explore here.
This repo provides benchmark estimators of LLC on small but non-trivial neural networks, using standard industrial tooling:
- BlackJAX for sampling
- ArviZ for diagnostics,
- Hydra for configuration management
- Haiku for neural network definitions.
We target networks with parameter-space dimension up to about
Requires Python 3.11
# Using uv (recommended):
uv venv --python 3.12 && source .venv/bin/activate
uv sync --extra cpu # For CPU/macOS
uv sync --extra cuda12 # For CUDA 12 (Linux)
# Or using pip:
pip install .[cpu] # For CPU/macOS
pip install .[cuda12] # For CUDA 12 (Linux)Notes: Previously the requirement in the toml was for jaxlib >= 0.7.1 (or CUDA variant same version) In order to make this project intel mac compatible we've pushed that requirement back to jaxlib >= 0.4.38. Testing now to see if this is ok.
lambda_hat provides two entry points for running experiments. Configuration is managed by Hydra.
Run the default configuration (MLP target, all samplers):
# Console script (recommended)
uv run lambda-hatOutputs (logs, plots, metrics) are automatically saved in a timestamped directory under outputs/.
The configuration is composable. You can select presets defined in the conf/ directory.
Run a quick, small experiment using the fast sampler settings and small model/data:
uv run lambda-hat sampler=fast model=small data=smallOverride any configuration parameter from the command line:
# Change the dataset size and random seed
uv run lambda-hat data.n_data=5000 seed=123
# Change the model architecture
uv run lambda-hat model.depth=5 model.target_params=20000
# Adjust sampler settings
uv run lambda-hat sampler.hmc.draws=2000 sampler.sgld.step_size=1e-5Hydra allows running sweeps over parameters using the --multirun (or -m) flag.
# Sweep over different model sizes
uv run lambda-hat --multirun model.target_params=1000,5000,10000
# Compare base vs fast sampler settings
uv run lambda-hat --multirun sampler=base,fastCombine sweeps (Cartesian product):
# 2 sizes x 2 sampler configs = 4 runs
uv run lambda-hat --multirun model.target_params=1000,5000 sampler=base,fastMulti-run outputs are saved under multirun/.
