[feat] Add token-level sparse autoencoder modules#3796
Conversation
feat] Add token-level sparse autoencoder modules
|
For other readers, we discussed this PR a bit over Slack. Some details: I think there's sadly a few tricky problems that make the PR currently a bit hard to justify. In particular, I think it's a bit between a rock and a hard place. The current PR is not really usable by anyone, there's no loss/training changes required to actually train with these modules, and there's no models to use this with. Even if there will be some more models in the future, they might be best off starting with a custom module so the authors can quickly iterate if they e.g. also want to train with 4k instead of only k and aux. I also considered expanding the PR a lot with a loss, a custom data collator, updated trainer and model card generation, etc. I was able to get training working, but on e.g. lightonai/DenseOn and after 6M FineWeb-Edu tokens, I got ~0.50 NDCG@10 on NanoBEIR (using BM25 scoring) compared to ~0.60 NDCG@10 for the base model. This could also be because of the model prompt though: DenseOn without prompts gets about ~0.51. This was also without any hyperparameter tuning, etc. An added complexity in both cases is that this produces scores that then need to be matched using BM25: the usual built-in
|
Summary
This PR adds token-level sparse autoencoder support for
SparseEncodermodels.It implements the method described in Latent Terms: Dense Retrievers Contain Trivially Extractable BM25-ready Zipfian Vocabularies
It introduces:
SparseAutoEncoderTokenEncoder, which projects token embeddings into sparse per-token SAE activationsSparseTokenPooling, which pools sparse token activations intosentence_embeddingSparseAutoEncoderand the token-level encoderSparseEncoderintegration formax_active_dimsandset_pooling_include_promptThe token encoder supports padded, flattened, and packed token inputs. In training mode, when a decoder is available, it emits:
sae_input_normalizedsae_output_decodedso external losses can train the SAE without baking loss logic into the module.
Motivation
SparseAutoEncodercurrently operates on sentence embeddings. This PR adds the corresponding token-level path:This makes token-level SAE projection compatible with the existing
SparseEncoderfeature-dict interface and final sparse embedding contract.Tests
(python 3.12 due to an unresolved and seemingly unrelated issue with kenlm when using python 3.14)
uv run --python 3.12 --no-sync pytest tests/sparse_encoder/modules/test_sparse_auto_encoder_token_encoder.py uv run --python 3.12 --no-sync pytest tests/sparse_encoder/modules/test_csr.py uv run --python 3.12 --no-sync pytest tests/sparse_encoder/test_model.py -k 'csr_max_active_dims_passed_to_forward or max_active_dims_set_init or default_to_csr or set_pooling_include_prompt_updates_sparse_token_pooling' git diff --checkResults:
13 passedfor token SAE module tests2 passedfor CSR module regression tests4 passedfor selected SparseEncoder integration testsgit diff --checkclean