Skip to content

Commit 9dd81b9

Browse files
committed
[ENH] Add muvera support
1 parent 833a982 commit 9dd81b9

File tree

6 files changed

+881
-0
lines changed

6 files changed

+881
-0
lines changed

chromadb/test/ef/test_ef.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_get_builtins_holds() -> None:
5757
"ChromaCloudQwenEmbeddingFunction",
5858
"ChromaCloudSpladeEmbeddingFunction",
5959
"ChromaBm25EmbeddingFunction",
60+
"PylateColBERTEmbeddingFunction",
6061
}
6162

6263
assert expected_builtins == embedding_functions.get_builtins()

chromadb/utils/embedding_functions/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@
9292
ChromaBm25EmbeddingFunction,
9393
)
9494

95+
from chromadb.utils.embedding_functions.pylate_colbert_embedding_function import (
96+
PylateColBERTEmbeddingFunction,
97+
)
9598

9699
# Get all the class names for backward compatibility
97100
_all_classes: Set[str] = {
@@ -127,6 +130,7 @@
127130
"ChromaCloudQwenEmbeddingFunction",
128131
"ChromaCloudSpladeEmbeddingFunction",
129132
"ChromaBm25EmbeddingFunction",
133+
"PylateColBERTEmbeddingFunction",
130134
}
131135

132136

@@ -163,6 +167,7 @@ def get_builtins() -> Set[str]:
163167
"cloudflare_workers_ai": CloudflareWorkersAIEmbeddingFunction,
164168
"together_ai": TogetherAIEmbeddingFunction,
165169
"chroma-cloud-qwen": ChromaCloudQwenEmbeddingFunction,
170+
"pylate_colbert": PylateColBERTEmbeddingFunction,
166171
}
167172

168173
sparse_known_embedding_functions: Dict[str, Type[SparseEmbeddingFunction]] = { # type: ignore
@@ -291,6 +296,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
291296
"ChromaCloudQwenEmbeddingFunction",
292297
"ChromaCloudSpladeEmbeddingFunction",
293298
"ChromaBm25EmbeddingFunction",
299+
"PylateColBERTEmbeddingFunction",
294300
"register_embedding_function",
295301
"config_to_embedding_function",
296302
"known_embedding_functions",

chromadb/utils/embedding_functions/jina_embedding_function.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import numpy as np
1313
import warnings
14+
from chromadb.utils.muvera import create_fdes
1415
import importlib
1516
import base64
1617
import io
@@ -37,6 +38,7 @@ def __init__(
3738
dimensions: Optional[int] = None,
3839
embedding_type: Optional[str] = None,
3940
normalized: Optional[bool] = None,
41+
return_multivector: Optional[bool] = None,
4042
query_config: Optional[JinaQueryConfig] = None,
4143
):
4244
"""
@@ -101,6 +103,7 @@ def __init__(
101103
self.dimensions = dimensions
102104
self.embedding_type = embedding_type
103105
self.normalized = normalized
106+
self.return_multivector = return_multivector
104107
self.query_config = query_config
105108

106109
self._api_url = "https://api.jina.ai/v1/embeddings"
@@ -149,6 +152,8 @@ def _build_payload(self, input: Embeddable, is_query: bool) -> Dict[str, Any]:
149152
payload["embedding_type"] = self.embedding_type
150153
if self.normalized is not None:
151154
payload["normalized"] = self.normalized
155+
if self.return_multivector is not None:
156+
payload["return_multivector"] = self.return_multivector
152157

153158
# overwrite parameteres when query payload is used
154159
if is_query and self.query_config is not None:
@@ -170,6 +175,35 @@ def _convert_resp(self, resp: Any, is_query: bool = False) -> Embeddings:
170175
if "data" not in resp:
171176
raise RuntimeError(resp.get("detail", "Unknown error"))
172177

178+
if self.return_multivector:
179+
# if it gives back multivector embeddings
180+
multi_embeddings_data: List[Dict[str, Any]] = resp["data"]
181+
sorted_multi_embeddings = sorted(
182+
multi_embeddings_data, key=lambda e: e["index"]
183+
)
184+
multi_embeddings: List[Embeddings] = [
185+
[
186+
np.array(vec, dtype=np.float32)
187+
for vec in multi_embedding_obj["embeddings"]
188+
]
189+
for multi_embedding_obj in sorted_multi_embeddings
190+
]
191+
192+
if not multi_embeddings or not multi_embeddings[0]:
193+
raise RuntimeError(
194+
"Invalid multivector embeddings format from Jina API"
195+
)
196+
197+
dims = len(multi_embeddings[0][0])
198+
fdes = create_fdes(
199+
multi_embeddings,
200+
dims=dims,
201+
is_query=is_query,
202+
fill_empty_partitions=not is_query,
203+
)
204+
205+
return fdes
206+
173207
embeddings_data: List[Dict[str, Union[int, List[float]]]] = resp["data"]
174208

175209
# Sort resulting embeddings by index
@@ -231,6 +265,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]"
231265
dimensions = config.get("dimensions")
232266
embedding_type = config.get("embedding_type")
233267
normalized = config.get("normalized")
268+
return_multivector = config.get("return_multivector")
234269
query_config = config.get("query_config")
235270

236271
if api_key_env_var is None or model_name is None:
@@ -245,6 +280,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]"
245280
dimensions=dimensions,
246281
embedding_type=embedding_type,
247282
normalized=normalized,
283+
return_multivector=return_multivector,
248284
query_config=query_config,
249285
)
250286

@@ -258,6 +294,7 @@ def get_config(self) -> Dict[str, Any]:
258294
"dimensions": self.dimensions,
259295
"embedding_type": self.embedding_type,
260296
"normalized": self.normalized,
297+
"return_multivector": self.return_multivector,
261298
"query_config": self.query_config,
262299
}
263300

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
2+
from typing import List, Dict, Any
3+
from chromadb.utils.embedding_functions.schemas import validate_config_schema
4+
from chromadb.utils.muvera import create_fdes
5+
6+
7+
class PylateColBERTEmbeddingFunction(EmbeddingFunction[Documents]):
8+
"""
9+
This class is used to get embeddings for a list of texts using the ColBERT API.
10+
"""
11+
12+
def __init__(
13+
self,
14+
model_name: str,
15+
):
16+
"""
17+
Initialize the PylateColBERTEmbeddingFunction.
18+
19+
Args:
20+
model_name (str): The name of the model to use for text embeddings.
21+
Examples: "mixedbread-ai/mxbai-edge-colbert-v0-17m", "mixedbread-ai/mxbai-edge-colbert-v0-32m", "lightonai/colbertv2.0", "answerdotai/answerai-colbert-small-v1", "jinaai/jina-colbert-v2", "GTE-ModernColBERT-v1"
22+
"""
23+
try:
24+
from pylate import models
25+
except ImportError:
26+
raise ValueError(
27+
"The pylate colbert python package is not installed. Please install it with `pip install pylate-colbert`"
28+
)
29+
30+
self.model_name = model_name
31+
self.model = models.ColBERT(model_name_or_path=model_name)
32+
33+
def __call__(self, input: Documents) -> Embeddings:
34+
"""
35+
Get the embeddings for a list of texts.
36+
37+
Args:
38+
input (Documents): A list of texts to get embeddings for.
39+
40+
Returns:
41+
Embeddings: The embeddings for the texts.
42+
"""
43+
multivec = self.model.encode(input, batch_size=32, is_query=False)
44+
if not multivec or not multivec[0]:
45+
raise ValueError("Model returned empty multivector embeddings")
46+
return create_fdes(
47+
multivec,
48+
dims=len(multivec[0][0]),
49+
is_query=False,
50+
fill_empty_partitions=True,
51+
)
52+
53+
def embed_query(self, input: Documents) -> Embeddings:
54+
"""
55+
Get the embeddings for a list of texts.
56+
57+
Args:
58+
input (Documents): A list of texts to get embeddings for.
59+
60+
Returns:
61+
Embeddings: The embeddings for the texts.
62+
"""
63+
multivec = self.model.encode(input, batch_size=32, is_query=True)
64+
if not multivec or not multivec[0]:
65+
raise ValueError("Model returned empty multivector embeddings")
66+
return create_fdes(
67+
multivec,
68+
dims=len(multivec[0][0]),
69+
is_query=True,
70+
fill_empty_partitions=False,
71+
)
72+
73+
@staticmethod
74+
def name() -> str:
75+
return "pylate_colbert"
76+
77+
def default_space(self) -> Space:
78+
return "ip" # muvera uses dot product to approximate multivec similarity
79+
80+
def supported_spaces(self) -> List[Space]:
81+
return [
82+
"ip"
83+
] # no cosine bc muvera does not normalize the fde, no l2 bc muvera uses dot product
84+
85+
@staticmethod
86+
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
87+
model_name = config.get("model_name")
88+
89+
if model_name is None:
90+
assert False, "This code should not be reached"
91+
92+
return PylateColBERTEmbeddingFunction(model_name=model_name)
93+
94+
def get_config(self) -> Dict[str, Any]:
95+
return {"model_name": self.model_name}
96+
97+
def validate_config_update(
98+
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
99+
) -> None:
100+
if "model_name" in new_config:
101+
raise ValueError(
102+
"The model name cannot be changed after the embedding function has been initialized."
103+
)
104+
105+
@staticmethod
106+
def validate_config(config: Dict[str, Any]) -> None:
107+
"""
108+
Validate the configuration using the JSON schema.
109+
110+
Args:
111+
config: Configuration to validate
112+
113+
Raises:
114+
ValidationError: If the configuration does not match the schema
115+
"""
116+
validate_config_schema(config, "pylate_colbert")

0 commit comments

Comments
 (0)