Skip to content

Commit 52b1b60

Browse files
justin-cechmanektylerhutcherson
authored andcommitted
moves vectorizer dtype to base class attribute
1 parent bb93c23 commit 52b1b60

File tree

10 files changed

+156
-52
lines changed

10 files changed

+156
-52
lines changed

redisvl/utils/vectorize/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic.v1 import BaseModel, validator
66

77
from redisvl.redis.utils import array_to_buffer
8+
from redisvl.schema.fields import VectorDataType
89

910

1011
class Vectorizers(Enum):
@@ -19,11 +20,22 @@ class Vectorizers(Enum):
1920
class BaseVectorizer(BaseModel, ABC):
2021
model: str
2122
dims: int
23+
dtype: str
2224

2325
@property
2426
def type(self) -> str:
2527
return "base"
2628

29+
@validator("dtype")
30+
def check_dtype(dtype):
31+
try:
32+
VectorDataType(dtype.upper())
33+
except ValueError:
34+
raise ValueError(
35+
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
36+
)
37+
return dtype
38+
2739
@validator("dims")
2840
@classmethod
2941
def check_dims(cls, value):

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ class AzureOpenAITextVectorizer(BaseVectorizer):
5252
_aclient: Any = PrivateAttr()
5353

5454
def __init__(
55-
self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None
55+
self,
56+
model: str = "text-embedding-ada-002",
57+
api_config: Optional[Dict] = None,
58+
dtype: str = "float32",
5659
):
5760
"""Initialize the AzureOpenAI vectorizer.
5861
@@ -63,13 +66,17 @@ def __init__(
6366
api_config (Optional[Dict], optional): Dictionary containing the
6467
API key, API version, Azure endpoint, and any other API options.
6568
Defaults to None.
69+
dtype (str): the default datatype to use when embedding text as byte arrays.
70+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
71+
Defaults to 'float32'.
6672
6773
Raises:
6874
ImportError: If the openai library is not installed.
6975
ValueError: If the AzureOpenAI API key, version, or endpoint are not provided.
76+
ValueError: If an invalid dtype is provided.
7077
"""
7178
self._initialize_clients(api_config)
72-
super().__init__(model=model, dims=self._set_model_dims(model))
79+
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
7380

7481
def _initialize_clients(self, api_config: Optional[Dict]):
7582
"""
@@ -190,7 +197,7 @@ def embed_many(
190197
if len(texts) > 0 and not isinstance(texts[0], str):
191198
raise TypeError("Must pass in a list of str values to embed.")
192199

193-
dtype = kwargs.pop("dtype", "float32")
200+
dtype = kwargs.pop("dtype", self.dtype)
194201

195202
embeddings: List = []
196203
for batch in self.batchify(texts, batch_size, preprocess):
@@ -234,7 +241,7 @@ def embed(
234241
if preprocess:
235242
text = preprocess(text)
236243

237-
dtype = kwargs.pop("dtype", "float32")
244+
dtype = kwargs.pop("dtype", self.dtype)
238245

239246
result = self._client.embeddings.create(input=[text], model=self.model)
240247
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@@ -274,7 +281,7 @@ async def aembed_many(
274281
if len(texts) > 0 and not isinstance(texts[0], str):
275282
raise TypeError("Must pass in a list of str values to embed.")
276283

277-
dtype = kwargs.pop("dtype", "float32")
284+
dtype = kwargs.pop("dtype", self.dtype)
278285

279286
embeddings: List = []
280287
for batch in self.batchify(texts, batch_size, preprocess):
@@ -320,7 +327,7 @@ async def aembed(
320327
if preprocess:
321328
text = preprocess(text)
322329

323-
dtype = kwargs.pop("dtype", "float32")
330+
dtype = kwargs.pop("dtype", self.dtype)
324331

325332
result = await self._aclient.embeddings.create(input=[text], model=self.model)
326333
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)

redisvl/utils/vectorize/text/bedrock.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
self,
5050
model: str = "amazon.titan-embed-text-v2:0",
5151
api_config: Optional[Dict[str, str]] = None,
52+
dtype: str = "float32",
5253
) -> None:
5354
"""Initialize the AWS Bedrock Vectorizer.
5455
@@ -57,10 +58,13 @@ def __init__(
5758
api_config (Optional[Dict[str, str]]): AWS credentials and config.
5859
Can include: aws_access_key_id, aws_secret_access_key, aws_region
5960
If not provided, will use environment variables.
61+
dtype (str): the default datatype to use when embedding text as byte arrays.
62+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
6063
6164
Raises:
6265
ValueError: If credentials are not provided in config or environment.
6366
ImportError: If boto3 is not installed.
67+
ValueError: If an invalid dtype is provided.
6468
"""
6569
try:
6670
import boto3 # type: ignore
@@ -94,7 +98,7 @@ def __init__(
9498
region_name=aws_region,
9599
)
96100

97-
super().__init__(model=model, dims=self._set_model_dims(model))
101+
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
98102

99103
def _set_model_dims(self, model: str) -> int:
100104
"""Initialize model and determine embedding dimensions."""
@@ -145,7 +149,7 @@ def embed(
145149
response_body = json.loads(response["body"].read())
146150
embedding = response_body["embedding"]
147151

148-
dtype = kwargs.pop("dtype", "float32")
152+
dtype = kwargs.pop("dtype", self.dtype)
149153
return self._process_embedding(embedding, as_buffer, dtype)
150154

151155
@retry(
@@ -181,7 +185,7 @@ def embed_many(
181185
raise TypeError("Texts must be a list of strings")
182186

183187
embeddings: List[List[float]] = []
184-
dtype = kwargs.pop("dtype", "float32")
188+
dtype = kwargs.pop("dtype", self.dtype)
185189

186190
for batch in self.batchify(texts, batch_size, preprocess):
187191
# Process each text in the batch individually since Bedrock

redisvl/utils/vectorize/text/cohere.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ class CohereTextVectorizer(BaseVectorizer):
4747
_client: Any = PrivateAttr()
4848

4949
def __init__(
50-
self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None
50+
self,
51+
model: str = "embed-english-v3.0",
52+
api_config: Optional[Dict] = None,
53+
dtype: str = "float32",
5154
):
5255
"""Initialize the Cohere vectorizer.
5356
@@ -57,14 +60,17 @@ def __init__(
5760
model (str): Model to use for embedding. Defaults to 'embed-english-v3.0'.
5861
api_config (Optional[Dict], optional): Dictionary containing the API key.
5962
Defaults to None.
63+
dtype (str): the default datatype to use when embedding text as byte arrays.
64+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
65+
Defaults to 'float32'.
6066
6167
Raises:
6268
ImportError: If the cohere library is not installed.
6369
ValueError: If the API key is not provided.
64-
70+
ValueError: If an invalid dtype is provided.
6571
"""
6672
self._initialize_client(api_config)
67-
super().__init__(model=model, dims=self._set_model_dims(model))
73+
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
6874

6975
def _initialize_client(self, api_config: Optional[Dict]):
7076
"""
@@ -159,7 +165,7 @@ def embed(
159165
if preprocess:
160166
text = preprocess(text)
161167

162-
dtype = kwargs.pop("dtype", "float32")
168+
dtype = kwargs.pop("dtype", self.dtype)
163169

164170
embedding = self._client.embed(
165171
texts=[text], model=self.model, input_type=input_type
@@ -228,7 +234,7 @@ def embed_many(
228234
See https://docs.cohere.com/reference/embed."
229235
)
230236

231-
dtype = kwargs.pop("dtype", "float32")
237+
dtype = kwargs.pop("dtype", self.dtype)
232238

233239
embeddings: List = []
234240
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/utils/vectorize/text/custom.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class CustomTextVectorizer(BaseVectorizer):
10-
"""The CustomTextVectorizer class wraps user-defined embeding methods to create
10+
"""The CustomTextVectorizer class wraps user-defined embedding methods to create
1111
embeddings for text data.
1212
1313
This vectorizer is designed to accept a provided callable text vectorizer and
@@ -44,6 +44,7 @@ def __init__(
4444
embed_many: Optional[Callable] = None,
4545
aembed: Optional[Callable] = None,
4646
aembed_many: Optional[Callable] = None,
47+
dtype: str = "float32",
4748
):
4849
"""Initialize the Custom vectorizer.
4950
@@ -52,10 +53,14 @@ def __init__(
5253
embed_many (Optional[Callable)]: a Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None.
5354
aembed (Optional[Callable]): an asyncronous Callable function that accepts a string object and returns a lists of floats. Defaults to None.
5455
aembed_many (Optional[Callable]): an asyncronous Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None.
56+
dtype (str): the default datatype to use when embedding text as byte arrays.
57+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
58+
Defaults to 'float32'.
5559
5660
Raises:
57-
ValueError if any of the provided functions accept or return incorrect types.
58-
TypeError if any of the provided functions are not Callable objects.
61+
ValueError: if any of the provided functions accept or return incorrect types.
62+
TypeError: if any of the provided functions are not Callable objects.
63+
ValueError: If an invalid dtype is provided.
5964
"""
6065

6166
self._validate_embed(embed)
@@ -71,7 +76,7 @@ def __init__(
7176
self._validate_aembed_many(aembed_many)
7277
self._aembed_many_func = aembed_many
7378

74-
super().__init__(model=self.type, dims=self._set_model_dims())
79+
super().__init__(model=self.type, dims=self._set_model_dims(), dtype=dtype)
7580

7681
def _validate_embed(self, func: Callable):
7782
"""calls the func with dummy input and validates that it returns a vector"""
@@ -173,7 +178,7 @@ def embed(
173178
if preprocess:
174179
text = preprocess(text)
175180

176-
dtype = kwargs.pop("dtype", "float32")
181+
dtype = kwargs.pop("dtype", self.dtype)
177182

178183
result = self._embed_func(text, **kwargs)
179184
return self._process_embedding(result, as_buffer, dtype)
@@ -212,7 +217,7 @@ def embed_many(
212217
if not self._embed_many_func:
213218
raise NotImplementedError
214219

215-
dtype = kwargs.pop("dtype", "float32")
220+
dtype = kwargs.pop("dtype", self.dtype)
216221

217222
embeddings: List = []
218223
for batch in self.batchify(texts, batch_size, preprocess):
@@ -254,7 +259,7 @@ async def aembed(
254259
if preprocess:
255260
text = preprocess(text)
256261

257-
dtype = kwargs.pop("dtype", "float32")
262+
dtype = kwargs.pop("dtype", self.dtype)
258263

259264
result = await self._aembed_func(text, **kwargs)
260265
return self._process_embedding(result, as_buffer, dtype)
@@ -293,7 +298,7 @@ async def aembed_many(
293298
if not self._aembed_many_func:
294299
raise NotImplementedError
295300

296-
dtype = kwargs.pop("dtype", "float32")
301+
dtype = kwargs.pop("dtype", self.dtype)
297302

298303
embeddings: List = []
299304
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/utils/vectorize/text/huggingface.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,28 @@ class HFTextVectorizer(BaseVectorizer):
3333
_client: Any = PrivateAttr()
3434

3535
def __init__(
36-
self, model: str = "sentence-transformers/all-mpnet-base-v2", **kwargs
36+
self,
37+
model: str = "sentence-transformers/all-mpnet-base-v2",
38+
dtype: str = "float32",
39+
**kwargs,
3740
):
3841
"""Initialize the Hugging Face text vectorizer.
3942
4043
Args:
4144
model (str): The pre-trained model from Hugging Face's Sentence
4245
Transformers to be used for embedding. Defaults to
4346
'sentence-transformers/all-mpnet-base-v2'.
47+
dtype (str): the default datatype to use when embedding text as byte arrays.
48+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
49+
Defaults to 'float32'.
4450
4551
Raises:
4652
ImportError: If the sentence-transformers library is not installed.
4753
ValueError: If there is an error setting the embedding model dimensions.
54+
ValueError: If an invalid dtype is provided.
4855
"""
4956
self._initialize_client(model)
50-
super().__init__(model=model, dims=self._set_model_dims())
57+
super().__init__(model=model, dims=self._set_model_dims(), dtype=dtype)
5158

5259
def _initialize_client(self, model: str):
5360
"""Setup the HuggingFace client"""
@@ -100,7 +107,7 @@ def embed(
100107
if preprocess:
101108
text = preprocess(text)
102109

103-
dtype = kwargs.pop("dtype", "float32")
110+
dtype = kwargs.pop("dtype", self.dtype)
104111

105112
embedding = self._client.encode([text], **kwargs)[0]
106113
return self._process_embedding(embedding.tolist(), as_buffer, dtype)
@@ -136,7 +143,7 @@ def embed_many(
136143
if len(texts) > 0 and not isinstance(texts[0], str):
137144
raise TypeError("Must pass in a list of str values to embed.")
138145

139-
dtype = kwargs.pop("dtype", "float32")
146+
dtype = kwargs.pop("dtype", self.dtype)
140147

141148
embeddings: List = []
142149
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/utils/vectorize/text/mistral.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,30 @@ class MistralAITextVectorizer(BaseVectorizer):
4646
_client: Any = PrivateAttr()
4747
_aclient: Any = PrivateAttr()
4848

49-
def __init__(self, model: str = "mistral-embed", api_config: Optional[Dict] = None):
49+
def __init__(
50+
self,
51+
model: str = "mistral-embed",
52+
api_config: Optional[Dict] = None,
53+
dtype: str = "float32",
54+
):
5055
"""Initialize the MistralAI vectorizer.
5156
5257
Args:
5358
model (str): Model to use for embedding. Defaults to
5459
'text-embedding-ada-002'.
5560
api_config (Optional[Dict], optional): Dictionary containing the
5661
API key. Defaults to None.
62+
dtype (str): the default datatype to use when embedding text as byte arrays.
63+
Used when setting `as_buffer=True` in calls to embed() and embed_many().
64+
Defaults to 'float32'.
5765
5866
Raises:
5967
ImportError: If the mistralai library is not installed.
6068
ValueError: If the Mistral API key is not provided.
69+
ValueError: If an invalid dtype is provided.
6170
"""
6271
self._initialize_clients(api_config)
63-
super().__init__(model=model, dims=self._set_model_dims(model))
72+
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
6473

6574
def _initialize_clients(self, api_config: Optional[Dict]):
6675
"""
@@ -140,7 +149,7 @@ def embed_many(
140149
if len(texts) > 0 and not isinstance(texts[0], str):
141150
raise TypeError("Must pass in a list of str values to embed.")
142151

143-
dtype = kwargs.pop("dtype", "float32")
152+
dtype = kwargs.pop("dtype", self.dtype)
144153

145154
embeddings: List = []
146155
for batch in self.batchify(texts, batch_size, preprocess):
@@ -184,7 +193,7 @@ def embed(
184193
if preprocess:
185194
text = preprocess(text)
186195

187-
dtype = kwargs.pop("dtype", "float32")
196+
dtype = kwargs.pop("dtype", self.dtype)
188197

189198
result = self._client.embeddings(model=self.model, input=[text])
190199
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@@ -224,7 +233,7 @@ async def aembed_many(
224233
if len(texts) > 0 and not isinstance(texts[0], str):
225234
raise TypeError("Must pass in a list of str values to embed.")
226235

227-
dtype = kwargs.pop("dtype", "float32")
236+
dtype = kwargs.pop("dtype", self.dtype)
228237

229238
embeddings: List = []
230239
for batch in self.batchify(texts, batch_size, preprocess):
@@ -268,7 +277,7 @@ async def aembed(
268277
if preprocess:
269278
text = preprocess(text)
270279

271-
dtype = kwargs.pop("dtype", "float32")
280+
dtype = kwargs.pop("dtype", self.dtype)
272281

273282
result = await self._aclient.embeddings(model=self.model, input=[text])
274283
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)

0 commit comments

Comments
 (0)