Skip to content

Commit 47cfed7

Browse files
committed
test: upgrade testcontainers version and refactor some tests
1 parent 228c912 commit 47cfed7

File tree

84 files changed

+645
-643
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+645
-643
lines changed

content-retrievers/langchain4j-community-lucene/src/test/java/dev/langchain4j/community/rag/content/retriever/lucene/EmbeddingSearchIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void hybridQuery1() {
5353
results.stream().map(match -> match.embedded().text()).collect(Collectors.toList());
5454

5555
assertThat(results).hasSize(3);
56-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
56+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
5757
}
5858

5959
@Test
@@ -75,7 +75,7 @@ void hybridQuery2() {
7575
results.stream().map(match -> match.embedded().text()).collect(Collectors.toList());
7676

7777
assertThat(results).hasSize(1);
78-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
78+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
7979
}
8080

8181
@Test
@@ -91,7 +91,7 @@ void hybridQuery3() {
9191

9292
debugQuery(query, results);
9393

94-
assertThat(results).hasSize(0);
94+
assertThat(results).isEmpty();
9595
}
9696

9797
@BeforeEach

content-retrievers/langchain4j-community-lucene/src/test/java/dev/langchain4j/community/rag/content/retriever/lucene/FullTextSearchIT.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void query1() {
5858
Collections.sort(actualTextSegments);
5959

6060
assertThat(results).hasSize(1);
61-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
61+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
6262
}
6363

6464
@Test
@@ -86,7 +86,7 @@ void queryAll() {
8686
Collections.sort(actualTextSegments);
8787

8888
assertThat(results).hasSize(hits.length + misses.length);
89-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
89+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
9090
}
9191

9292
@Test
@@ -109,8 +109,8 @@ void queryContent() {
109109
results.stream().map(content -> content.textSegment().text()).collect(Collectors.toList());
110110
Collections.sort(actualTextSegments);
111111

112-
assertThat(results).hasSize(hits.length);
113-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
112+
assertThat(results).hasSameSizeAs(hits);
113+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
114114
}
115115

116116
@Test
@@ -131,7 +131,7 @@ void queryWithMaxTokens() {
131131
Collections.sort(actualTextSegments);
132132

133133
assertThat(results).hasSize(1);
134-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
134+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
135135
}
136136

137137
@Test
@@ -185,7 +185,7 @@ void queryWithMinScore() {
185185
Collections.sort(actualTextSegments);
186186

187187
assertThat(results).hasSize(2);
188-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
188+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
189189
}
190190

191191
@Test
@@ -200,7 +200,7 @@ void retrieverWithBadTokenCountField() {
200200
List<Content> results = contentRetriever.retrieve(query);
201201

202202
// No limiting by token count, since wrong field is used
203-
assertThat(results).hasSize(hits.length);
203+
assertThat(results).hasSameSizeAs(hits);
204204
}
205205

206206
@Test

content-retrievers/langchain4j-community-lucene/src/test/java/dev/langchain4j/community/rag/content/retriever/lucene/HybridSearchIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import org.slf4j.LoggerFactory;
2020

2121
class HybridSearchIT {
22-
22+
2323
private static final Logger log = LoggerFactory.getLogger(HybridSearchIT.class);
2424

2525
private static final TextEmbedding[] hits = {
@@ -59,7 +59,7 @@ void hybridQuery1() {
5959
results.stream().map(content -> content.textSegment().text()).collect(Collectors.toList());
6060

6161
assertThat(results).hasSize(3);
62-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
62+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
6363
}
6464

6565
@Test
@@ -84,7 +84,7 @@ void hybridQuery2() {
8484
results.stream().map(content -> content.textSegment().text()).collect(Collectors.toList());
8585

8686
assertThat(results).hasSize(1);
87-
assertThat(actualTextSegments).isEqualTo(expectedTextSegments);
87+
assertThat(actualTextSegments).containsExactlyElementsOf(expectedTextSegments);
8888
}
8989

9090
@Test
@@ -103,7 +103,7 @@ void hybridQuery3() {
103103
List<Content> results = contentRetriever.retrieve(Query.from(queryText));
104104
debugQuery(query, results);
105105

106-
assertThat(results).hasSize(0);
106+
assertThat(results).isEmpty();
107107
}
108108

109109
@BeforeEach

content-retrievers/langchain4j-community-lucene/src/test/java/dev/langchain4j/community/rag/content/retriever/lucene/IndexerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void addAllEmbeddings() {
5353

5454
List<Content> results = contentRetriever.retrieve(query);
5555

56-
assertThat(results).hasSize(0);
56+
assertThat(results).isEmpty();
5757
}
5858

5959
@Test

content-retrievers/langchain4j-community-neo4j-retriever/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@
6262
<!-- test dependencies -->
6363
<dependency>
6464
<groupId>org.testcontainers</groupId>
65-
<artifactId>junit-jupiter</artifactId>
65+
<artifactId>testcontainers-junit-jupiter</artifactId>
6666
<scope>test</scope>
6767
</dependency>
6868

6969
<dependency>
7070
<groupId>org.testcontainers</groupId>
71-
<artifactId>neo4j</artifactId>
71+
<artifactId>testcontainers-neo4j</artifactId>
7272
<scope>test</scope>
7373
</dependency>
7474

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jEmbeddingStoreIngestorBaseTest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package dev.langchain4j.community.rag.content.retriever.neo4j;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4-
import static org.junit.jupiter.api.Assertions.assertEquals;
54

65
import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingStore;
76
import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingStoreIngestor;
@@ -43,7 +42,7 @@ WITH parent, collect(node.text) AS chunks, max(score) AS score
4342
protected static EmbeddingModel embeddingModel;
4443

4544
@BeforeAll
46-
public static void beforeAll() {
45+
static void beforeAll() {
4746
Neo4jContainerBaseTest.beforeAll();
4847

4948
embeddingStore =
@@ -130,8 +129,8 @@ protected static void commonResults(List<Content> results, String... retrieveQue
130129
Content result = results.get(0);
131130

132131
assertThat(result.textSegment().text().toLowerCase()).containsIgnoringWhitespaces(retrieveQuery);
133-
assertEquals("Wikipedia link", result.textSegment().metadata().getString("source"));
134-
assertEquals("https://example.com/ai", result.textSegment().metadata().getString("url"));
132+
assertThat(result.textSegment().metadata().getString("source")).isEqualTo("Wikipedia link");
133+
assertThat(result.textSegment().metadata().getString("url")).isEqualTo("https://example.com/ai");
135134
}
136135

137136
protected static EmbeddingStoreContentRetriever getEmbeddingStoreContentRetriever(

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jEmbeddingStoreIngestorIT.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2525

2626
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
27-
public class Neo4jEmbeddingStoreIngestorIT extends Neo4jEmbeddingStoreIngestorBaseTest {
27+
class Neo4jEmbeddingStoreIngestorIT extends Neo4jEmbeddingStoreIngestorBaseTest {
2828

2929
ChatModel chatModel = OpenAiChatModel.builder()
3030
.baseUrl(System.getenv("OPENAI_BASE_URL"))
@@ -117,8 +117,7 @@ void testRetrieverWithCustomAnswerModelAndPrompt() {
117117
.build();
118118

119119
final String chainResult = chain.execute(Query.from(retrieveQuery));
120-
assertThat(chainResult).containsIgnoringCase("dattebayo");
121-
assertThat(chainResult).containsIgnoringCase("super saiyan");
120+
assertThat(chainResult).containsIgnoringCase("dattebayo").containsIgnoringCase("super saiyan");
122121

123122
RetrievalQAChain chainWithPromptBuilder = RetrievalQAChain.builder()
124123
.chatModel(chatModel)
@@ -127,17 +126,18 @@ void testRetrieverWithCustomAnswerModelAndPrompt() {
127126
.build();
128127

129128
final String chainResultWithPromptBuilder = chainWithPromptBuilder.execute(Query.from(retrieveQuery));
130-
assertThat(chainResultWithPromptBuilder).containsIgnoringCase("dattebayo");
131-
assertThat(chainResultWithPromptBuilder).containsIgnoringCase("super saiyan");
129+
assertThat(chainResultWithPromptBuilder)
130+
.containsIgnoringCase("dattebayo")
131+
.containsIgnoringCase("super saiyan");
132132
}
133133

134134
@Test
135-
public void testSummaryGraphIngestor() {
135+
void testSummaryGraphIngestor() {
136136
summaryGraphIngestorCommon(chatModel);
137137
}
138138

139139
@Test
140-
public void testHypotheticalQuestionIngestor() {
140+
void testHypotheticalQuestionIngestor() {
141141
hypotheticalQuestionIngestorCommon(chatModel);
142142
}
143143
}

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jEmbeddingStoreIngestorQAChainTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
import org.mockito.junit.jupiter.MockitoExtension;
2828

2929
@ExtendWith(MockitoExtension.class)
30-
public class Neo4jEmbeddingStoreIngestorQAChainTest extends Neo4jEmbeddingStoreIngestorBaseTest {
30+
class Neo4jEmbeddingStoreIngestorQAChainTest extends Neo4jEmbeddingStoreIngestorBaseTest {
3131
@Mock
3232
private ChatModel chatLanguageModel;
3333

3434
@Test
35-
public void testBasicRetrieverWithChatQuestionAndAnswerModel() {
35+
void testBasicRetrieverWithChatQuestionAndAnswerModel() {
3636

3737
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
3838
.driver(driver)

content-retrievers/langchain4j-community-neo4j-retriever/src/test/java/dev/langchain4j/community/rag/content/retriever/neo4j/Neo4jEmbeddingStoreIngestorTest.java

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package dev.langchain4j.community.rag.content.retriever.neo4j;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4-
import static org.junit.jupiter.api.Assertions.assertEquals;
5-
import static org.junit.jupiter.api.Assertions.assertFalse;
6-
import static org.junit.jupiter.api.Assertions.assertTrue;
74
import static org.mockito.ArgumentMatchers.anyList;
85
import static org.mockito.Mockito.when;
96

@@ -33,13 +30,13 @@
3330
import org.mockito.junit.jupiter.MockitoExtension;
3431

3532
@ExtendWith(MockitoExtension.class)
36-
public class Neo4jEmbeddingStoreIngestorTest extends Neo4jEmbeddingStoreIngestorBaseTest {
33+
class Neo4jEmbeddingStoreIngestorTest extends Neo4jEmbeddingStoreIngestorBaseTest {
3734

3835
@Mock
3936
private ChatModel chatLanguageModel;
4037

4138
@Test
42-
public void testBasicRetriever() {
39+
void testBasicRetriever() {
4340
Document parentDoc = getDocumentMiscTopics();
4441

4542
// Child splitter: splits into sentences using OpenNLP
@@ -67,7 +64,7 @@ public void testBasicRetriever() {
6764
}
6865

6966
@Test
70-
public void testRetrieverWithChatModel() {
67+
void testRetrieverWithChatModel() {
7168

7269
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
7370
.driver(driver)
@@ -164,13 +161,13 @@ void testRetrieverWithCustomRetrievalAndEmbeddingCreationQuery() {
164161
void testRetrieverWithCustomRetrievalAndEmbeddingCreationQueryMainDocIdAndParams() {
165162
String customCreationQuery =
166163
"""
167-
UNWIND $rows AS row
168-
MATCH (p:MainDoc {customParentId: $customParentId})
169-
CREATE (p)-[:REFERS_TO]->(u:%1$s {%2$s: row.%2$s})
170-
SET u += row.%3$s
171-
WITH row, u
172-
CALL db.create.setNodeVectorProperty(u, $embeddingProperty, row.%4$s)
173-
RETURN count(*)""";
164+
UNWIND $rows AS row
165+
MATCH (p:MainDoc {customParentId: $customParentId})
166+
CREATE (p)-[:REFERS_TO]->(u:%1$s {%2$s: row.%2$s})
167+
SET u += row.%3$s
168+
WITH row, u
169+
CALL db.create.setNodeVectorProperty(u, $embeddingProperty, row.%4$s)
170+
RETURN count(*)""";
174171
final Neo4jEmbeddingStore neo4jEmbeddingStore = Neo4jEmbeddingStore.builder()
175172
.driver(driver)
176173
.retrievalQuery(CUSTOM_RETRIEVAL)
@@ -242,15 +239,15 @@ void testRetrieverWithCustomRetrievalAndEmbeddingCreationQueryAndPreInsertedData
242239
List<Content> results = retriever.retrieve(new Query("quantum physics"));
243240

244241
// Assert
245-
assertEquals(1, results.size());
242+
assertThat(results).hasSize(1);
246243
Content parent = results.get(0);
247244

248-
assertTrue(parent.textSegment().text().contains("quantum physics"));
249-
assertEquals("science", parent.textSegment().metadata().getString("source"));
245+
assertThat(parent.textSegment().text()).contains("quantum physics");
246+
assertThat(parent.textSegment().metadata().getString("source")).isEqualTo("science");
250247
}
251248

252249
@Test
253-
public void testSummaryGraphIngestor() {
250+
void testSummaryGraphIngestor() {
254251

255252
when(chatLanguageModel.chat(anyList()))
256253
.thenReturn(ChatResponse.builder()
@@ -261,7 +258,7 @@ public void testSummaryGraphIngestor() {
261258
}
262259

263260
@Test
264-
public void testHypotheticalQuestionIngestor() {
261+
void testHypotheticalQuestionIngestor() {
265262

266263
when(chatLanguageModel.chat(anyList()))
267264
.thenReturn(ChatResponse.builder()
@@ -272,7 +269,7 @@ public void testHypotheticalQuestionIngestor() {
272269
}
273270

274271
@Test
275-
public void testParentChildRetriever() {
272+
void testParentChildRetriever() {
276273
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
277274

278275
int maxSegmentSize = 250;
@@ -325,10 +322,10 @@ protected static void summaryGraphIngestorCommon(ChatModel chatModel) {
325322
// Query and validate results
326323
List<Content> results = retriever.retrieve(Query.from("What is Machine Learning?"));
327324

328-
assertFalse(results.isEmpty(), "Should retrieve at least one parent document");
325+
assertThat(results).as("Should retrieve at least one parent document").isNotEmpty();
329326

330327
Content result = results.get(0);
331-
assertTrue(result.textSegment().text().toLowerCase().contains("machine learning"));
328+
assertThat(result.textSegment().text().toLowerCase()).contains("machine learning");
332329
assertThat(result.textSegment().metadata().getString("url")).isEqualTo("https://example.com/ai");
333330
}
334331

@@ -352,12 +349,12 @@ protected static void hypotheticalQuestionIngestorCommon(ChatModel chatModel) {
352349
final EmbeddingStoreContentRetriever retriever = getEmbeddingStoreContentRetriever(ingestor);
353350
List<Content> results = retriever.retrieve(Query.from("Who is John Doe?"));
354351

355-
assertFalse(results.isEmpty(), "Should retrieve at least one parent document");
352+
assertThat(results).as("Should retrieve at least one parent document").isNotEmpty();
356353

357354
Content result = results.get(0);
358355

359356
assertThat(result.textSegment().text().toLowerCase()).containsIgnoringWhitespaces("super saiyan");
360-
assertEquals("Wikipedia link", result.textSegment().metadata().getString("source"));
361-
assertEquals("https://example.com/ai", result.textSegment().metadata().getString("url"));
357+
assertThat(result.textSegment().metadata().getString("source")).isEqualTo("Wikipedia link");
358+
assertThat(result.textSegment().metadata().getString("url")).isEqualTo("https://example.com/ai");
362359
}
363360
}

0 commit comments

Comments
 (0)