This repository contains the official PyTorch implementation for the ICML 2025 paper "ProDiff: Prototype-Guided Diffusion for Minimal Information Trajectory Imputation" . The paper introduces ProDiff, a novel framework where learned trajectory prototypes are used to guide and enhance the performance of diffusion models for complex trajectory-related tasks, specifically focusing on imputing trajectories with minimal available information. This codebase provides the tools to reproduce the experiments presented in the paper and to further explore the ProDiff methodology.
The following sections detail the structure of this codebase, setup instructions, and how to run training and inference.
This project implements a deep learning model for trajectory generation and related tasks, leveraging a combination of Diffusion Models and a Trajectory Transformer. The system is designed to learn from trajectory data, generate new trajectories, and potentially perform tasks like trajectory completion or imputation under certain conditions. It incorporates a contrastive learning objective and uses prototype-based representations learned by the Transformer to guide the diffusion process.
The model architecture consists of:
- Trajectory Transformer (
diffProModel/protoTrans.py): Processes input trajectories (sequences of time, latitude, longitude) to extract meaningful features and generate representative prototypes. It employs self-attention mechanisms. - Diffusion Model (
diffProModel/Diffusion.py): A UNet-based model that learns to denoise trajectories. It takes masked or corrupted trajectory data and learned prototypes as input to generate or refine trajectories through a reverse diffusion process. Supports both DDPM and DDIM sampling.
The training process involves a combined loss including diffusion loss, contrastive loss (to make learned trajectory embeddings discriminative), and a cross-entropy loss related to prototype assignment.
.
├── conf/
│ └── config.py # Configuration file for all parameters
├── data/
│ ├── train.h5 # Placeholder for training data (HDF5 format expected)
│ └── test.h5 # Placeholder for test data (HDF5 format expected)
├── dataset/
│ └── data_util.py # Data loading and preprocessing utilities (TrajectoryDataset)
├── diffProModel/
│ ├── Diffusion.py # Core Diffusion model implementation (UNet, sampling)
│ ├── loss.py # Custom loss functions (e.g., ContrastiveLoss)
│ └── protoTrans.py # Trajectory Transformer model implementation
├── utils/
│ ├── logger.py # Logging utilities
│ ├── metric.py # Evaluation metrics for trajectories
│ └── utils.py # Helper functions (KMeans, masking, etc.)
├── main.py # Main script to run training and experiments
├── train.py # Contains the main training loop and logic
├── test.py # Handles model evaluation and testing
├── run_ddpm.sh # Example script to run experiments with DDPM sampling
├── run_ddim.sh # Example script to run experiments with DDIM sampling
└── README.md # This file
The project relies on several Python libraries. Key dependencies include:
- PyTorch (including
torch.distributedfor DDP) - NumPy
- tqdm
- h5py (implicitly, for loading
.h5data files)
Ensure you have a Python environment with these packages installed. For GPU support, a compatible CUDA version and PyTorch build are required.
- Prepare your trajectory data in HDF5 format (
.h5files). - Each HDF5 file should be structured in a way that
dataset.data_util.TrajectoryDatasetcan read it. Typically, this involves datasets for absolute time, latitude, and longitude for each trajectory. - Place your training data (e.g.,
train.h5) and test data (e.g.,test.h5) into thedata/directory, or update the paths inconf/config.py. - The
TrajectoryDatasetindataset/data_util.pyexpects file paths and a trajectory length. Modifyget_data_pathsinutils/utils.pyif your data storage differs.
All experimental parameters are managed through conf/config.py. This includes:
- Data settings: Dataset paths, trajectory length, number of workers for DataLoader.
- Model settings: Architecture details for both the Diffusion model (UNet) and the Transformer (embedding dimensions, number of layers/heads).
- Diffusion settings: Beta schedule, number of timesteps.
- Training settings: Batch size, number of epochs, learning rate, optimizer settings, distributed training flags, loss weights.
- Sampling settings: Batch size for generation, DDIM steps, eta.
Modify load_config() in conf/config.py to suit your experiment. The configuration is loaded as a nested namespace object in main.py.
The primary entry point for the project is main.py. You can also use the provided shell scripts (run_ddpm.sh, run_ddim.sh) as templates.
-
Configure: Adjust parameters in
conf/config.pyas needed (e.g., paths, batch size, epochs). -
Run Training:
- Single GPU / CPU:
python main.py --exp_name your_experiment_name
- Distributed Data Parallel (DDP) Training:
To use DDP, set
training.dis_gputoTrueinconf/config.py. Then launch usingtorchrunortorch.distributed.launch. For example, for a 2-GPU setup:Thetorchrun --nproc_per_node=2 main.py --exp_name your_ddp_experiment_name
run_ddpm.shscript provides an example of how to launch DDP training. It uses environment variables likeMASTER_ADDR,MASTER_PORT,RANK,WORLD_SIZE, andLOCAL_RANKwhich are typically set by the launch utility.
Training artifacts (model checkpoints, logs, results) will be saved in an experiment directory, usually structured as
experiments_json/<exp_name_with_timestamp>/. - Single GPU / CPU:
- Experiments are organized into directories under
experiments_json/. - Each experiment directory will contain:
config.json: A copy of the configuration used for the experiment.models/: Saved model checkpoints (diffusion model and transformer model).results/: Saved outputs like losses, generated samples, or evaluation metrics.- Log files.
- The
test_modelfunction (intest.py, called fromtrain.pyduring validation) handles the generation of samples and evaluation. - To run standalone testing or generation with a trained model:
- Ensure you have trained model checkpoints (
.ptfiles for diffusion and transformer, and.npyfor prototypes if used). - Modify
main.pyor create a new script to load the checkpoints and calltest_modelwith the appropriate configuration (e.g., point to the correct test data, set sampling parameters). - The
samplingsection inconf/config.pyallows control over batch size for generation and specific sampling parameters like DDIM steps.
- Ensure you have trained model checkpoints (
run_ddpm.sh: An example script to launch training. It sets up environment variables for distributed training and callspython main.py. You might need to adapt it to your environment (e.g., GPU count, paths).run_ddim.sh: Similar torun_ddpm.sh, potentially configured for DDIM sampling during validation/testing (though DDIM/DDPM choice for validation is often controlled within the config ortest_modellogic).
main.py:- Parses arguments (like experiment name).
- Sets up the experiment directory and logging.
- Loads the configuration from
conf/config.py. - Calls
train_mainto start the training and validation process.
train.py(train_mainfunction):- Sets up DDP if enabled.
- Initializes datasets, dataloaders, models (Diffusion and Transformer), optimizer, and scheduler.
- Contains the main training loop:
- Iterates through epochs and batches.
- Performs data scaling and masking.
- Passes data through the Transformer to get prototypes and features.
- Uses KMeans for clustering features (iteratively).
- Computes contrastive loss and cross-entropy loss for the Transformer.
- Computes diffusion loss using the Diffusion model, conditioned on masked data and matched prototypes.
- Performs backpropagation and optimizer step.
- Handles model saving and periodic validation by calling
test_model.
test.py(test_modelfunction):- Takes trained models, test dataloader, config, and other parameters.
- Sets models to evaluation mode.
- Iterates through the test data.
- Performs sampling using the Diffusion model (DDPM or DDIM based on config) conditioned on input data (e.g., masked trajectories) and prototypes from the Transformer.
- Calculates and logs evaluation metrics (from
utils/metric.py). - Saves generated samples/results.
This README provides a starting point. You may need to delve into the specific scripts and configuration options for more advanced usage or customization.