@@ -23,6 +23,7 @@ import com.johnsnowlabs.nlp.{Annotation, AssertAnnotations}
2323import com .johnsnowlabs .tags .{FastTest , SlowTest }
2424import org .apache .spark .ml .{Pipeline , PipelineModel }
2525import org .apache .spark .sql .DataFrame
26+ import org .apache .spark .sql .functions .col
2627import org .scalatest .flatspec .AnyFlatSpec
2728
2829class WordEmbeddingsTestSpec extends AnyFlatSpec with SparkSessionTest {
@@ -31,6 +32,34 @@ class WordEmbeddingsTestSpec extends AnyFlatSpec with SparkSessionTest {
3132 .option(" header" , " true" )
3233 .csv(" src/test/resources/embeddings/clinical_words.txt" )
3334
35+ " Word Embeddings" should " Should not repeat tokens" taggedAs FastTest in {
36+
37+ val loaded = spark.read.parquet(" src/test/resources/word-embedding/test-repeated-tokens" )
38+
39+ val embeddings = WordEmbeddingsModel
40+ .pretrained(" glove_100d" , " en" )
41+ .setInputCols(Array (" splitter" , " token" ))
42+ .setOutputCol(" embedding" )
43+
44+ val pipeline = new Pipeline ()
45+ .setStages(Array (embeddings))
46+
47+ val model = pipeline.fit(loaded)
48+
49+ val result = model.transform(loaded)
50+ val duplicateBegins = result
51+ .selectExpr(" explode(embedding) as e" )
52+ .select(col(" e.begin" ).alias(" begin" ))
53+ .groupBy(" begin" )
54+ .count()
55+ .filter(col(" count" ) > 2 )
56+ .count()
57+
58+ assert(
59+ duplicateBegins == 0 ,
60+ s " Found $duplicateBegins repeated tokens (duplicate begin positions) " )
61+ }
62+
3463 " Word Embeddings" should " correctly embed clinical words not embed non-existent words" taggedAs SlowTest in {
3564
3665 val notWords = spark.read
0 commit comments