This repository contains the original implementation for "Learning Long-Context Diffusion Policies via Past-Token Prediction".
Website: https://long-context-dp.github.io
Paper: https://arxiv.org/abs/2505.09561
We are very grateful for the great work done by Chi et al with the Diffusion Policy github repository, from which we adapted our repository.
- Install Diffusion Policy (with Conda):
conda env create -f conda_environment.yaml
- Install additional packages:
conda activate robodiff-lh
pip install -r requirements.txt
- Download necessary data:
For Diffusion Policy benchmarks, see https://diffusion-policy.cs.columbia.edu/data/training/. For example, with robomimic data:
mkdir data && cd data
wget https://diffusion-policy.cs.columbia.edu/data/training/robomimic_image.zip
unzip robomimic_image.zip && rm -f robomimic_image.zip && cd ..
For long-history simulation benchmarks, see the folloing data sources:
- ALOHA: https://drive.google.com/file/d/1gwzIRBmn0a4Orj2okMNQ9qiPPpxmqdKA/view?usp=sharing. We recommend using the link to download and
scpthe ZIP onto your folder. If it's easier, you can run the following command to download the file on the command line (note that this is dependent on the current system of Google verification, and can change):
cd data
export FILE_ID="1gwzIRBmn0a4Orj2okMNQ9qiPPpxmqdKA" FILE_NAME="aloha_twomodes_single.zip" TMP_HTML="/tmp/gdrive_download.html"
wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \
"https://drive.google.com/uc?export=download&id=${FILE_ID}" -O "${TMP_HTML}"
export UUID=$(grep -oP 'name="uuid" value="\K[^"]+' "${TMP_HTML}")
export CONFIRM=$(grep -oP 'name="confirm" value="\K[^"]+' "${TMP_HTML}")
export FINAL_URL="https://drive.usercontent.google.com/download?id=${FILE_ID}&export=download&confirm=${CONFIRM}&uuid=${UUID}"
wget --load-cookies /tmp/cookies.txt --no-check-certificate "$FINAL_URL" -O "${FILE_NAME}"
rm -f /tmp/cookies.txt "$TMP_HTML"
unzip aloha_twomodes_single.zip && rm -f aloha_twomodes_single.zip
cd ..
- long-square: https://drive.google.com/file/d/1-ZDi8-aVx1I8aZCan-vXJQIpLyCCNwym/view?usp=sharing. We recommend using the link to download and
scpthe ZIP onto your folder. If it's easier, you can run the following command to download the file on the command line (note that this is dependent on the current system of Google verification, and can change):
cd data
export FILE_ID="1-ZDi8-aVx1I8aZCan-vXJQIpLyCCNwym" FILE_NAME="longhistsquare100.zip" TMP_HTML="/tmp/gdrive_download.html"
wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \
"https://drive.google.com/uc?export=download&id=${FILE_ID}" -O "${TMP_HTML}"
export UUID=$(grep -oP 'name="uuid" value="\K[^"]+' "${TMP_HTML}")
export CONFIRM=$(grep -oP 'name="confirm" value="\K[^"]+' "${TMP_HTML}")
export FINAL_URL="https://drive.usercontent.google.com/download?id=${FILE_ID}&export=download&confirm=${CONFIRM}&uuid=${UUID}"
wget --load-cookies /tmp/cookies.txt --no-check-certificate "$FINAL_URL" -O "${FILE_NAME}"
rm -f /tmp/cookies.txt "$TMP_HTML"
unzip longhistsquare100.zip && rm -f longhistsquare100.zip
cd ..
Activate Conda and log into wandb, if you haven't already:
conda activate robodiff-lh
wandb login
To reproduce our experiments, we have a directory of experiment_configs/: a folder with all configurations used for experiments.
We provide a simple CLI for some of the common use cases with configurations -- for more advanced use cases, feel free to run the configs
themselves as follows:
python train.py \
--config-dir="[CONFIGDIR]" \
--config-name="[CONFIGNAME]" \
--logging.name="[UNIQUE NAME]" \
--hydra.run.dir="data/outputs/[UNIQUE DIRECTORY]" \
[YOUR OVERRIDEN ARGUMENTS]
To use the CLI, run transformer_history.sh as follows:
./transformer_history.sh --global_obs N --global_action N --global_horizon N --config_choice NAME --past_action_pred BOOL --past_steps_reg N --emb BOOL --cached BOOL
config_choice: [tool | square | square_past | transport | aloha | ...]
This help menu can also be generated by running transformer_history.sh -h. Please see the recommended obs, action, and horizon lengths for each config in the appendix
of our paper. (Generally, we run with obs=2, act=1, horizon=16 mimicking the DP repository for short-history, and obs=16, act=1, horizon=32 for long-history, with
the exception of our long-history settings).
If you are using embeddings, we also provide a selection of our short-history encoders here: https://drive.google.com/file/d/1tSYyWg3HZbTtEhzpAXQpl28DSrWsXc7J/view?usp=sharing.
We recommend that you download the data from the link and scp the ZIP onto your folder. If it's easier, you can run the following command to download the file
on the command line (note that this is dependent on the current system of Google verification, and can change):
echo "assumes you are in base directory, not data/"
export FILE_ID="1tSYyWg3HZbTtEhzpAXQpl28DSrWsXc7J" FILE_NAME="obs_encoders.zip" TMP_HTML="/tmp/gdrive_download.html"
wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \
"https://drive.google.com/uc?export=download&id=${FILE_ID}" -O "${TMP_HTML}"
export UUID=$(grep -oP 'name="uuid" value="\K[^"]+' "${TMP_HTML}")
export CONFIRM=$(grep -oP 'name="confirm" value="\K[^"]+' "${TMP_HTML}")
export FINAL_URL="https://drive.usercontent.google.com/download?id=${FILE_ID}&export=download&confirm=${CONFIRM}&uuid=${UUID}"
wget --load-cookies /tmp/cookies.txt --no-check-certificate "$FINAL_URL" -O "${FILE_NAME}"
rm -f /tmp/cookies.txt "$TMP_HTML"
unzip obs_encoders.zip && rm -f obs_encoders.zip
This will be necessary to run --emb True --cached True (for embedding caching). After downloading the observation encoders, to cache embeddings for a dataset run
python rewrite_with_embeddings.py -c [CHECKPOINT] -o [DIR FOR OUTPUT LOGS] -f [FILE TO CONVERT]
For example, python rewrite_with_embeddings.py -c obs_encoders/square_encoder.ckpt -o square_caching -f data/robomimic/datasets/square/mh/image_abs.hdf5.
Once the caching is complete, you can run code with caching with substantial speedup.
We have created transformer_eval.sh: a script for evaluating checkpoints (with various amounts of test-time consistency). To run, use:
./transformer_eval.sh --checkpoint PATH --name NAME --perturb VALUE --samples N
For example (usually you don't have to change the chunking): ./transformer_eval.sh --checkpoint /path/to/ckpt.ckpt --name testtime10 --perturb perturbs/none.yaml --samples 10
This will put the output in data/testtime10_{start_time}.
To gather data to evaluate consistency, run the following (all examples on square):
- Check the expert consistency using
rollouts_via_policy.py(only need policy for convenient way to get normalizer on dataset):
python get_action_loss_train.py --checkpoint obs_encoders/square_encoder.ckpt --output_dir square_correlation [--transport TRUE]
adding in whether or not it's a two-arm setup. This will print out correlations for a few epochs.
- Gather the rollouts using an existing checkpoint that you want to compare to the expert values:
./run_gather_rollouts.ph output_dir /path/to/squarecheckpoint
This will write the output to rollouts/output_dir (may take a while). The output will be in several pkl files.
- Merge the actions into one file for training.
python rollouts/merge_actions.py [PATH_TO_GATHERED_ACTIONS_FOLDER]
- Train on the rollouts:
python rollouts_via_policy.py /path/to/merged_actions.pkl [checkpoint] [transport]
Where the last argument is either "transport" for two-arm setup or nothing for one-arm-setup. This will give you the correlations.
Here is a directory of useful files:
experiment_configs/: a folder with all configurations used for experimentstransformer_history.sh: a shell script accessing some of the common modes for experimentsgather_rollouts.py: a shell script gathering rollouts for the action predictability analysis in/rolloutsrollouts/merge_actions.py: a script to merge togther several data files gathered for predictability analysisrollouts_via_policy.py: a script for measuring the predictability of a rollout data filetransformer_eval.sh: a script for evaluating checkpoints (with various amounts of test-time consistency)
If you use the code, remember to cite our paper:
@misc{torne2025learninglongcontextdiffusionpolicies,
title={Learning Long-Context Diffusion Policies via Past-Token Prediction},
author={Marcel Torne and Andy Tang and Yuejiang Liu and Chelsea Finn},
year={2025},
eprint={2505.09561},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2505.09561},
}