Skip to content

Commit 4135ad5

Browse files
author
Milder Hernandez Cagua
committed
Clean Vector search interfaces
1 parent 14a5b3d commit 4135ad5

15 files changed

+274
-462
lines changed

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java

Lines changed: 50 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import com.azure.search.documents.models.VectorQuery;
1515
import com.azure.search.documents.models.VectorizableTextQuery;
1616
import com.azure.search.documents.models.VectorizedQuery;
17-
import com.microsoft.semantickernel.data.vectorsearch.VectorizableSearch;
17+
import com.microsoft.semantickernel.data.vectorsearch.VectorizableTextSearch;
1818
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
1919
import com.microsoft.semantickernel.data.vectorsearch.VectorizedSearch;
2020
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordCollection;
@@ -28,9 +28,6 @@
2828
import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions;
2929
import com.microsoft.semantickernel.data.vectorstorage.options.UpsertRecordOptions;
3030
import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions;
31-
import com.microsoft.semantickernel.data.vectorsearch.queries.VectorSearchQuery;
32-
import com.microsoft.semantickernel.data.vectorsearch.queries.VectorizableTextSearchQuery;
33-
import com.microsoft.semantickernel.data.vectorsearch.queries.VectorizedSearchQuery;
3431
import com.microsoft.semantickernel.exceptions.SKException;
3532
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
3633
import java.time.OffsetDateTime;
@@ -49,7 +46,7 @@
4946
public class AzureAISearchVectorStoreRecordCollection<Record> implements
5047
VectorStoreRecordCollection<String, Record>,
5148
VectorizedSearch<Record>,
52-
VectorizableSearch<Record> {
49+
VectorizableTextSearch<Record> {
5350

5451
private static final HashSet<Class<?>> supportedKeyTypes = new HashSet<>(
5552
Collections.singletonList(
@@ -277,8 +274,25 @@ public Mono<Void> deleteBatchAsync(List<String> keys, DeleteRecordOptions option
277274
}).collect(Collectors.toList())).then();
278275
}
279276

280-
private Mono<List<VectorSearchResult<Record>>> searchAndMapAsync(SearchOptions searchOptions,
277+
private Mono<List<VectorSearchResult<Record>>> searchAndMapAsync(
278+
List<VectorQuery> vectorQueries, VectorSearchOptions options,
281279
GetRecordOptions getRecordOptions) {
280+
281+
String filter = AzureAISearchVectorStoreCollectionSearchMapping
282+
.buildFilterString(options.getVectorSearchFilter(), recordDefinition);
283+
284+
SearchOptions searchOptions = new SearchOptions()
285+
.setFilter(filter)
286+
.setTop(options.getLimit())
287+
.setSkip(options.getOffset())
288+
.setScoringParameters()
289+
.setVectorSearchOptions(new com.azure.search.documents.models.VectorSearchOptions()
290+
.setQueries(vectorQueries));
291+
292+
if (!options.isIncludeVectors()) {
293+
searchOptions.setSelect(nonVectorFields.toArray(new String[0]));
294+
}
295+
282296
VectorStoreRecordMapper<Record, SearchDocument> mapper = this.options
283297
.getVectorStoreRecordMapper();
284298

@@ -300,70 +314,32 @@ record = response.getDocument(this.options.getRecordClass());
300314
}
301315

302316
/**
303-
* Search the vector store for records that match the given embedding and filter.
317+
* Vectorizable text search. This method searches for records that are similar to the given text.
304318
*
305-
* @param query The vector search query.
319+
* @param searchText The text to search with.
320+
* @param options The options to use for the search.
306321
* @return A list of search results.
307322
*/
308323
@Override
309-
public Mono<List<VectorSearchResult<Record>>> searchAsync(VectorSearchQuery query) {
324+
public Mono<List<VectorSearchResult<Record>>> searchAsync(String searchText,
325+
VectorSearchOptions options) {
310326
if (firstVectorFieldName == null) {
311327
throw new SKException("No vector fields defined. Cannot perform vector search");
312328
}
313329

314-
VectorSearchOptions options = query.getSearchOptions();
315330
if (options == null) {
316331
options = VectorSearchOptions.createDefault(firstVectorFieldName);
317332
}
318333

319334
List<VectorQuery> vectorQueries = new ArrayList<>();
320-
321-
if (query instanceof VectorizedSearchQuery) {
322-
vectorQueries.add(new VectorizedQuery(((VectorizedSearchQuery) query).getVector())
323-
.setFields(recordDefinition.getField(options.getVectorFieldName() != null
324-
? options.getVectorFieldName()
325-
: firstVectorFieldName).getEffectiveStorageName())
326-
.setKNearestNeighborsCount(options.getLimit()));
327-
} else if (query instanceof VectorizableTextSearchQuery) {
328-
vectorQueries
329-
.add(new VectorizableTextQuery(((VectorizableTextSearchQuery) query).getQueryText())
330-
.setFields(recordDefinition.getField(options.getVectorFieldName() != null
331-
? options.getVectorFieldName()
332-
: firstVectorFieldName).getEffectiveStorageName())
333-
.setKNearestNeighborsCount(options.getLimit()));
334-
} else {
335-
throw new SKException("Unsupported query type: " + query.getQueryType());
336-
}
337-
338-
String filter = AzureAISearchVectorStoreCollectionSearchMapping
339-
.buildFilterString(options.getVectorSearchFilter(), recordDefinition);
340-
341-
SearchOptions searchOptions = new SearchOptions()
342-
.setFilter(filter)
343-
.setTop(options.getLimit())
344-
.setSkip(options.getOffset())
345-
.setScoringParameters()
346-
.setVectorSearchOptions(new com.azure.search.documents.models.VectorSearchOptions()
347-
.setQueries(vectorQueries));
348-
349-
if (!options.isIncludeVectors()) {
350-
searchOptions.setSelect(nonVectorFields.toArray(new String[0]));
351-
}
352-
353-
return searchAndMapAsync(searchOptions, new GetRecordOptions(options.isIncludeVectors()));
354-
}
355-
356-
/**
357-
* Vectorizable text search. This method searches for records that are similar to the given text.
358-
*
359-
* @param searchText The text to search with.
360-
* @param options The options to use for the search.
361-
* @return A list of search results.
362-
*/
363-
@Override
364-
public Mono<List<VectorSearchResult<Record>>> searchAsync(String searchText,
365-
VectorSearchOptions options) {
366-
return searchAsync(VectorSearchQuery.createQuery(searchText, options));
335+
vectorQueries.add(new VectorizableTextQuery(searchText)
336+
.setFields(recordDefinition.getField(options.getVectorFieldName() != null
337+
? options.getVectorFieldName()
338+
: firstVectorFieldName).getEffectiveStorageName())
339+
.setKNearestNeighborsCount(options.getLimit()));
340+
341+
return searchAndMapAsync(vectorQueries, options,
342+
new GetRecordOptions(options.isIncludeVectors()));
367343
}
368344

369345
/**
@@ -376,6 +352,22 @@ public Mono<List<VectorSearchResult<Record>>> searchAsync(String searchText,
376352
@Override
377353
public Mono<List<VectorSearchResult<Record>>> searchAsync(List<Float> vector,
378354
VectorSearchOptions options) {
379-
return searchAsync(VectorSearchQuery.createQuery(vector, options));
355+
if (firstVectorFieldName == null) {
356+
throw new SKException("No vector fields defined. Cannot perform vector search");
357+
}
358+
359+
if (options == null) {
360+
options = VectorSearchOptions.createDefault(firstVectorFieldName);
361+
}
362+
363+
List<VectorQuery> vectorQueries = new ArrayList<>();
364+
vectorQueries.add(new VectorizedQuery(vector)
365+
.setFields(recordDefinition.getField(options.getVectorFieldName() != null
366+
? options.getVectorFieldName()
367+
: firstVectorFieldName).getEffectiveStorageName())
368+
.setKNearestNeighborsCount(options.getLimit()));
369+
370+
return searchAndMapAsync(vectorQueries, options,
371+
new GetRecordOptions(options.isIncludeVectors()));
380372
}
381373
}

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java

Lines changed: 70 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import com.fasterxml.jackson.databind.node.ArrayNode;
77
import com.microsoft.semantickernel.data.vectorsearch.VectorOperations;
88
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
9-
import com.microsoft.semantickernel.data.vectorsearch.queries.VectorSearchQuery;
10-
import com.microsoft.semantickernel.data.vectorsearch.queries.VectorizedSearchQuery;
119
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordMapper;
1210
import com.microsoft.semantickernel.data.vectorstorage.definition.DistanceFunction;
1311
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
@@ -471,101 +469,94 @@ protected <Record> List<Record> getRecordsWithFilter(String collectionName,
471469
*
472470
* @param <Record> the record type
473471
* @param collectionName the collection name
474-
* @param query the vectorized search query, containing the vector and search options
472+
* @param vector the vector to search with
473+
* @param options the search options
475474
* @param recordDefinition the record definition
476475
* @param mapper the mapper, responsible for mapping the result set to the record type.
477476
* @return the search results
478477
*/
479478
@Override
480479
public <Record> List<VectorSearchResult<Record>> search(String collectionName,
481-
VectorSearchQuery query, VectorStoreRecordDefinition recordDefinition,
480+
List<Float> vector, VectorSearchOptions options,
481+
VectorStoreRecordDefinition recordDefinition,
482482
VectorStoreRecordMapper<Record, ResultSet> mapper) {
483483
if (recordDefinition.getVectorFields().isEmpty()) {
484484
throw new SKException("No vector fields defined. Cannot perform vector search");
485485
}
486486

487-
if (query instanceof VectorizedSearchQuery) {
488-
VectorizedSearchQuery vectorizedSearchQuery = (VectorizedSearchQuery) query;
489-
VectorSearchOptions options = query.getSearchOptions();
487+
VectorStoreRecordVectorField firstVectorField = recordDefinition.getVectorFields()
488+
.get(0);
489+
if (options == null) {
490+
options = VectorSearchOptions.createDefault(firstVectorField.getName());
491+
}
490492

491-
VectorStoreRecordVectorField firstVectorField = recordDefinition.getVectorFields()
492-
.get(0);
493-
if (options == null) {
494-
options = VectorSearchOptions.createDefault(firstVectorField.getName());
493+
VectorStoreRecordVectorField vectorField = options.getVectorFieldName() == null
494+
? firstVectorField
495+
: (VectorStoreRecordVectorField) recordDefinition
496+
.getField(options.getVectorFieldName());
497+
498+
String filter = SQLVectorStoreRecordCollectionSearchMapping
499+
.buildFilter(options.getVectorSearchFilter(), recordDefinition);
500+
List<Object> parameters = SQLVectorStoreRecordCollectionSearchMapping
501+
.getFilterParameters(options.getVectorSearchFilter());
502+
503+
List<Record> records = getRecordsWithFilter(collectionName, recordDefinition, mapper,
504+
new GetRecordOptions(true), filter, parameters);
505+
List<VectorSearchResult<Record>> results = new ArrayList<>();
506+
507+
DistanceFunction distanceFunction = vectorField.getDistanceFunction() == null
508+
? DistanceFunction.EUCLIDEAN_DISTANCE
509+
: vectorField.getDistanceFunction();
510+
511+
for (Record record : records) {
512+
List<Float> recordVector;
513+
try {
514+
String json = new ObjectMapper().writeValueAsString(record);
515+
ArrayNode arrayNode = (ArrayNode) new ObjectMapper().readTree(json)
516+
.get(vectorField.getEffectiveStorageName());
517+
518+
recordVector = Stream.iterate(0, i -> i + 1)
519+
.limit(arrayNode.size())
520+
.map(i -> arrayNode.get(i).floatValue())
521+
.collect(Collectors.toList());
522+
} catch (JsonProcessingException e) {
523+
throw new RuntimeException(e);
495524
}
496525

497-
VectorStoreRecordVectorField vectorField = options.getVectorFieldName() == null
498-
? firstVectorField
499-
: (VectorStoreRecordVectorField) recordDefinition
500-
.getField(options.getVectorFieldName());
501-
502-
String filter = SQLVectorStoreRecordCollectionSearchMapping
503-
.buildFilter(options.getVectorSearchFilter(), recordDefinition);
504-
List<Object> parameters = SQLVectorStoreRecordCollectionSearchMapping
505-
.getFilterParameters(options.getVectorSearchFilter());
506-
507-
List<Record> records = getRecordsWithFilter(collectionName, recordDefinition, mapper,
508-
new GetRecordOptions(true), filter, parameters);
509-
List<VectorSearchResult<Record>> results = new ArrayList<>();
510-
511-
DistanceFunction distanceFunction = vectorField.getDistanceFunction() == null
512-
? DistanceFunction.EUCLIDEAN_DISTANCE
513-
: vectorField.getDistanceFunction();
514-
515-
for (Record record : records) {
516-
List<Float> vector;
517-
try {
518-
String json = new ObjectMapper().writeValueAsString(record);
519-
ArrayNode arrayNode = (ArrayNode) new ObjectMapper().readTree(json)
520-
.get(vectorField.getEffectiveStorageName());
521-
522-
vector = Stream.iterate(0, i -> i + 1)
523-
.limit(arrayNode.size())
524-
.map(i -> arrayNode.get(i).floatValue())
525-
.collect(Collectors.toList());
526-
} catch (JsonProcessingException e) {
527-
throw new RuntimeException(e);
528-
}
529-
530-
double score;
531-
switch (distanceFunction) {
532-
case COSINE_SIMILARITY:
533-
score = VectorOperations.cosineSimilarity(vectorizedSearchQuery.getVector(),
534-
vector);
535-
break;
536-
case COSINE_DISTANCE:
537-
score = VectorOperations.cosineDistance(vectorizedSearchQuery.getVector(),
538-
vector);
539-
break;
540-
case EUCLIDEAN_DISTANCE:
541-
score = VectorOperations
542-
.euclideanDistance(vectorizedSearchQuery.getVector(), vector);
543-
break;
544-
case DOT_PRODUCT:
545-
score = VectorOperations.dot(vectorizedSearchQuery.getVector(), vector);
546-
break;
547-
default:
548-
throw new SKException("Unsupported distance function");
549-
}
550-
551-
results.add(new VectorSearchResult<>(record, score));
526+
double score;
527+
switch (distanceFunction) {
528+
case COSINE_SIMILARITY:
529+
score = VectorOperations.cosineSimilarity(vector, recordVector);
530+
break;
531+
case COSINE_DISTANCE:
532+
score = VectorOperations.cosineDistance(vector, recordVector);
533+
break;
534+
case EUCLIDEAN_DISTANCE:
535+
score = VectorOperations
536+
.euclideanDistance(vector, recordVector);
537+
break;
538+
case DOT_PRODUCT:
539+
score = VectorOperations.dot(vector, recordVector);
540+
break;
541+
default:
542+
throw new SKException("Unsupported distance function");
552543
}
553544

554-
Comparator<VectorSearchResult<Record>> comparator = Comparator
555-
.comparingDouble(VectorSearchResult::getScore);
556-
// Higher scores are better
557-
if (distanceFunction == DistanceFunction.COSINE_SIMILARITY
558-
|| distanceFunction == DistanceFunction.DOT_PRODUCT) {
559-
comparator = comparator.reversed();
560-
}
561-
return results.stream()
562-
.sorted(comparator)
563-
.skip(options.getOffset())
564-
.limit(options.getLimit())
565-
.collect(Collectors.toList());
545+
results.add(new VectorSearchResult<>(record, score));
566546
}
567547

568-
throw new SKException("Unsupported query type");
548+
Comparator<VectorSearchResult<Record>> comparator = Comparator
549+
.comparingDouble(VectorSearchResult::getScore);
550+
// Higher scores are better
551+
if (distanceFunction == DistanceFunction.COSINE_SIMILARITY
552+
|| distanceFunction == DistanceFunction.DOT_PRODUCT) {
553+
comparator = comparator.reversed();
554+
}
555+
return results.stream()
556+
.sorted(comparator)
557+
.skip(options.getOffset())
558+
.limit(options.getLimit())
559+
.collect(Collectors.toList());
569560
}
570561

571562
/**

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreRecordMapper;
88
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
99
import com.microsoft.semantickernel.data.vectorsearch.VectorizedSearch;
10-
import com.microsoft.semantickernel.data.vectorsearch.queries.VectorSearchQuery;
1110
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordMapper;
1211
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordCollection;
1312
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
@@ -298,14 +297,6 @@ public Mono<Void> prepareAsync() {
298297
.subscribeOn(Schedulers.boundedElastic()).then();
299298
}
300299

301-
@Override
302-
public Mono<List<VectorSearchResult<Record>>> searchAsync(VectorSearchQuery query) {
303-
return Mono.fromCallable(
304-
() -> queryProvider.search(this.collectionName, query, recordDefinition,
305-
vectorStoreRecordMapper))
306-
.subscribeOn(Schedulers.boundedElastic());
307-
}
308-
309300
/**
310301
* Vectorized search. This method searches for records that are similar to the given vector.
311302
*
@@ -316,7 +307,11 @@ public Mono<List<VectorSearchResult<Record>>> searchAsync(VectorSearchQuery quer
316307
@Override
317308
public Mono<List<VectorSearchResult<Record>>> searchAsync(List<Float> vector,
318309
VectorSearchOptions vectorSearchOptions) {
319-
return this.searchAsync(VectorSearchQuery.createQuery(vector, vectorSearchOptions));
310+
return Mono.fromCallable(
311+
() -> queryProvider.search(this.collectionName, vector, vectorSearchOptions,
312+
recordDefinition,
313+
vectorStoreRecordMapper))
314+
.subscribeOn(Schedulers.boundedElastic());
320315
}
321316

322317
public static class Builder<Record>

0 commit comments

Comments
 (0)