Skip to content

long-context-dp/ldp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This repository contains the original implementation for "Learning Long-Context Diffusion Policies via Past-Token Prediction".

Website: https://long-context-dp.github.io

Paper: https://arxiv.org/abs/2505.09561

We are very grateful for the great work done by Chi et al with the Diffusion Policy github repository, from which we adapted our repository.

🧰 Installation

  1. Install Diffusion Policy (with Conda):
conda env create -f conda_environment.yaml
  1. Install additional packages:
conda activate robodiff-lh
pip install -r requirements.txt
  1. Download necessary data:

For Diffusion Policy benchmarks, see https://diffusion-policy.cs.columbia.edu/data/training/. For example, with robomimic data:

mkdir data && cd data
wget https://diffusion-policy.cs.columbia.edu/data/training/robomimic_image.zip
unzip robomimic_image.zip && rm -f robomimic_image.zip && cd ..

For long-history simulation benchmarks, see the folloing data sources:

cd data
export FILE_ID="1gwzIRBmn0a4Orj2okMNQ9qiPPpxmqdKA" FILE_NAME="aloha_twomodes_single.zip" TMP_HTML="/tmp/gdrive_download.html"

wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \
"https://drive.google.com/uc?export=download&id=${FILE_ID}" -O "${TMP_HTML}"

export UUID=$(grep -oP 'name="uuid" value="\K[^"]+' "${TMP_HTML}")
export CONFIRM=$(grep -oP 'name="confirm" value="\K[^"]+' "${TMP_HTML}")

export FINAL_URL="https://drive.usercontent.google.com/download?id=${FILE_ID}&export=download&confirm=${CONFIRM}&uuid=${UUID}"

wget --load-cookies /tmp/cookies.txt --no-check-certificate "$FINAL_URL" -O "${FILE_NAME}"

rm -f /tmp/cookies.txt "$TMP_HTML"

unzip aloha_twomodes_single.zip && rm -f aloha_twomodes_single.zip
cd ..
cd data
export FILE_ID="1-ZDi8-aVx1I8aZCan-vXJQIpLyCCNwym" FILE_NAME="longhistsquare100.zip" TMP_HTML="/tmp/gdrive_download.html"

wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \
"https://drive.google.com/uc?export=download&id=${FILE_ID}" -O "${TMP_HTML}"

export UUID=$(grep -oP 'name="uuid" value="\K[^"]+' "${TMP_HTML}")
export CONFIRM=$(grep -oP 'name="confirm" value="\K[^"]+' "${TMP_HTML}")

export FINAL_URL="https://drive.usercontent.google.com/download?id=${FILE_ID}&export=download&confirm=${CONFIRM}&uuid=${UUID}"

wget --load-cookies /tmp/cookies.txt --no-check-certificate "$FINAL_URL" -O "${FILE_NAME}"

rm -f /tmp/cookies.txt "$TMP_HTML"

unzip longhistsquare100.zip && rm -f longhistsquare100.zip
cd ..

Running the code

Training Models

Activate Conda and log into wandb, if you haven't already:

conda activate robodiff-lh
wandb login

To reproduce our experiments, we have a directory of experiment_configs/: a folder with all configurations used for experiments. We provide a simple CLI for some of the common use cases with configurations -- for more advanced use cases, feel free to run the configs themselves as follows:

python train.py \
        --config-dir="[CONFIGDIR]" \
        --config-name="[CONFIGNAME]" \
        --logging.name="[UNIQUE NAME]" \
        --hydra.run.dir="data/outputs/[UNIQUE DIRECTORY]" \
        [YOUR OVERRIDEN ARGUMENTS] 

To use the CLI, run transformer_history.sh as follows:

./transformer_history.sh --global_obs N --global_action N --global_horizon N --config_choice NAME --past_action_pred BOOL --past_steps_reg N --emb BOOL --cached BOOL
config_choice: [tool | square | square_past | transport | aloha | ...]

This help menu can also be generated by running transformer_history.sh -h. Please see the recommended obs, action, and horizon lengths for each config in the appendix of our paper. (Generally, we run with obs=2, act=1, horizon=16 mimicking the DP repository for short-history, and obs=16, act=1, horizon=32 for long-history, with the exception of our long-history settings).

If you are using embeddings, we also provide a selection of our short-history encoders here: https://drive.google.com/file/d/1tSYyWg3HZbTtEhzpAXQpl28DSrWsXc7J/view?usp=sharing. We recommend that you download the data from the link and scp the ZIP onto your folder. If it's easier, you can run the following command to download the file on the command line (note that this is dependent on the current system of Google verification, and can change):

echo "assumes you are in base directory, not data/"
export FILE_ID="1tSYyWg3HZbTtEhzpAXQpl28DSrWsXc7J" FILE_NAME="obs_encoders.zip" TMP_HTML="/tmp/gdrive_download.html"

wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \
"https://drive.google.com/uc?export=download&id=${FILE_ID}" -O "${TMP_HTML}"

export UUID=$(grep -oP 'name="uuid" value="\K[^"]+' "${TMP_HTML}")
export CONFIRM=$(grep -oP 'name="confirm" value="\K[^"]+' "${TMP_HTML}")

export FINAL_URL="https://drive.usercontent.google.com/download?id=${FILE_ID}&export=download&confirm=${CONFIRM}&uuid=${UUID}"

wget --load-cookies /tmp/cookies.txt --no-check-certificate "$FINAL_URL" -O "${FILE_NAME}"

rm -f /tmp/cookies.txt "$TMP_HTML"

unzip obs_encoders.zip && rm -f obs_encoders.zip

This will be necessary to run --emb True --cached True (for embedding caching). After downloading the observation encoders, to cache embeddings for a dataset run

python rewrite_with_embeddings.py -c [CHECKPOINT] -o [DIR FOR OUTPUT LOGS] -f [FILE TO CONVERT]

For example, python rewrite_with_embeddings.py -c obs_encoders/square_encoder.ckpt -o square_caching -f data/robomimic/datasets/square/mh/image_abs.hdf5. Once the caching is complete, you can run code with caching with substantial speedup.

Evaluating with Test-Time Improvements

We have created transformer_eval.sh: a script for evaluating checkpoints (with various amounts of test-time consistency). To run, use:

./transformer_eval.sh --checkpoint PATH --name NAME --perturb VALUE --samples N

For example (usually you don't have to change the chunking): ./transformer_eval.sh --checkpoint /path/to/ckpt.ckpt --name testtime10 --perturb perturbs/none.yaml --samples 10

This will put the output in data/testtime10_{start_time}.

Gathering Data to Evaluate Consistency

To gather data to evaluate consistency, run the following (all examples on square):

  1. Check the expert consistency using rollouts_via_policy.py (only need policy for convenient way to get normalizer on dataset):
python get_action_loss_train.py --checkpoint obs_encoders/square_encoder.ckpt --output_dir square_correlation [--transport TRUE]

adding in whether or not it's a two-arm setup. This will print out correlations for a few epochs.

  1. Gather the rollouts using an existing checkpoint that you want to compare to the expert values:
./run_gather_rollouts.ph output_dir /path/to/squarecheckpoint

This will write the output to rollouts/output_dir (may take a while). The output will be in several pkl files.

  1. Merge the actions into one file for training.
python rollouts/merge_actions.py [PATH_TO_GATHERED_ACTIONS_FOLDER]
  1. Train on the rollouts:
python rollouts_via_policy.py /path/to/merged_actions.pkl [checkpoint] [transport]

Where the last argument is either "transport" for two-arm setup or nothing for one-arm-setup. This will give you the correlations.

Here is a directory of useful files:

  • experiment_configs/: a folder with all configurations used for experiments
  • transformer_history.sh: a shell script accessing some of the common modes for experiments
  • gather_rollouts.py: a shell script gathering rollouts for the action predictability analysis in /rollouts
  • rollouts/merge_actions.py: a script to merge togther several data files gathered for predictability analysis
  • rollouts_via_policy.py: a script for measuring the predictability of a rollout data file
  • transformer_eval.sh: a script for evaluating checkpoints (with various amounts of test-time consistency)

Citation

If you use the code, remember to cite our paper:

@misc{torne2025learninglongcontextdiffusionpolicies,
      title={Learning Long-Context Diffusion Policies via Past-Token Prediction}, 
      author={Marcel Torne and Andy Tang and Yuejiang Liu and Chelsea Finn},
      year={2025},
      eprint={2505.09561},
      archivePrefix={arXiv},
      primaryClass={cs.RO},
      url={https://arxiv.org/abs/2505.09561}, 
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •