Skip to content

Commit 5c7d0d7

Browse files
authored
[SPARKNLP-1309] Fix repeating tokens in WordEmbeddings
1 parent 6600098 commit 5c7d0d7

File tree

4 files changed

+33
-2
lines changed

4 files changed

+33
-2
lines changed

src/main/scala/com/johnsnowlabs/nlp/annotators/common/TokenizedWithSentence.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ object TokenizedWithSentence extends Annotated[TokenizedSentence] {
2626
val tokens = annotations
2727
.filter(_.annotatorType == annotatorType)
2828
.toArray
29-
3029
val sentences = SentenceSplit.unpack(annotations)
3130

3231
/** // Evaluate whether to enable this validation to check proper usage of DOCUMENT and
@@ -37,7 +36,10 @@ object TokenizedWithSentence extends Annotated[TokenizedSentence] {
3736
sentences
3837
.map(sentence => {
3938
val sentenceTokens = tokens
40-
.filter(token => token.begin >= sentence.start & token.end <= sentence.end)
39+
.filter(token =>
40+
token.begin >= sentence.start &&
41+
token.end <= sentence.end &&
42+
token.metadata.getOrElse("sentence", "0").toInt == sentence.index)
4143
.map(token => IndexedToken(token.result, token.begin, token.end))
4244
sentenceTokens
4345
})

src/test/scala/com/johnsnowlabs/nlp/embeddings/WordEmbeddingsTestSpec.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import com.johnsnowlabs.nlp.{Annotation, AssertAnnotations}
2323
import com.johnsnowlabs.tags.{FastTest, SlowTest}
2424
import org.apache.spark.ml.{Pipeline, PipelineModel}
2525
import org.apache.spark.sql.DataFrame
26+
import org.apache.spark.sql.functions.col
2627
import org.scalatest.flatspec.AnyFlatSpec
2728

2829
class WordEmbeddingsTestSpec extends AnyFlatSpec with SparkSessionTest {
@@ -31,6 +32,34 @@ class WordEmbeddingsTestSpec extends AnyFlatSpec with SparkSessionTest {
3132
.option("header", "true")
3233
.csv("src/test/resources/embeddings/clinical_words.txt")
3334

35+
"Word Embeddings" should "Should not repeat tokens" taggedAs FastTest in {
36+
37+
val loaded = spark.read.parquet("src/test/resources/word-embedding/test-repeated-tokens")
38+
39+
val embeddings = WordEmbeddingsModel
40+
.pretrained("glove_100d", "en")
41+
.setInputCols(Array("splitter", "token"))
42+
.setOutputCol("embedding")
43+
44+
val pipeline = new Pipeline()
45+
.setStages(Array(embeddings))
46+
47+
val model = pipeline.fit(loaded)
48+
49+
val result = model.transform(loaded)
50+
val duplicateBegins = result
51+
.selectExpr("explode(embedding) as e")
52+
.select(col("e.begin").alias("begin"))
53+
.groupBy("begin")
54+
.count()
55+
.filter(col("count") > 2)
56+
.count()
57+
58+
assert(
59+
duplicateBegins == 0,
60+
s"Found $duplicateBegins repeated tokens (duplicate begin positions)")
61+
}
62+
3463
"Word Embeddings" should "correctly embed clinical words not embed non-existent words" taggedAs SlowTest in {
3564

3665
val notWords = spark.read

0 commit comments

Comments
 (0)