Add ConvGRU ensemble training pipeline#10
Conversation
Transfer the training-ready ConvGRU ensemble model from ConvGRU-Ensemble into mlcast, enabling participants at the hackathon to train models on sampled radar datasets. New modules: - losses.py: CRPS, afCRPS, MaskedLoss, and build_loss() factory - utils.py: rain rate <-> normalized reflectivity conversions - models/convgru.py: RadarLightningModel with ensemble support - configs.py: @auto_config experiment factory (weatherduck pattern) - __main__.py: CLI entry point for training Updated modules: - modules/convgru_modules.py: ensemble generation via noisy decoder - data/zarr_dataset.py: SampledRadarDataset using CSV coordinates - data/zarr_datamodule.py: RadarDataModule with chronological splits - pyproject.toml: add fiddle, pandas, pytorch-lightning, torchvision deps
Replace argparse with Fiddle's absl_flags integration so that any
config parameter can be overridden from the command line using
--config set:key.path=value syntax. This is essential for the
hackathon where participants need to iterate quickly on HPC.
Usage:
python -m mlcast \
--config config:convgru_experiment \
--config set:data.zarr_path=/path/to/data.zarr \
--config set:data.csv_path=/path/to/sampled.csv \
--config set:data.batch_size=32 \
--config set:pl_module.num_blocks=4 \
--config set:trainer.max_epochs=50
Also adds absl-py, etils, importlib-resources as dependencies.
Restructure CLI to use subcommands so future commands (test, predict) can be added alongside train.
Adds [project.scripts] so 'uv run mlcast train' works alongside 'python -m mlcast train'.
- Create `tests/data/test_normalization.py` to verify the symmetry of `rainrate_to_normalized` and `normalized_to_rainrate`. - Create `tests/test_losses.py` to verify the expected tensor output shapes for `CRPS` and `afCRPS` across different reduction modes. This establishes the initial test coverage required for Phase 1 before refactoring the utilities and loss modules.
- Rename `src/mlcast/utils.py` to `src/mlcast/data/normalization.py` - Update all references in `zarr_dataset.py`, `convgru.py`, and tests
- Create `NORMALIZATION_REGISTRY` in `normalization.py` to map CF standard names to their normalization functions. - Add test to verify the registry mapping.
- Create `src/mlcast/visualization.py` to house visualization utilities. - Move `apply_radar_colormap` and `log_images` out of `convgru.py`. - Update `RadarLightningModel` to use the extracted `log_images`.
- Rename `afCRPS` to `AFCRPS` to follow naming conventions. - Update `build_loss` signature to explicitly default to `loss_class="mse"`. - Add type checking in `build_loss` to ensure `loss_class` is a string or class. - Use a dedicated `LossClass` variable internally for instantiation. - Expand docstrings to include explicit expected tensor shapes.
- Update `__main__.py` to use argparse subparsers and default to `training_experiment`. - Rename `convgru_experiment` to `training_experiment` in `configs.py`. - Update TensorBoardLogger default name to `mlcast`. - Defer documentation updates to Phase 7 to align with implementation timeline.
…dule - Rename `src/mlcast/models/convgru.py` to `base.py`. - Refactor `RadarLightningModel` into a generic `NowcastLightningModule` that accepts an injected PyTorch `nn.Module`. - Extract the core `EncoderDecoder` logic into a separate file (`src/mlcast/modules/convgru_modules.py`) and rename it to `ConvGruModel`. - Update `src/mlcast/configs.py` to use Fiddle to inject `ConvGruModel` into `NowcastLightningModule`. This completes Phase 2A of the restructuring plan, establishing a clean separation of concerns between the training orchestrator and the underlying neural network architecture.
- Add `jaxtyping` and `beartype` dependencies for static and runtime type checking. - Add rigorous NumPy-style docstrings to all methods in `base.py`. - Decorate PyTorch module `forward()` methods with `@jaxtyped(typechecker=beartype)` to enforce shape constraints at runtime. - Fix static type errors identified by `mypy`.
- Add `storage_options` to dataset factory classes to support reading Zarr stores anonymously from S3 object storage - Create `use_anon_s3_dataset` Fiddler in `mlcast.config.fiddlers` to easily configure the dataset factory for remote AWS connection strings - Revert default experiment configuration to use local dummy paths and `rainfall_rate` rather than hardcoding the Italian S3 dataset - Implement graceful CF `standard_name` validation inside dataset classes to intercept `cf_xarray` KeyErrors. When a requested variable isn't found, emit a clear `ValueError` listing all available valid CF standard names in the dataset, along with CLI hints on how to select them. - Add `use_anon_s3_dataset` usage example to dynamic CLI help text.
- Create standalone script at `examples/scripts/download_mlcast_dataset_sample.py` to download temporal slices of remote datasets. - Use `mlcast_datasets.open_catalog()` and Intake to dynamically resolve remote paths. - Support dot-notation catalog traversal (e.g. `precipitation.radklim_hourly`). - Automatically mirror remote dataset directory structure onto the local filesystem, using `mode='w'` to safely overwrite existing local caches. - Provide a Dask progress bar to show download/writing status. - Include an optional `--data-stage` argument (defaulting to `source_data`) designed to validate the stage but temporarily bypass traversal until the catalog structure updates to support it as the root node.
Adds optional `gpu-cu128` and `gpu-cu130` extras so users can opt into CUDA-enabled torch builds via `uv sync --extra gpu-cu128/cu130`, while keeping CPU torch as the default. Updates README with install instructions.
- Promote nvidia-ml-py and psutil to core dependencies (required for MLflow system metrics monitoring) - Add LogSystemInfoCallback: logs system/git tags as MLflow run tags, starts SystemMetricsMonitor manually, prints run URL at train start - Fix GPU warning suppression: patch gpu_monitor._logger immediately before SystemMetricsMonitor.start() to avoid being reset by mlflow.__init__ dictConfig call - Add use_mlflow_logger() fiddler: swaps TensorBoardLogger for MLFlowLogger, appends LogSystemInfoCallback to trainer callbacks - Export use_mlflow_logger from config/__init__.py - log_images() now dispatches on TensorBoardLogger vs MLFlowLogger, converting tensors to PIL for MLflow's log_image API - Pass self.logger (not self.logger.experiment) to log_images()
logging.config.dictConfig (called by mlflow.__init__) resets logger levels on existing loggers but does not clear their filters. Replacing setLevel(ERROR) with a logging.Filter subclass makes the suppression robust to any subsequent dictConfig calls regardless of timing.
The approach of patching mlflow's gpu_monitor logger (both via setLevel and logging.Filter) did not reliably suppress the warnings in practice. Removing the dead code for now.
Calling torch.set_float32_matmul_precision('high') at training startup
silences PyTorch's warning about underutilised Tensor Cores and enables
TF32 for matrix multiplications, improving throughput with negligible
impact on training precision.
…owcastLightningModule
…g.yaml Introduces load_yaml_config() (config/loader.py) which deserialises a Fiddle YAML dump back into a fdl.Config by mirroring the representers registered by fiddle._src.experimental.yaml_serialization in reverse (using a custom _FiddleLoader subclass of yaml.SafeLoader). The CLI (mlcast train) now detects a YAML file passed as --config, removes it from the remaining argv, loads it, and seeds the internal state of Fiddle's FiddleFlag directly so that all subsequent set: and fiddler: overrides are applied by Fiddle's own flag machinery without any custom override logic. Also fixes MLflow hyperparameter logging: fdp.as_dict_flattened produces keys like trainer.callbacks[0].monitor which MLflow rejects. Brackets are replaced with dot notation (e.g. .0.) before logging.
…g Path - Rename fixture to follow fp_ convention and return Path instead of an open dataset, so each test and dataloader worker can open the store independently - Update all tests using the fixture (test_fixture.py, test_cli_training.py, test_source_datasets.py) - Add test_cli_train_from_yaml_config exercising the new YAML load path - Update __main__.py module docstring to reflect current CLI usage - Add AGENTS.md documenting project conventions
…CLI, MLflow/WandB config upload, and YAML-based config loading
… CLI help, add config-diagram pre-commit hook
…riables skips input_channels if not in network signature
…ion with Mermaid diagram, mfai adapter example, and README snippet integration tests
…diagram with PNGs
…t classes - DatasetSample TypedDict introduced; __getitem__ now returns input/target/target_mask - target_mask computed per-timestep per-channel from target before nan_to_num (fixes collapsed-mask bug) - forecast_steps removed from NowcastLightningModule; shared_step derives it from future.shape[1] - forecast_steps moved to dataset_factory in base config and fiddlers - Contract 3 (steps > forecast_steps) removed from consistency_checks; now enforced by dataset guard - Old contracts 4/5 renumbered to 3/4; tests updated accordingly - README: all pl_module.forecast_steps references updated to dataset_factory.forecast_steps - Regenerate config diagram to reflect moved forecast_steps
- SourceDataDatasetBase extracts all shared logic: __init__, input_steps property, ds property, _validate_standard_names, _apply_augmentations, and _build_sample - _build_sample centralises post-isel processing: mask capture, nan_to_num, input/target split, augmentations, and DatasetSample assembly - DatasetSample constructed without mask first; target_mask added conditionally, eliminating duplicated if/else branches - Both subclass __getitem__ reduced to isel slicing + _build_sample call - _detect_axes extracted as module-private free function taking (ds, standard_name); sets t_dim/y_dim/x_dim via return value rather than side effects; stacklevel=3 accounts for the extra call frame
… steps becomes a property - input_steps + forecast_steps replace steps as the two primary constructor params on SourceDataDatasetBase and both subclasses; steps is now a @Property returning input_steps + forecast_steps - Guards updated: input_steps < 1 and forecast_steps < 1 raise ValueError - config/base.py: steps=18, forecast_steps=12 -> input_steps=6, forecast_steps=12 - config/fiddlers.py: use_random_sampler forwards input_steps instead of steps - README: past_steps renamed to input_steps throughout; mfai HalfUNetNowcaster example updated to accept input_steps/num_vars and compute in_channels internally; einops.rearrange used for channel-stacking with explanatory comments; config site reads cfg.data.dataset_factory.input_steps directly
|
Ok, @franchg I think I'm done with my refactor 🥳 I hope you like what I've done to your code 🎁 ➡️ https://github.com/leifdenby/mlcast/tree/feat/convgru-ensemble-training I've (with my agent's help) written up this list of changes (below). I went through this by hand and added/removed detail so I think it gives quite a complete overview. As well as reading the list below (hopefully it makes sense...) if you could also read through the updated README that would be great. I tried to take care to make it clear how the CLI now works, what the code structure is, and what ConvGRU (our first architecture) is. In the README I have also added an example of how to couple in a new architecture not yet in the If you are happy with these change my suggestion is that:
We can then build on this when we discover things I have overlooked :) ChangesWhat Changed
WhyAll these changes were motivate by a desire to:
What This Now EnablesWith this, it is now:
In addition the support for mlflow and remote S3 Zarr stores were convenient for my own experimentation, but I thought that enabling this flexiblity should make it easier to add more choices here in future. |
Feat/convgru ensemble training
|
@franchg just a heads-up, I have continue working from this branch in https://github.com/leifdenby/mlcast/tree/feat/dmi-convgru-training (because I think you are continue with this branch yourself and have already merged my other branch). I found some issues, made some additions on my branch and changed a few things:
from mlcast.config import training_experiment
# Returns a fdl.Config graph — nothing is instantiated yet.
cfg = training_experiment.as_buildable()
## Ratio mode — chronological split by fraction of the time axis:
# Implicit test split (remainder = 15%)
cfg.data.splits = {"train": 0.70, "val": 0.15}
# Explicit test fraction
cfg.data.splits = {"train": 0.70, "val": 0.15, "test": 0.15}
## Datetime mode — explicit date boundaries:
# All three splits
cfg.data.splits = {
"train": ("2016-02-29", "2022-12-31"),
"val": ("2023-01-01", "2023-12-31"),
"test": ("2024-01-01", "2025-12-31"),
}
# No test split (test_dataset will be None)
cfg.data.splits = {
"train": ("2016-02-29", "2022-12-31"),
"val": ("2023-01-01", "2023-12-31"),
"test": None,
}the data module then passes these split definitions to each dataset instances created, and internally in the dataset |
Pulls the Marshall-Palmer constants (a, b), dBZ clip range (dbz_floor, dbz_ceiling), and normalized output range (norm_min, norm_max) into a single frozen dataclass shared by all conversions. The floor anchors 0 mm/h <-> dbz_floor exactly in both directions and the ceiling caps the forward conversion only. Per-call constants are cast to the input dtype via _typed_constants so the Z-R math stays in fp32 throughout instead of upcasting to fp64 and casting back. ~2.5x faster on a typical fp32 batch, which matters since this runs in the training dataloader hot path. Module-level shim functions and registries remain bound to DEFAULT_SCALING for backwards compatibility (nowcasting_module.py unchanged). Tests cover dtype preservation across fp16/fp32/fp64, zero/NaN anchors, round-trip across various (a, b) values, and the custom norm_min/norm_max range.
leifdenby
left a comment
There was a problem hiding this comment.
@franchg and I discussed this PR yesterday and agreed to merge it in and the continue work by creating further PRs. The purpose of doing this is to give more visibility to the work we have done here, and make it easier for others to build on this work.
We expect to make further PRs to address the following:
- @franchg is refactoring the source dataset sampler (https://github.com/mlcast-community/mlcast-dataset-sampler) to 1) use a faster implementation of the intensity based sampling (by retaining statistics of extracted sequences found during the nan-filtering step) and 2) to switch to parquet-based storage for the sampling information so that meta information (about the parameters of the sampling process) can be included directly in the parquet files
- I will be adding a PR to introduce the generalisation of how dataset splits are defined by 1) allowing for splitting (in future) on other coordinates besides
timeby making it explicit in the argument that defines the splits what coordinate to be split on, and allowing for range-based defined splits (tuples like so(2020-01-01T12:00, 2021-01-01T12:00)
And additional PRs that fix the following bugs I identified earlier:
- issue is that DMI radar is stored as float64 (which the other datasets are not), a these need converting to float32 before returning from the pytorch Dataset class, wasn't handled previously
.gitignorewas excluding themlcast.datasubmodule so thatuv add {mlcast-repo-path}install wouldn't includemlcast.datasubmodule
Summary
@auto_configpattern (as discussed in Designing configuration infrastructure for mlcast python package #5 and used in weatherduck)New files
losses.py— CRPS, afCRPS, MaskedLoss, andbuild_loss()factoryutils.py— rain rate ↔ normalized reflectivity conversions (Marshall-Palmer)models/convgru.py—RadarLightningModelwith ensemble support, TensorBoard image logging, and inference APIconfigs.py—@auto_configexperiment factory following the weatherduck pattern__main__.py— CLI entry point (python -m mlcast --zarr-path ... --csv-path ...)Updated files
modules/convgru_modules.py— added ensemble generation via noisy decoder inputsdata/zarr_dataset.py—SampledRadarDatasetloading datacubes from Zarr using CSV(t, x, y)coordinatesdata/zarr_datamodule.py—RadarDataModulewith chronological train/val/test splits and augmentationpyproject.toml— addedfiddle,pandas,pytorch-lightning,torchvisiondependenciesUsage (hackathon)
Closes #7 (partially — brings ensemble ConvGRU into mlcast)
Implements #5 (Fiddle configuration)
Test plan
uv pip install -e .and verify imports workpython -m mlcast --helpto verify CLI