Skip to content

[v6] Add support for MultiVectorEncoder models#3794

Draft
tomaarsen wants to merge 7 commits into
huggingface:mainfrom
tomaarsen:v6/multi_vector
Draft

[v6] Add support for MultiVectorEncoder models#3794
tomaarsen wants to merge 7 commits into
huggingface:mainfrom
tomaarsen:v6/multi_vector

Conversation

@tomaarsen

@tomaarsen tomaarsen commented Jun 3, 2026

Copy link
Copy Markdown
Member

Hello!

Pull Request overview

  • Introduce MultiVectorEncoder, a first-class model family for ColBERT-style / late-interaction retrieval
  • Ship the full training, evaluation, and scoring surface alongside the model
  • Add SimilarityFunction.MAXSIM and maxsim / maxsim_pairwise utilities

Details

This brings PyLate's feature set into sentence-transformers as a first-class MultiVectorEncoder family sitting alongside SentenceTransformer, SparseEncoder, and CrossEncoder. Models produce a sequence of token-level vectors per input, while scoring uses MaxSim late-interaction. The naming mirrors SparseEncoder's output-shape framing ("late interaction" is one scoring strategy on top of multi-vector outputs, not the encoder itself).

The package at sentence_transformers/multi_vector_encoder/ ships with MultiVectorTransformer (query_length / document_length, PAD -> MASK query expansion, attend_to_expansion_tokens), MultiVectorMask (skiplist), HierarchicalPooling (Ward clustering for storage compression), four losses (MultiVectorMultipleNegativesRankingLoss, CachedMultiVectorMultipleNegativesRankingLoss, MultiVectorDistillKLDivLoss, MultiVectorMarginMSELoss), and five evaluators (MultiVector{InformationRetrieval,NanoBEIR,Triplet,Distillation,Reranking}Evaluator). ColBERT MaxSim scoring lives in util/similarity.py so model.similarity dispatch and evaluators can request "MaxSim" by name, while XTR scoring lives in multi_vector_encoder/scoring/xtr.py and slots into any of the four losses as score_metric=XTRScores(). There's also KDProcessing for join-at-iteration-time KD data, re-exported from sentence_transformers.util since dense distillation flows benefit too.

Hub interop is symmetric across the load paths: native ST saves load the obvious way, while PyLate v3 checkpoints (model_type == "ColBERT") are auto-promoted via _apply_legacy_fixups. Stanford-NLP checkpoints (architectures == ["HF_ColBERT"]) are detected and load the inline linear.weight + artifact.metadata via _load_default_modules. SentenceTransformer checkpoints with a final Dense head can be converted into multi-vector models with the projection weights preserved. I've tested e.g. lightonai/GTE-ModernColBERT-v1, colbert-ir/colbertv2.0, and answerdotai/answerai-colbert-small-v1, and all are round-trip within bf16 noise of PyLate.

#3614 was a prior community attempt at the same problem and is used as a reference for the inference subset only. This PR inherits from BaseModel rather than SentenceTransformer (the post-v5.4 pattern) and avoids the LateInteractionPooling module that conflated projection with masking.

Usage

from sentence_transformers import MultiVectorEncoder

# PyLate-trained ModernBERT-based ColBERT model on the Hub. Other compatible checkpoints include:
#   - "lightonai/Reason-ModernColBERT"
#   - "answerdotai/answerai-colbert-small-v1"
#   - "colbert-ir/colbertv2.0"
model = MultiVectorEncoder("answerdotai/answerai-colbert-small-v1")

queries = [
    "What is the capital of France?",
    "Who painted the Mona Lisa?",
]
documents = [
    "Paris is the capital and most populous city of France.",
    "The Mona Lisa was painted by Leonardo da Vinci in the 16th century.",
    "Berlin is the capital of Germany.",
    "Vincent van Gogh painted The Starry Night.",
    "The Eiffel Tower is a wrought-iron lattice tower in Paris, France.",
]

query_embeddings = model.encode_query(queries)
document_embeddings = model.encode_document(documents)
print(f"Query 0 shape:    {query_embeddings[0].shape}  (padded with mask tokens to query_length)")
print(f"Document 0 shape: {document_embeddings[0].shape}")
print(f"Document 4 shape: {document_embeddings[4].shape}  (longer document = more tokens)")
"""
Query 0 shape:    (32, 96)  (padded with mask tokens to query_length)
Document 0 shape: (13, 96)
Document 4 shape: (17, 96)  (longer document = more tokens)
"""

# MaxSim score every query against every document.
scores = model.similarity(query_embeddings, document_embeddings)
print(scores)
"""
tensor([[31.3581, 30.4041, 30.8991, 30.3172, 30.9990],
        [29.9239, 31.3825, 29.8381, 30.5045, 30.0169]])
"""
from sentence_transformers import MultiVectorEncoder

model = MultiVectorEncoder("tomaarsen/colpali-v1.3-merged-st")

queries = [
    "What is the variable represented on the y-axis of the graph?",
    "Total outlay is maximum in which year?",
]
images = [
    "https://huggingface.co/tomaarsen/colpali-v1.3-merged-st/resolve/main/assets/doc1.jpg",
    "https://huggingface.co/tomaarsen/colpali-v1.3-merged-st/resolve/main/assets/doc2.jpg",
    "https://huggingface.co/tomaarsen/colpali-v1.3-merged-st/resolve/main/assets/doc3.jpg",
    "https://huggingface.co/tomaarsen/colpali-v1.3-merged-st/resolve/main/assets/doc4.jpg",
]

query_embeddings = model.encode_query(queries, convert_to_tensor=True)
document_embeddings = model.encode_document(images, convert_to_tensor=True)
print(f"Query 0 shape:    {query_embeddings[0].shape}  (text-only path)")
print(f"Document 0 shape: {document_embeddings[0].shape}  (visual prompt + image patches)")
"""
Query 0 shape:    torch.Size([25, 128])  (text-only path)
Document 0 shape: torch.Size([1031, 128])  (visual prompt + image patches)
"""

scores = model.similarity(query_embeddings, document_embeddings)
print(scores)
"""
tensor([[19.5000, 17.3750, 17.3750, 16.7500],
        [ 5.5000, 11.1250,  5.4375,  6.4688]], device='cuda:0',
       dtype=torch.bfloat16)
"""

This work is still very much a work-in-progress. I'd like for the implementation to be sufficiently flexible that it can incorporate all forms of multi-vector models, from text-only (ColBERT) to text+image (ColPali) and much beyond. The idea is to implement as much as possible in standalone modules rather than the core MultiVectorEncoder class, so that future models with different architectural choices (e.g. different query expansion, skiplist, pooling, etc.) can be trained, evaluated, and loaded as expected.

There's still open questions in that regard, e.g. currently I'm working with a MultiVectorTransformer subclass of the Transformer, but perhaps I'd like to absorb all of those architectural choices into the core class and have the "Transformer" part be a more modular option. I'm also interested in trainable scoring mechanisms, but that also requires more architectural flexibility than currently exists.

I've also uploaded these models for testing:

cc @NohTow

  • Tom Aarsen

@tomaarsen tomaarsen marked this pull request as draft June 3, 2026 15:58
Comment on lines +108 to +109
try:
from transformers import PaliGemmaProcessor

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wild

# We can't set widget examples from an IterableDataset without losing data
continue

if dataset[dataset_name].format["type"] == "custom":

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants