Graph augmentation strategies for self-supervised pretraining of graph neural networks in spatial omics.
Spatial Augmentations is a research framework for exploring and benchmarking graph-based data augmentation strategies to improve Graph Neural Network (GNN) pretraining on spatial omics datasets (e.g., spatial transcriptomics, spatial proteomics). This project is built on PyTorch Lightning and Hydra for modularity, reproducibility, and ease of experimentation.
The key methods used are Bootstrapped Graph Latents (BGRL) (Thakoor et al., 2021) and GRACE (Zhu et al., 2022) for self-supervised pretraining of GNNs via graph augmentations. Augmentation strategies are evaluated on two representative tasks: domain identification and phenotype prediction.
This repository is associated with the NeurIPS workshop paper:
“Exploring Augmentation-Driven Invariances for Graph Self-supervised Learning in Spatial Omics”
Authors: Lovro Rabuzin, Michel Tarnow, Prof. Dr. Valentina Boeva
├── .github <- Github Actions workflows
├── configs <- Hydra configs
├── data <- Project data
├── logs <- Logs generated by hydra and lightning loggers
├── notebooks <- Jupyter notebooks
├── scripts <- Shell scripts
├── src <- Source code
├── tests <- Tests of any kind
│
├── .gitignore <- List of files ignored by git
├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
├── .project-root <- File for inferring the position of project root directory
├── environment.yaml <- File for installing conda environment
├── LICENSE <- License file
├── Makefile <- Makefile with commands like `make train` or `make test`
├── pyproject.toml <- Configuration options for testing and linting
├── requirements.txt <- File for installing python dependencies
└── README.md
Datasets are not included in this repository due to their size. Raw datasets can be downloaded from the resources provided below.
| Dataset | Task | Description |
|---|---|---|
| Domain123 | Domain Identification | Datasets 1, 2, and 3 from Schaub et al. (2025): mouse brain spatial transcriptomics datasets (MERFISH, STARmap, BaristaSeq) |
| Domain4 | Domain Identification | Datasets 4 from Schaub et al. (2025): mouse brain spatial transcriptomics dataset (Xenium) |
| Domain7 | Domain Identification | Dataset 7 from Schaub et al. (2025): mouse brain spatial transcriptomics dataset (MERFISH) |
| NSCLC | Phenotype Prediction | non-small cell lung cancer (NSCLC) spatial proteomics dataset (IMC) from Cords et al. (2024) |
# clone project
git clone https://github.com/BoevaLab/spatial-augmentations.git
cd spatial-augmentations
# create conda environment and install dependencies
conda env create -f environment.yaml -n myenv
# activate conda environment
conda activate myenvtorch_scatter, torch_sparse, and abc_atlas packages have to be transferred manually. The installation approaches have not been tested outside of a Linux/HPC environment.
# clone project
git clone https://github.com/BoevaLab/spatial-augmentations.git
cd spatial-augmentations
# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv
# install pytorch according to instructions
# https://pytorch.org/get-started/
# install requirements
pip install -r requirements.txtTrain model with default configuration (here, domain identification model):
# train on CPU
python src/train_domain.py trainer=cpu
# train on GPU
python src/train_domain.py trainer=gpuTrain model with chosen experiment configuration from configs/experiment/:
python src/train_domain.py experiment=experiment_name.yamlYou can override any parameter from command line like this:
python src/train_domain.py trainer.max_epochs=20 data.batch_size=64Distributed under the MIT License. See LICENSE for more information.