Skip to content

[feat] Add token-level sparse autoencoder modules#3796

Draft
robro612 wants to merge 8 commits into
huggingface:mainfrom
robro612:sae-sparse-token-encoder
Draft

[feat] Add token-level sparse autoencoder modules#3796
robro612 wants to merge 8 commits into
huggingface:mainfrom
robro612:sae-sparse-token-encoder

Conversation

@robro612

@robro612 robro612 commented Jun 4, 2026

Copy link
Copy Markdown

Summary

This PR adds token-level sparse autoencoder support for SparseEncoder models.
It implements the method described in Latent Terms: Dense Retrievers Contain Trivially Extractable BM25-ready Zipfian Vocabularies

It introduces:

  • SparseAutoEncoderTokenEncoder, which projects token embeddings into sparse per-token SAE activations
  • SparseTokenPooling, which pools sparse token activations into sentence_embedding
  • shared stateless SAE projection helpers used by both sentence-level SparseAutoEncoder and the token-level encoder
  • SparseEncoder integration for max_active_dims and set_pooling_include_prompt
  • model card tagging for token SAE models
  • package reference docs for the new modules

The token encoder supports padded, flattened, and packed token inputs. In training mode, when a decoder is available, it emits:

  • sae_input_normalized
  • sae_output_decoded

so external losses can train the SAE without baking loss logic into the module.

Motivation

SparseAutoEncoder currently operates on sentence embeddings. This PR adds the corresponding token-level path:

token_embeddings
-> SparseAutoEncoderTokenEncoder
-> token_sparse_values / token_sparse_indices
-> SparseTokenPooling
-> sentence_embedding

This makes token-level SAE projection compatible with the existing SparseEncoder feature-dict interface and final sparse embedding contract.

Tests

(python 3.12 due to an unresolved and seemingly unrelated issue with kenlm when using python 3.14)

uv run --python 3.12 --no-sync pytest tests/sparse_encoder/modules/test_sparse_auto_encoder_token_encoder.py
uv run --python 3.12 --no-sync pytest tests/sparse_encoder/modules/test_csr.py
uv run --python 3.12 --no-sync pytest tests/sparse_encoder/test_model.py -k 'csr_max_active_dims_passed_to_forward or max_active_dims_set_init or default_to_csr or set_pooling_include_prompt_updates_sparse_token_pooling'
git diff --check

Results:

  • 13 passed for token SAE module tests
  • 2 passed for CSR module regression tests
  • 4 passed for selected SparseEncoder integration tests
  • git diff --check clean

@robro612 robro612 changed the title Add token-level sparse autoencoder modules [feat] Add token-level sparse autoencoder modules Jun 4, 2026
@robro612 robro612 changed the title [feat] Add token-level sparse autoencoder modules [feat] Add token-level sparse autoencoder modules Jun 4, 2026
@tomaarsen

Copy link
Copy Markdown
Member

For other readers, we discussed this PR a bit over Slack. Some details:

I think there's sadly a few tricky problems that make the PR currently a bit hard to justify. In particular, I think it's a bit between a rock and a hard place.

The current PR is not really usable by anyone, there's no loss/training changes required to actually train with these modules, and there's no models to use this with. Even if there will be some more models in the future, they might be best off starting with a custom module so the authors can quickly iterate if they e.g. also want to train with 4k instead of only k and aux.

I also considered expanding the PR a lot with a loss, a custom data collator, updated trainer and model card generation, etc. I was able to get training working, but on e.g. lightonai/DenseOn and after 6M FineWeb-Edu tokens, I got ~0.50 NDCG@10 on NanoBEIR (using BM25 scoring) compared to ~0.60 NDCG@10 for the base model. This could also be because of the model prompt though: DenseOn without prompts gets about ~0.51. This was also without any hyperparameter tuning, etc.

An added complexity in both cases is that this produces scores that then need to be matched using BM25: the usual built-in model.similarity etc. isn't sufficient, so either way, it wouldn't immediately slot into ST nicely. If there were some ready-to-use models, it would be simpler to justify everything, but those could also be very viable with trust_remote_code=True.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants