Skip to content

Commit 5306263

Browse files
author
Milder Hernandez Cagua
committed
Add SQLiteVectorStoreQueryProvider
1 parent 7e4b348 commit 5306263

File tree

4 files changed

+194
-14
lines changed

4 files changed

+194
-14
lines changed

api-test/integration-tests/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@
7878
<artifactId>postgresql</artifactId>
7979
<version>42.7.3</version> <!-- Use the latest version -->
8080
</dependency>
81+
<dependency>
82+
<groupId>org.xerial</groupId>
83+
<artifactId>sqlite-jdbc</artifactId>
84+
<version>3.46.1.0</version>
85+
</dependency>
8186

8287
<dependency>
8388
<groupId>org.testcontainers</groupId>

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
77
import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
88
import com.microsoft.semantickernel.connectors.data.jdbc.filter.SQLEqualToFilterClause;
9+
import com.microsoft.semantickernel.connectors.data.sqlite.SQLiteVectorStoreQueryProvider;
910
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchFilter;
1011
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
1112
import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions;
@@ -17,6 +18,7 @@
1718
import org.junit.jupiter.params.provider.EnumSource;
1819
import org.junit.jupiter.params.provider.MethodSource;
1920
import org.postgresql.ds.PGSimpleDataSource;
21+
import org.sqlite.SQLiteDataSource;
2022
import org.testcontainers.containers.MySQLContainer;
2123
import org.testcontainers.containers.PostgreSQLContainer;
2224
import org.testcontainers.junit.jupiter.Container;
@@ -47,7 +49,8 @@ public class JDBCVectorStoreRecordCollectionTest {
4749

4850
public enum QueryProvider {
4951
MySQL,
50-
PostgreSQL
52+
PostgreSQL,
53+
SQLite
5154
}
5255

5356
private JDBCVectorStoreRecordCollection<Hotel> buildRecordCollection(QueryProvider provider, @Nonnull String collectionName) {
@@ -75,6 +78,14 @@ private JDBCVectorStoreRecordCollection<Hotel> buildRecordCollection(QueryProvid
7578
.withDataSource(dataSource)
7679
.build();
7780
break;
81+
case SQLite:
82+
SQLiteDataSource sqliteDataSource = new SQLiteDataSource();
83+
sqliteDataSource.setUrl("jdbc:sqlite:file:testdb");
84+
dataSource = sqliteDataSource;
85+
queryProvider = SQLiteVectorStoreQueryProvider.builder()
86+
.withDataSource(sqliteDataSource)
87+
.build();
88+
break;
7889
default:
7990
throw new IllegalArgumentException("Unknown query provider: " + provider);
8091
}
@@ -131,7 +142,7 @@ public void upsertAndGetRecordAsync(QueryProvider provider) {
131142

132143
// Upsert the first time
133144
for (Hotel hotel : hotels) {
134-
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
145+
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
135146
assertNotNull(retrievedHotel);
136147
assertEquals(hotel.getId(), retrievedHotel.getId());
137148
assertEquals(hotel.getRating(), retrievedHotel.getRating());
@@ -146,7 +157,7 @@ public void upsertAndGetRecordAsync(QueryProvider provider) {
146157
}
147158

148159
for (Hotel hotel : hotels) {
149-
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
160+
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
150161
assertNotNull(retrievedHotel);
151162
assertEquals(hotel.getId(), retrievedHotel.getId());
152163
assertEquals(1.0, retrievedHotel.getRating());
@@ -157,7 +168,7 @@ public void upsertAndGetRecordAsync(QueryProvider provider) {
157168
@EnumSource(QueryProvider.class)
158169
public void getBatchAsync(QueryProvider provider) {
159170
String collectionName = "getBatchAsync";
160-
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
171+
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
161172

162173
List<Hotel> hotels = getHotels();
163174
for (Hotel hotel : hotels) {
@@ -169,7 +180,7 @@ public void getBatchAsync(QueryProvider provider) {
169180
keys.add(hotel.getId());
170181
}
171182

172-
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
183+
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
173184
assertNotNull(retrievedHotels);
174185
assertEquals(hotels.size(), retrievedHotels.size());
175186
}
@@ -178,7 +189,7 @@ public void getBatchAsync(QueryProvider provider) {
178189
@EnumSource(QueryProvider.class)
179190
public void upsertBatchAndGetBatchAsync(QueryProvider provider) {
180191
String collectionName = "upsertBatchAndGetBatchAsync";
181-
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
192+
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
182193

183194
List<Hotel> hotels = getHotels();
184195
recordCollection.upsertBatchAsync(hotels, null).block();
@@ -188,7 +199,7 @@ public void upsertBatchAndGetBatchAsync(QueryProvider provider) {
188199
keys.add(hotel.getId());
189200
}
190201

191-
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
202+
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
192203
assertNotNull(retrievedHotels);
193204
assertEquals(hotels.size(), retrievedHotels.size());
194205
}
@@ -209,7 +220,7 @@ public void insertAndReplaceAsync(QueryProvider provider) {
209220
keys.add(hotel.getId());
210221
}
211222

212-
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
223+
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
213224
assertNotNull(retrievedHotels);
214225
assertEquals(hotels.size(), retrievedHotels.size());
215226
}
@@ -225,7 +236,7 @@ public void deleteRecordAsync(QueryProvider provider) {
225236

226237
for (Hotel hotel : hotels) {
227238
recordCollection.deleteAsync(hotel.getId(), null).block();
228-
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
239+
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
229240
assertNull(retrievedHotel);
230241
}
231242
}
@@ -247,7 +258,7 @@ public void deleteBatchAsync(QueryProvider provider) {
247258
recordCollection.deleteBatchAsync(keys, null).block();
248259

249260
for (String key : keys) {
250-
Hotel retrievedHotel = recordCollection.getAsync(key, null).block();
261+
Hotel retrievedHotel = recordCollection.getAsync(key, null).block();
251262
assertNull(retrievedHotel);
252263
}
253264
}
@@ -266,7 +277,7 @@ public void getWithNoVectors(QueryProvider provider) {
266277
.build();
267278

268279
for (Hotel hotel : hotels) {
269-
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block();
280+
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block();
270281
assertNotNull(retrievedHotel);
271282
assertEquals(hotel.getId(), retrievedHotel.getId());
272283
assertNull(retrievedHotel.getEuclidean());
@@ -277,7 +288,7 @@ public void getWithNoVectors(QueryProvider provider) {
277288
.build();
278289

279290
for (Hotel hotel : hotels) {
280-
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block();
291+
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block();
281292
assertNotNull(retrievedHotel);
282293
assertEquals(hotel.getId(), retrievedHotel.getId());
283294
assertNotNull(retrievedHotel.getEuclidean());
@@ -302,7 +313,7 @@ public void getBatchWithNoVectors(QueryProvider provider) {
302313
keys.add(hotel.getId());
303314
}
304315

305-
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(keys, options).block();
316+
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(keys, options).block();
306317
assertNotNull(retrievedHotels);
307318
assertEquals(hotels.size(), retrievedHotels.size());
308319

@@ -314,7 +325,7 @@ public void getBatchWithNoVectors(QueryProvider provider) {
314325
.includeVectors(true)
315326
.build();
316327

317-
retrievedHotels = recordCollection.getBatchAsync(keys, options).block();
328+
retrievedHotels = recordCollection.getBatchAsync(keys, options).block();
318329
assertNotNull(retrievedHotels);
319330
assertEquals(hotels.size(), retrievedHotels.size());
320331

semantickernel-experimental/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@
112112
<artifactId>postgresql</artifactId>
113113
<version>42.7.3</version> <!-- Use the latest version -->
114114
</dependency>
115+
<dependency>
116+
<groupId>org.xerial</groupId>
117+
<artifactId>sqlite-jdbc</artifactId>
118+
<version>3.46.1.0</version>
119+
</dependency>
115120

116121
</dependencies>
117122

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.connectors.data.sqlite;
3+
4+
import com.fasterxml.jackson.core.JsonProcessingException;
5+
import com.fasterxml.jackson.databind.JsonNode;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
8+
import com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider;
9+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
10+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordField;
11+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordVectorField;
12+
import com.microsoft.semantickernel.data.vectorstorage.options.UpsertRecordOptions;
13+
import com.microsoft.semantickernel.exceptions.SKException;
14+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
15+
16+
import javax.annotation.Nonnull;
17+
import javax.sql.DataSource;
18+
import java.sql.Connection;
19+
import java.sql.PreparedStatement;
20+
import java.sql.SQLException;
21+
import java.util.List;
22+
import java.util.stream.Collectors;
23+
24+
public class SQLiteVectorStoreQueryProvider extends
25+
JDBCVectorStoreQueryProvider implements SQLVectorStoreQueryProvider {
26+
27+
private final DataSource dataSource;
28+
private final ObjectMapper objectMapper;
29+
30+
@SuppressFBWarnings("EI_EXPOSE_REP2")
31+
private SQLiteVectorStoreQueryProvider(
32+
@Nonnull DataSource dataSource,
33+
@Nonnull String collectionsTable,
34+
@Nonnull String prefixForCollectionTables,
35+
@Nonnull ObjectMapper objectMapper) {
36+
super(dataSource, collectionsTable, prefixForCollectionTables);
37+
this.dataSource = dataSource;
38+
this.objectMapper = objectMapper;
39+
}
40+
41+
/**
42+
* Creates a new builder.
43+
* @return the builder
44+
*/
45+
public static Builder builder() {
46+
return new Builder();
47+
}
48+
49+
private void setUpsertStatementValues(PreparedStatement statement, Object record,
50+
List<VectorStoreRecordField> fields) {
51+
JsonNode jsonNode = objectMapper.valueToTree(record);
52+
53+
for (int i = 0; i < fields.size(); ++i) {
54+
VectorStoreRecordField field = fields.get(i);
55+
try {
56+
JsonNode valueNode = jsonNode.get(field.getEffectiveStorageName());
57+
58+
if (field instanceof VectorStoreRecordVectorField) {
59+
// Convert the vector field to a string
60+
if (!field.getFieldType().equals(String.class)) {
61+
statement.setObject(i + 1, objectMapper.writeValueAsString(valueNode));
62+
continue;
63+
}
64+
}
65+
66+
statement.setObject(i + 1,
67+
objectMapper.convertValue(valueNode, field.getFieldType()));
68+
} catch (SQLException | JsonProcessingException e) {
69+
throw new RuntimeException(e);
70+
}
71+
}
72+
}
73+
74+
/**
75+
* Upserts records into the collection.
76+
* @param collectionName the collection name
77+
* @param records the records to upsert
78+
* @param recordDefinition the record definition
79+
* @param options the upsert options
80+
* @throws SKException if the upsert fails
81+
*/
82+
@Override
83+
@SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers
84+
public void upsertRecords(String collectionName, List<?> records,
85+
VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) {
86+
List<VectorStoreRecordField> fields = recordDefinition.getAllFields();
87+
88+
String query = formatQuery("INSERT OR REPLACE INTO %s (%s) VALUES (%s)",
89+
getCollectionTableName(collectionName),
90+
getQueryColumnsFromFields(fields),
91+
getWildcardString(fields.size()));
92+
93+
try (Connection connection = dataSource.getConnection();
94+
PreparedStatement statement = connection.prepareStatement(query)) {
95+
for (Object record : records) {
96+
setUpsertStatementValues(statement, record, recordDefinition.getAllFields());
97+
statement.addBatch();
98+
}
99+
100+
statement.executeBatch();
101+
} catch (SQLException e) {
102+
throw new SKException("Failed to upsert records", e);
103+
}
104+
}
105+
106+
public static class Builder
107+
extends JDBCVectorStoreQueryProvider.Builder {
108+
private DataSource dataSource;
109+
private String collectionsTable = DEFAULT_COLLECTIONS_TABLE;
110+
private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES;
111+
private ObjectMapper objectMapper = new ObjectMapper();
112+
113+
@SuppressFBWarnings("EI_EXPOSE_REP2")
114+
public Builder withDataSource(DataSource dataSource) {
115+
this.dataSource = dataSource;
116+
return this;
117+
}
118+
119+
/**
120+
* Sets the collections table name.
121+
* @param collectionsTable the collections table name
122+
* @return the builder
123+
*/
124+
public Builder withCollectionsTable(String collectionsTable) {
125+
this.collectionsTable = validateSQLidentifier(collectionsTable);
126+
return this;
127+
}
128+
129+
/**
130+
* Sets the prefix for collection tables.
131+
* @param prefixForCollectionTables the prefix for collection tables
132+
* @return the builder
133+
*/
134+
public Builder withPrefixForCollectionTables(String prefixForCollectionTables) {
135+
this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables);
136+
return this;
137+
}
138+
139+
/**
140+
* Sets the object mapper.
141+
* @param objectMapper the object mapper
142+
* @return the builder
143+
*/
144+
@SuppressFBWarnings("EI_EXPOSE_REP2")
145+
public Builder withObjectMapper(ObjectMapper objectMapper) {
146+
this.objectMapper = objectMapper;
147+
return this;
148+
}
149+
150+
public SQLiteVectorStoreQueryProvider build() {
151+
if (dataSource == null) {
152+
throw new SKException("DataSource is required");
153+
}
154+
155+
return new SQLiteVectorStoreQueryProvider(dataSource, collectionsTable,
156+
prefixForCollectionTables, objectMapper);
157+
}
158+
}
159+
}

0 commit comments

Comments
 (0)