Skip to content

Commit 7e4b348

Browse files
author
Milder Hernandez Cagua
committed
Add VolatileVectorStoreRecordCollection search
1 parent 4135ad5 commit 7e4b348

File tree

12 files changed

+399
-86
lines changed

12 files changed

+399
-86
lines changed

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ public void buildRecordCollection(QueryProvider provider) {
100100
}
101101

102102
private List<Hotel> getHotels() {
103-
ArrayList<Hotel> embeddings = new ArrayList<>();
104-
105-
return List.of(
103+
return Arrays.asList(
106104
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f),null, null, null, 4.0),
107105
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),null, null, null, 4.0),
108106
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f),null, null, null, 5.0),

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/redis/RedisHashSetVectorStoreRecordCollectionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ private static RedisHashSetVectorStoreRecordCollection<Hotel> createCollection(@
132132
}
133133

134134
private static List<Hotel> getHotels() {
135-
return List.of(
135+
return Arrays.asList(
136136
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f),null, null, 4.0),
137137
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),null, null, 4.0),
138138
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f),null, null, 5.0),

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/redis/RedisJsonVectorStoreRecordCollectionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ private static RedisJsonVectorStoreRecordCollection<Hotel> createCollection(@Non
132132
}
133133

134134
private static List<Hotel> getHotels() {
135-
return List.of(
135+
return Arrays.asList(
136136
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f),null, null, 4.0),
137137
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),null, null, 4.0),
138138
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f),null, null, 5.0),

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -502,61 +502,13 @@ public <Record> List<VectorSearchResult<Record>> search(String collectionName,
502502

503503
List<Record> records = getRecordsWithFilter(collectionName, recordDefinition, mapper,
504504
new GetRecordOptions(true), filter, parameters);
505-
List<VectorSearchResult<Record>> results = new ArrayList<>();
506505

507506
DistanceFunction distanceFunction = vectorField.getDistanceFunction() == null
508507
? DistanceFunction.EUCLIDEAN_DISTANCE
509508
: vectorField.getDistanceFunction();
510509

511-
for (Record record : records) {
512-
List<Float> recordVector;
513-
try {
514-
String json = new ObjectMapper().writeValueAsString(record);
515-
ArrayNode arrayNode = (ArrayNode) new ObjectMapper().readTree(json)
516-
.get(vectorField.getEffectiveStorageName());
517-
518-
recordVector = Stream.iterate(0, i -> i + 1)
519-
.limit(arrayNode.size())
520-
.map(i -> arrayNode.get(i).floatValue())
521-
.collect(Collectors.toList());
522-
} catch (JsonProcessingException e) {
523-
throw new RuntimeException(e);
524-
}
525-
526-
double score;
527-
switch (distanceFunction) {
528-
case COSINE_SIMILARITY:
529-
score = VectorOperations.cosineSimilarity(vector, recordVector);
530-
break;
531-
case COSINE_DISTANCE:
532-
score = VectorOperations.cosineDistance(vector, recordVector);
533-
break;
534-
case EUCLIDEAN_DISTANCE:
535-
score = VectorOperations
536-
.euclideanDistance(vector, recordVector);
537-
break;
538-
case DOT_PRODUCT:
539-
score = VectorOperations.dot(vector, recordVector);
540-
break;
541-
default:
542-
throw new SKException("Unsupported distance function");
543-
}
544-
545-
results.add(new VectorSearchResult<>(record, score));
546-
}
547-
548-
Comparator<VectorSearchResult<Record>> comparator = Comparator
549-
.comparingDouble(VectorSearchResult::getScore);
550-
// Higher scores are better
551-
if (distanceFunction == DistanceFunction.COSINE_SIMILARITY
552-
|| distanceFunction == DistanceFunction.DOT_PRODUCT) {
553-
comparator = comparator.reversed();
554-
}
555-
return results.stream()
556-
.sorted(comparator)
557-
.skip(options.getOffset())
558-
.limit(options.getLimit())
559-
.collect(Collectors.toList());
510+
return VectorOperations.exactSimilaritySearch(records, vector, vectorField,
511+
distanceFunction, options);
560512
}
561513

562514
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.data;
3+
4+
import com.fasterxml.jackson.databind.JsonNode;
5+
import com.fasterxml.jackson.databind.ObjectMapper;
6+
import com.microsoft.semantickernel.data.filter.EqualToFilterClause;
7+
import com.microsoft.semantickernel.data.filter.FilterClause;
8+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchFilter;
9+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
10+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordField;
11+
import com.microsoft.semantickernel.exceptions.SKException;
12+
13+
import java.util.List;
14+
import java.util.stream.Collectors;
15+
16+
public class VolatileVectorStoreCollectionSearchMapping {
17+
18+
public static <Record> List<Record> filterRecords(List<Record> records,
19+
VectorSearchFilter filter,
20+
VectorStoreRecordDefinition recordDefinition, ObjectMapper objectMapper) {
21+
if (filter == null || filter.getFilterClauses().isEmpty()) {
22+
return records;
23+
}
24+
25+
return records.stream().filter(
26+
record -> {
27+
JsonNode recordNode = objectMapper.valueToTree(record);
28+
29+
for (FilterClause filterClause : filter.getFilterClauses()) {
30+
if (filterClause instanceof EqualToFilterClause) {
31+
EqualToFilterClause equalToFilterClause = (EqualToFilterClause) filterClause;
32+
VectorStoreRecordField field = recordDefinition
33+
.getField(equalToFilterClause.getFieldName());
34+
35+
Object value = objectMapper.convertValue(
36+
recordNode.get(field.getEffectiveStorageName()), field.getFieldType());
37+
if (!equalToFilterClause.getValue().equals(value)) {
38+
return false;
39+
}
40+
} else {
41+
throw new SKException(String.format("Unsupported filter clause type '%s'.",
42+
filterClause.getClass().getSimpleName()));
43+
}
44+
}
45+
return true;
46+
}).collect(Collectors.toList());
47+
}
48+
}

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,36 @@
11
// Copyright (c) Microsoft. All rights reserved.
22
package com.microsoft.semantickernel.data;
33

4+
import com.fasterxml.jackson.core.JsonProcessingException;
5+
import com.fasterxml.jackson.databind.JsonNode;
46
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.fasterxml.jackson.databind.node.ArrayNode;
58
import com.fasterxml.jackson.databind.node.ObjectNode;
9+
import com.microsoft.semantickernel.data.filter.FilterClause;
10+
import com.microsoft.semantickernel.data.vectorsearch.VectorOperations;
11+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
612
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordCollection;
13+
import com.microsoft.semantickernel.data.vectorstorage.definition.DistanceFunction;
714
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
15+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordVectorField;
816
import com.microsoft.semantickernel.data.vectorstorage.options.DeleteRecordOptions;
917
import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions;
1018
import com.microsoft.semantickernel.data.vectorstorage.options.UpsertRecordOptions;
19+
import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions;
1120
import com.microsoft.semantickernel.exceptions.SKException;
21+
22+
import java.util.ArrayList;
1223
import java.util.Collections;
24+
import java.util.Comparator;
1325
import java.util.HashSet;
1426
import java.util.List;
1527
import java.util.Map;
1628
import java.util.concurrent.ConcurrentHashMap;
1729
import java.util.stream.Collectors;
30+
import java.util.stream.Stream;
31+
1832
import reactor.core.publisher.Mono;
33+
import reactor.core.scheduler.Schedulers;
1934

2035
public class VolatileVectorStoreRecordCollection<Record> implements
2136
VectorStoreRecordCollection<String, Record> {
@@ -33,7 +48,6 @@ public VolatileVectorStoreRecordCollection(String collectionName,
3348
this.collectionName = collectionName;
3449
this.options = options;
3550
this.collections = new ConcurrentHashMap<>();
36-
this.objectMapper = new ObjectMapper();
3751

3852
if (options.getRecordDefinition() != null) {
3953
this.recordDefinition = options.getRecordDefinition();
@@ -42,6 +56,12 @@ public VolatileVectorStoreRecordCollection(String collectionName,
4256
.fromRecordClass(this.options.getRecordClass());
4357
}
4458

59+
if (options.getObjectMapper() == null) {
60+
this.objectMapper = new ObjectMapper();
61+
} else {
62+
this.objectMapper = options.getObjectMapper();
63+
}
64+
4565
// Validate the key type
4666
VectorStoreRecordDefinition.validateSupportedTypes(
4767
Collections.singletonList(recordDefinition.getKeyField()),
@@ -222,4 +242,49 @@ private Map<String, Record> getCollection() {
222242
}
223243
return (Map<String, Record>) collections.get(collectionName);
224244
}
245+
246+
private List<Float> arrayNodeToFloatList(ArrayNode arrayNode) {
247+
return Stream.iterate(0, i -> i + 1)
248+
.limit(arrayNode.size())
249+
.map(i -> arrayNode.get(i).floatValue())
250+
.collect(Collectors.toList());
251+
}
252+
253+
/**
254+
* Vectorized search. This method searches for records that are similar to the given vector.
255+
*
256+
* @param vector The vector to search with.
257+
* @param options The options to use for the search.
258+
* @return A list of search results.
259+
*/
260+
@Override
261+
public Mono<List<VectorSearchResult<Record>>> searchAsync(List<Float> vector,
262+
final VectorSearchOptions options) {
263+
if (recordDefinition.getVectorFields().isEmpty()) {
264+
throw new SKException("No vector fields defined. Cannot perform vector search");
265+
}
266+
267+
return Mono.fromCallable(() -> {
268+
VectorStoreRecordVectorField firstVectorField = recordDefinition.getVectorFields()
269+
.get(0);
270+
VectorSearchOptions effectiveOptions = options == null
271+
? VectorSearchOptions.createDefault(firstVectorField.getName())
272+
: options;
273+
274+
VectorStoreRecordVectorField vectorField = effectiveOptions.getVectorFieldName() == null
275+
? firstVectorField
276+
: (VectorStoreRecordVectorField) recordDefinition
277+
.getField(effectiveOptions.getVectorFieldName());
278+
279+
DistanceFunction distanceFunction = vectorField.getDistanceFunction() == null
280+
? DistanceFunction.EUCLIDEAN_DISTANCE
281+
: vectorField.getDistanceFunction();
282+
283+
List<Record> records = VolatileVectorStoreCollectionSearchMapping.filterRecords(
284+
new ArrayList<>(getCollection().values()), effectiveOptions.getVectorSearchFilter(),
285+
recordDefinition, objectMapper);
286+
return VectorOperations.exactSimilaritySearch(records, vector, vectorField,
287+
distanceFunction, effectiveOptions);
288+
}).subscribeOn(Schedulers.boundedElastic());
289+
}
225290
}

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollectionOptions.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
// Copyright (c) Microsoft. All rights reserved.
22
package com.microsoft.semantickernel.data;
33

4+
import com.fasterxml.jackson.databind.ObjectMapper;
45
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordCollectionOptions;
56
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
7+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
68

79
import javax.annotation.Nonnull;
810
import javax.annotation.Nullable;
@@ -12,17 +14,21 @@ public class VolatileVectorStoreRecordCollectionOptions<Record>
1214
private final Class<Record> recordClass;
1315
@Nullable
1416
private final VectorStoreRecordDefinition recordDefinition;
17+
@Nullable
18+
private final ObjectMapper objectMapper;
1519

1620
/**
1721
* Creates a new instance of the Volatile vector store record collection options.
1822
*
1923
* @param recordClass The record class.
2024
* @param recordDefinition The record definition.
2125
*/
26+
@SuppressFBWarnings("EI_EXPOSE_REP2") // ObjectMapper only has package visibility
2227
public VolatileVectorStoreRecordCollectionOptions(@Nonnull Class<Record> recordClass,
23-
@Nullable VectorStoreRecordDefinition recordDefinition) {
28+
@Nullable VectorStoreRecordDefinition recordDefinition, ObjectMapper objectMapper) {
2429
this.recordClass = recordClass;
2530
this.recordDefinition = recordDefinition;
31+
this.objectMapper = objectMapper;
2632
}
2733

2834
/**
@@ -54,6 +60,15 @@ public Class<Record> getRecordClass() {
5460
return recordClass;
5561
}
5662

63+
/**
64+
* Gets the object mapper.
65+
*
66+
* @return the object mapper
67+
*/
68+
ObjectMapper getObjectMapper() {
69+
return objectMapper;
70+
}
71+
5772
/**
5873
* Gets the record definition.
5974
*
@@ -73,6 +88,8 @@ public static class Builder<Record> {
7388
private Class<Record> recordClass;
7489
@Nullable
7590
private VectorStoreRecordDefinition recordDefinition;
91+
@Nullable
92+
private ObjectMapper objectMapper;
7693

7794
/**
7895
* Sets the record class.
@@ -96,6 +113,18 @@ public Builder<Record> withRecordDefinition(VectorStoreRecordDefinition recordDe
96113
return this;
97114
}
98115

116+
/**
117+
* Sets the object mapper.
118+
*
119+
* @param objectMapper the object mapper
120+
* @return the builder
121+
*/
122+
@SuppressFBWarnings("EI_EXPOSE_REP2")
123+
public Builder<Record> withObjectMapper(ObjectMapper objectMapper) {
124+
this.objectMapper = objectMapper;
125+
return this;
126+
}
127+
99128
/**
100129
* Builds the options.
101130
*
@@ -106,7 +135,8 @@ public VolatileVectorStoreRecordCollectionOptions<Record> build() {
106135
throw new IllegalArgumentException("recordClass is required");
107136
}
108137

109-
return new VolatileVectorStoreRecordCollectionOptions<>(recordClass, recordDefinition);
138+
return new VolatileVectorStoreRecordCollectionOptions<>(recordClass, recordDefinition,
139+
objectMapper);
110140
}
111141
}
112142
}

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/filter/EqualToFilterClause.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22
package com.microsoft.semantickernel.data.filter;
33

4-
public abstract class EqualToFilterClause implements FilterClause {
4+
public class EqualToFilterClause implements FilterClause {
55

66
private final String fieldName;
77
private final Object value;
@@ -28,4 +28,15 @@ public String getFieldName() {
2828
public Object getValue() {
2929
return value;
3030
}
31+
32+
/**
33+
* Gets the filter string.
34+
*
35+
* @return The filter.
36+
*/
37+
@Override
38+
public String getFilter() {
39+
throw new UnsupportedOperationException(String.format(
40+
"Not implemented. Use one of %s derived classes.", this.getClass().getSimpleName()));
41+
}
3142
}

0 commit comments

Comments
 (0)