diff --git a/Cargo.lock b/Cargo.lock index 361cdc056..0078f18c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9054,6 +9054,7 @@ dependencies = [ "schemars 0.8.22", "serde", "serde_json", + "strsim", "thiserror 1.0.69", "toml 0.8.23", "tsify", diff --git a/crates/terraphim_agent/Cargo.toml b/crates/terraphim_agent/Cargo.toml index a67e4c82e..d8750adfe 100644 --- a/crates/terraphim_agent/Cargo.toml +++ b/crates/terraphim_agent/Cargo.toml @@ -79,7 +79,7 @@ terraphim_hooks = { path = "../terraphim_hooks", version = "1.0.0" } terraphim_tracker = { path = "../terraphim_tracker", version = "1.0.0" } terraphim_orchestrator = { path = "../terraphim_orchestrator", version = "1.0.0" } # Session search - uses workspace version (path for dev, version for crates.io) -terraphim_sessions = { path = "../terraphim_sessions", version = "1.6.0", optional = true, features = ["tsa-full", "aider-connector"] } +terraphim_sessions = { path = "../terraphim_sessions", version = "1.6.0", optional = true, features = ["tsa-full", "aider-connector", "search-index"] } [dev-dependencies] serial_test = "3.3" diff --git a/crates/terraphim_agent/src/repl/commands.rs b/crates/terraphim_agent/src/repl/commands.rs index 178c7d18d..aa2984646 100644 --- a/crates/terraphim_agent/src/repl/commands.rs +++ b/crates/terraphim_agent/src/repl/commands.rs @@ -191,6 +191,8 @@ pub enum SessionsSubcommand { Files { session_id: String, json: bool }, /// Find sessions by file path ByFile { file_path: String, json: bool }, + /// Build search index and show index statistics + Index { verbose: bool }, } #[cfg(feature = "firecracker")] @@ -1309,8 +1311,14 @@ impl FromStr for ReplCommand { subcommand: SessionsSubcommand::ByFile { file_path, json }, }) } + "index" => { + let verbose = parts.contains(&"--verbose") || parts.contains(&"-v"); + Ok(ReplCommand::Sessions { + subcommand: SessionsSubcommand::Index { verbose }, + }) + } _ => Err(anyhow!( - "Unknown sessions subcommand: {}. Use: sources, list, search, stats, show, concepts, related, timeline, export, enrich, files, by-file", + "Unknown sessions subcommand: {}. Use: sources, list, search, stats, show, concepts, related, timeline, export, enrich, files, by-file, index", parts[1] )), } diff --git a/crates/terraphim_agent/src/repl/handler.rs b/crates/terraphim_agent/src/repl/handler.rs index 4c7c819a7..890133b30 100644 --- a/crates/terraphim_agent/src/repl/handler.rs +++ b/crates/terraphim_agent/src/repl/handler.rs @@ -1904,49 +1904,105 @@ impl ReplHandler { } SessionsSubcommand::Search { query } => { - let sessions = svc.search(&query).await; + #[cfg(feature = "enrichment")] + { + let thesaurus = if let Some(ref tui_service) = self.service { + let role_name: terraphim_types::RoleName = self.current_role.clone().into(); + tui_service.get_thesaurus(&role_name).await.ok() + } else { + None + }; + let sessions = svc.search_with_thesaurus(&query, thesaurus).await; - if sessions.is_empty() { - println!("{} No sessions match '{}'", "ℹ".blue().bold(), query.cyan()); - return Ok(()); - } + if sessions.is_empty() { + println!("{} No sessions match '{}'", "ℹ".blue().bold(), query.cyan()); + return Ok(()); + } - println!( - "\n{} sessions match '{}':", - sessions.len().to_string().green(), - query.cyan() - ); - let mut table = Table::new(); - table - .load_preset(UTF8_FULL) - .apply_modifier(UTF8_ROUND_CORNERS) - .set_header(vec![ - Cell::new("ID").add_attribute(comfy_table::Attribute::Bold), - Cell::new("Source").add_attribute(comfy_table::Attribute::Bold), - Cell::new("Title").add_attribute(comfy_table::Attribute::Bold), - ]); + println!( + "\n{} sessions match '{}':", + sessions.len().to_string().green(), + query.cyan() + ); + let mut table = Table::new(); + table + .load_preset(UTF8_FULL) + .apply_modifier(UTF8_ROUND_CORNERS) + .set_header(vec![ + Cell::new("ID").add_attribute(comfy_table::Attribute::Bold), + Cell::new("Source").add_attribute(comfy_table::Attribute::Bold), + Cell::new("Title").add_attribute(comfy_table::Attribute::Bold), + ]); - for session in sessions.iter().take(10) { - let title = session - .title - .as_ref() - .map(|t| { - if t.len() > 50 { - format!("{}...", &t[..50]) - } else { - t.clone() - } - }) - .unwrap_or_else(|| "-".to_string()); + for session in sessions.iter().take(10) { + let title = session + .title + .as_ref() + .map(|t| { + if t.len() > 50 { + format!("{}...", &t[..50]) + } else { + t.clone() + } + }) + .unwrap_or_else(|| "-".to_string()); - table.add_row(vec![ - Cell::new(&session.external_id[..8.min(session.external_id.len())]), - Cell::new(&session.source), - Cell::new(title), - ]); + table.add_row(vec![ + Cell::new(&session.external_id[..8.min(session.external_id.len())]), + Cell::new(&session.source), + Cell::new(title), + ]); + } + + println!("{}", table); } - println!("{}", table); + #[cfg(not(feature = "enrichment"))] + { + let sessions = svc.search(&query).await; + + if sessions.is_empty() { + println!("{} No sessions match '{}'", "ℹ".blue().bold(), query.cyan()); + return Ok(()); + } + + println!( + "\n{} sessions match '{}':", + sessions.len().to_string().green(), + query.cyan() + ); + let mut table = Table::new(); + table + .load_preset(UTF8_FULL) + .apply_modifier(UTF8_ROUND_CORNERS) + .set_header(vec![ + Cell::new("ID").add_attribute(comfy_table::Attribute::Bold), + Cell::new("Source").add_attribute(comfy_table::Attribute::Bold), + Cell::new("Title").add_attribute(comfy_table::Attribute::Bold), + ]); + + for session in sessions.iter().take(10) { + let title = session + .title + .as_ref() + .map(|t| { + if t.len() > 50 { + format!("{}...", &t[..50]) + } else { + t.clone() + } + }) + .unwrap_or_else(|| "-".to_string()); + + table.add_row(vec![ + Cell::new(&session.external_id[..8.min(session.external_id.len())]), + Cell::new(&session.source), + Cell::new(title), + ]); + } + + println!("{}", table); + } } SessionsSubcommand::Stats => { @@ -2546,6 +2602,44 @@ impl ReplHandler { println!("{}", table); } } + + SessionsSubcommand::Index { verbose } => { + let sessions = svc.list_sessions().await; + let count = sessions.len(); + let total_messages: usize = sessions.iter().map(|s| s.message_count()).sum(); + + println!("\n{} Search Index Status", "🔍".bold()); + println!("{}", "─".repeat(40)); + println!(" Sessions indexed: {}", count.to_string().green()); + println!(" Total messages: {}", total_messages.to_string().green()); + println!(" Scorer: {}", "BM25 (Okapi)".cyan()); + + if verbose { + let sources: std::collections::HashMap<&str, usize> = { + let mut map = std::collections::HashMap::new(); + for s in &sessions { + *map.entry(s.source.as_str()).or_default() += 1; + } + map + }; + println!("\n {} By source:", "▸".blue()); + for (source, cnt) in sources { + println!(" {}: {}", source.magenta(), cnt); + } + + let total_chars: usize = sessions + .iter() + .map(|s| { + s.messages.iter().map(|m| m.content.len()).sum::() + + s.title.as_ref().map(|t| t.len()).unwrap_or(0) + }) + .sum(); + println!( + "\n Total text size: {} chars", + total_chars.to_string().yellow() + ); + } + } } Ok(()) diff --git a/crates/terraphim_service/src/score/bm25_additional_test.rs b/crates/terraphim_service/src/score/bm25_additional_test.rs deleted file mode 100644 index a32ce1e05..000000000 --- a/crates/terraphim_service/src/score/bm25_additional_test.rs +++ /dev/null @@ -1,676 +0,0 @@ -#[cfg(test)] -mod tests { - use super::super::bm25::{BM25FScorer, BM25PlusScorer}; - use super::super::bm25_additional::{OkapiBM25Scorer, TFIDFScorer, JaccardScorer, QueryRatioScorer}; - use terraphim_types::Document; - use std::collections::HashSet; - - // Test documents for all tests - fn get_test_documents() -> Vec { - vec![ - Document { - id: "1".to_string(), - url: "http://example.com/1".to_string(), - title: "Rust Programming Language".to_string(), - body: "Rust is a systems programming language focused on safety, speed, and concurrency.".to_string(), - description: Some("Learn about Rust programming".to_string()), - stub: None, - tags: Some(vec!["programming".to_string(), "systems".to_string()]), - rank: None, - }, - Document { - id: "2".to_string(), - url: "http://example.com/2".to_string(), - title: "Python Programming Tutorial".to_string(), - body: "Python is a high-level programming language known for its readability.".to_string(), - description: Some("Learn Python programming".to_string()), - stub: None, - tags: Some(vec!["programming".to_string(), "tutorial".to_string()]), - rank: None, - }, - Document { - id: "3".to_string(), - url: "http://example.com/3".to_string(), - title: "JavaScript for Web Development".to_string(), - body: "JavaScript is a scripting language that enables interactive web pages.".to_string(), - description: Some("Learn JavaScript for web development".to_string()), - stub: None, - tags: Some(vec!["programming".to_string(), "web".to_string()]), - rank: None, - }, - ] - } - - #[test] - fn test_compare_bm25plus_with_okapi_bm25() { - let documents = get_test_documents(); - - // Initialize BM25+ scorer - let mut bm25plus_scorer = BM25PlusScorer::new(); - bm25plus_scorer.initialize(&documents); - - // Initialize Okapi BM25 scorer - let mut okapi_bm25_scorer = OkapiBM25Scorer::new(); - okapi_bm25_scorer.initialize(&documents); - - // Test queries - let queries = vec![ - "rust programming", - "python tutorial", - "javascript web", - "programming language", - ]; - - for query in queries { - println!("Query: {}", query); - - // Score documents with BM25+ - let mut bm25plus_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25plus_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - bm25plus_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Score documents with Okapi BM25 - let mut okapi_bm25_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = okapi_bm25_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - okapi_bm25_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - println!("BM25+ ranking: {:?}", bm25plus_scores); - println!("Okapi BM25 ranking: {:?}", okapi_bm25_scores); - - // Check if the top document is the same for both scorers - // This is a basic validation that the scorers are producing similar results - assert_eq!( - bm25plus_scores.first().unwrap().0, - okapi_bm25_scores.first().unwrap().0, - "Top document should be the same for BM25+ and Okapi BM25 for query: {}", - query - ); - } - } - - #[test] - fn test_compare_bm25f_with_tfidf() { - let documents = get_test_documents(); - - // Initialize BM25F scorer - let mut bm25f_scorer = BM25FScorer::new(); - bm25f_scorer.initialize(&documents); - - // Initialize TFIDF scorer - let mut tfidf_scorer = TFIDFScorer::new(); - tfidf_scorer.initialize(&documents); - - // Test queries - let queries = vec![ - "rust programming", - "python tutorial", - "javascript web", - "programming language", - ]; - - for query in queries { - println!("Query: {}", query); - - // Score documents with BM25F - let mut bm25f_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25f_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - bm25f_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Score documents with TFIDF - let mut tfidf_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = tfidf_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - tfidf_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - println!("BM25F ranking: {:?}", bm25f_scores); - println!("TFIDF ranking: {:?}", tfidf_scores); - - // We don't assert equality here because BM25F and TFIDF can produce different rankings - // Instead, we just print the rankings for manual inspection - } - } - - #[test] - fn test_jaccard_scorer() { - let documents = get_test_documents(); - - // Initialize Jaccard scorer - let mut jaccard_scorer = JaccardScorer::new(); - jaccard_scorer.initialize(&documents); - - // Test queries - let queries = vec![ - "rust programming", - "python tutorial", - "javascript web", - "programming language", - ]; - - for query in queries { - println!("Query: {}", query); - - // Score documents with Jaccard - let mut jaccard_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = jaccard_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - jaccard_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - println!("Jaccard ranking: {:?}", jaccard_scores); - - // Verify that scores are between 0 and 1 - for (_, score) in &jaccard_scores { - assert!(*score >= 0.0 && *score <= 1.0, "Jaccard score should be between 0 and 1"); - } - - // Verify that the top document contains at least one of the query terms - let top_doc_id = jaccard_scores.first().unwrap().0.clone(); - let top_doc = documents.iter().find(|doc| doc.id == top_doc_id).unwrap(); - - let query_terms: Vec<&str> = query.split_whitespace().collect(); - let doc_contains_query_term = query_terms.iter().any(|term| { - top_doc.body.to_lowercase().contains(&term.to_lowercase()) || - top_doc.title.to_lowercase().contains(&term.to_lowercase()) - }); - - assert!(doc_contains_query_term, "Top document should contain at least one query term"); - } - } - - #[test] - fn test_query_ratio_scorer() { - let documents = get_test_documents(); - - // Initialize QueryRatio scorer - let mut query_ratio_scorer = QueryRatioScorer::new(); - query_ratio_scorer.initialize(&documents); - - // Test queries - let queries = vec![ - "rust programming", - "python tutorial", - "javascript web", - "programming language", - ]; - - for query in queries { - println!("Query: {}", query); - - // Score documents with QueryRatio - let mut query_ratio_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = query_ratio_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - query_ratio_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - println!("QueryRatio ranking: {:?}", query_ratio_scores); - - // Verify that scores are between 0 and 1 - for (_, score) in &query_ratio_scores { - assert!(*score >= 0.0 && *score <= 1.0, "QueryRatio score should be between 0 and 1"); - } - - // Verify that the top document contains at least one of the query terms - let top_doc_id = query_ratio_scores.first().unwrap().0.clone(); - let top_doc = documents.iter().find(|doc| doc.id == top_doc_id).unwrap(); - - let query_terms: Vec<&str> = query.split_whitespace().collect(); - let doc_contains_query_term = query_terms.iter().any(|term| { - top_doc.body.to_lowercase().contains(&term.to_lowercase()) || - top_doc.title.to_lowercase().contains(&term.to_lowercase()) - }); - - assert!(doc_contains_query_term, "Top document should contain at least one query term"); - } - } - - #[test] - fn test_all_scorers_with_same_query() { - let documents = get_test_documents(); - let query = "programming language"; - - // Initialize all scorers - let mut bm25f_scorer = BM25FScorer::new(); - bm25f_scorer.initialize(&documents); - - let mut bm25plus_scorer = BM25PlusScorer::new(); - bm25plus_scorer.initialize(&documents); - - let mut okapi_bm25_scorer = OkapiBM25Scorer::new(); - okapi_bm25_scorer.initialize(&documents); - - let mut tfidf_scorer = TFIDFScorer::new(); - tfidf_scorer.initialize(&documents); - - let mut jaccard_scorer = JaccardScorer::new(); - jaccard_scorer.initialize(&documents); - - let mut query_ratio_scorer = QueryRatioScorer::new(); - query_ratio_scorer.initialize(&documents); - - // Score documents with all scorers - let mut bm25f_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25f_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - let mut bm25plus_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25plus_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - let mut okapi_bm25_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = okapi_bm25_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - let mut tfidf_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = tfidf_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - let mut jaccard_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = jaccard_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - let mut query_ratio_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = query_ratio_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort all scores by score in descending order - bm25f_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - bm25plus_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - okapi_bm25_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - tfidf_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - jaccard_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - query_ratio_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Print all rankings - println!("Query: {}", query); - println!("BM25F ranking: {:?}", bm25f_scores); - println!("BM25+ ranking: {:?}", bm25plus_scores); - println!("Okapi BM25 ranking: {:?}", okapi_bm25_scores); - println!("TFIDF ranking: {:?}", tfidf_scores); - println!("Jaccard ranking: {:?}", jaccard_scores); - println!("QueryRatio ranking: {:?}", query_ratio_scores); - - // Verify that all scorers return non-zero scores for documents containing query terms - for doc in &documents { - if doc.body.to_lowercase().contains("programming") || - doc.title.to_lowercase().contains("programming") || - doc.body.to_lowercase().contains("language") || - doc.title.to_lowercase().contains("language") { - - let bm25f_score = bm25f_scores.iter().find(|(id, _)| id == &doc.id).unwrap().1; - let bm25plus_score = bm25plus_scores.iter().find(|(id, _)| id == &doc.id).unwrap().1; - let okapi_bm25_score = okapi_bm25_scores.iter().find(|(id, _)| id == &doc.id).unwrap().1; - let tfidf_score = tfidf_scores.iter().find(|(id, _)| id == &doc.id).unwrap().1; - let jaccard_score = jaccard_scores.iter().find(|(id, _)| id == &doc.id).unwrap().1; - let query_ratio_score = query_ratio_scores.iter().find(|(id, _)| id == &doc.id).unwrap().1; - - // Check if the document contains both terms or just one term - let contains_both_terms = (doc.body.to_lowercase().contains("programming") || - doc.title.to_lowercase().contains("programming")) && - (doc.body.to_lowercase().contains("language") || - doc.title.to_lowercase().contains("language")); - - // For documents containing both terms, all scorers should return positive scores - if contains_both_terms { - assert!(bm25f_score > 0.0, "BM25F score should be positive for document containing both query terms"); - assert!(bm25plus_score > 0.0, "BM25+ score should be positive for document containing both query terms"); - assert!(okapi_bm25_score > 0.0, "Okapi BM25 score should be positive for document containing both query terms"); - assert!(tfidf_score > 0.0, "TFIDF score should be positive for document containing both query terms"); - assert!(jaccard_score > 0.0, "Jaccard score should be positive for document containing both query terms"); - assert!(query_ratio_score > 0.0, "QueryRatio score should be positive for document containing both query terms"); - } else { - // For documents containing only one term, some scorers might return zero scores - // depending on their implementation, so we don't assert anything here - println!("Document {} contains only one query term", doc.id); - println!("BM25F score: {}", bm25f_score); - println!("BM25+ score: {}", bm25plus_score); - println!("Okapi BM25 score: {}", okapi_bm25_score); - println!("TFIDF score: {}", tfidf_score); - println!("Jaccard score: {}", jaccard_score); - println!("QueryRatio score: {}", query_ratio_score); - } - } - } - } - - #[test] - fn test_validate_jaccard_similarity() { - // Create test documents with predictable term overlap - let documents = vec![ - Document { - id: "doc1".to_string(), - url: "http://example.com/1".to_string(), - title: "apple banana cherry".to_string(), - body: "apple banana cherry date".to_string(), - description: None, - summarization: None, - stub: None, - tags: None, - rank: None, - }, - Document { - id: "doc2".to_string(), - url: "http://example.com/2".to_string(), - title: "apple banana".to_string(), - body: "apple banana elderberry".to_string(), - description: None, - summarization: None, - stub: None, - tags: None, - rank: None, - }, - Document { - id: "doc3".to_string(), - url: "http://example.com/3".to_string(), - title: "cherry date".to_string(), - body: "cherry date fig".to_string(), - description: None, - summarization: None, - stub: None, - tags: None, - rank: None, - }, - ]; - - // Initialize Jaccard scorer - let mut jaccard_scorer = JaccardScorer::new(); - jaccard_scorer.initialize(&documents); - - // Test with query "apple banana" - let query = "apple banana"; - let scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = jaccard_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Calculate expected scores manually - // For doc1: intersection = 2 (apple, banana), union = 4 (apple, banana, cherry, date) => 2/4 = 0.5 - // For doc2: intersection = 2 (apple, banana), union = 3 (apple, banana, elderberry) => 2/3 = 0.67 - // For doc3: intersection = 0, union = 5 (apple, banana, cherry, date, fig) => 0/5 = 0 - - println!("Query: {}", query); - println!("Jaccard scores: {:?}", scores); - - // Verify scores are within expected ranges - assert!(scores[0].1 >= 0.45 && scores[0].1 <= 0.55, "Doc1 score should be around 0.5"); - assert!(scores[1].1 >= 0.6 && scores[1].1 <= 0.7, "Doc2 score should be around 0.67"); - assert_eq!(scores[2].1, 0.0, "Doc3 score should be 0"); - - // Verify ranking order - let mut ranked_scores = scores.clone(); - ranked_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - assert_eq!(ranked_scores[0].0, "doc2", "Doc2 should be ranked first"); - assert_eq!(ranked_scores[1].0, "doc1", "Doc1 should be ranked second"); - assert_eq!(ranked_scores[2].0, "doc3", "Doc3 should be ranked third"); - } - - #[test] - fn test_compare_jaccard_with_other_measures() { - let documents = get_test_documents(); // Use existing test documents - - // Initialize scorers - let mut jaccard_scorer = JaccardScorer::new(); - let mut query_ratio_scorer = QueryRatioScorer::new(); - let mut tfidf_scorer = TFIDFScorer::new(); - - jaccard_scorer.initialize(&documents); - query_ratio_scorer.initialize(&documents); - tfidf_scorer.initialize(&documents); - - // Test queries with different characteristics - let queries = vec![ - "rare unique terms", // Query with rare terms - "common frequent words", // Query with common terms - "programming language", // Query with terms in the documents - ]; - - for query in queries { - println!("\nQuery: {}", query); - - // Score with Jaccard - let jaccard_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = jaccard_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Score with QueryRatio - let query_ratio_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = query_ratio_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Score with TFIDF - let tfidf_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = tfidf_scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - println!("Jaccard scores: {:?}", jaccard_scores); - println!("QueryRatio scores: {:?}", query_ratio_scores); - println!("TFIDF scores: {:?}", tfidf_scores); - - // Verify Jaccard scores are between 0 and 1 - for (_, score) in &jaccard_scores { - assert!(*score >= 0.0 && *score <= 1.0, "Jaccard score should be between 0 and 1"); - } - } - } - - #[test] - fn test_jaccard_edge_cases() { - let documents = vec![ - Document { - id: "empty".to_string(), - url: "http://example.com/empty".to_string(), - title: "".to_string(), - body: "".to_string(), - description: None, - summarization: None, - stub: None, - tags: None, - rank: None, - }, - Document { - id: "identical".to_string(), - url: "http://example.com/identical".to_string(), - title: "test query".to_string(), - body: "test query".to_string(), - description: None, - summarization: None, - stub: None, - tags: None, - rank: None, - }, - Document { - id: "no_overlap".to_string(), - url: "http://example.com/no_overlap".to_string(), - title: "completely different content".to_string(), - body: "absolutely no overlap with search terms".to_string(), - description: None, - summarization: None, - stub: None, - tags: None, - rank: None, - }, - ]; - - let mut jaccard_scorer = JaccardScorer::new(); - jaccard_scorer.initialize(&documents); - - // Test with empty query - let empty_query = ""; - let empty_query_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = jaccard_scorer.score(empty_query, doc); - (doc.id.clone(), score) - }) - .collect(); - println!("Empty query scores: {:?}", empty_query_scores); - - // Test with exact match query - let exact_query = "test query"; - let exact_query_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = jaccard_scorer.score(exact_query, doc); - (doc.id.clone(), score) - }) - .collect(); - println!("Exact match query scores: {:?}", exact_query_scores); - - // Verify empty query returns 0 for all documents - for (_, score) in &empty_query_scores { - assert_eq!(*score, 0.0, "Empty query should return 0 score"); - } - - // Verify exact match returns 1.0 for identical document - let identical_score = exact_query_scores.iter() - .find(|(id, _)| id == "identical") - .unwrap().1; - assert!(identical_score > 0.9, "Identical document should have score close to 1.0"); - - // Verify no overlap has low score (not necessarily 0 due to how Jaccard works with term sets) - let no_overlap_score = exact_query_scores.iter() - .find(|(id, _)| id == "no_overlap") - .unwrap().1; - assert_eq!(no_overlap_score, 0.0, "Document with no overlapping terms should have a score of 0"); - - // Debug the intersection calculation - let query_terms: Vec = exact_query.split_whitespace() - .map(|s| s.to_lowercase()) - .collect(); - let no_overlap_doc = documents.iter().find(|doc| doc.id == "no_overlap").unwrap(); - let doc_terms: Vec = no_overlap_doc.body.split_whitespace() - .map(|s| s.to_lowercase()) - .collect(); - - println!("Query terms: {:?}", query_terms); - println!("Document terms: {:?}", doc_terms); - - let query_set: std::collections::HashSet = query_terms.into_iter().collect(); - let doc_set: std::collections::HashSet = doc_terms.into_iter().collect(); - - println!("Query set: {:?}", query_set); - println!("Document set: {:?}", doc_set); - - let intersection: std::collections::HashSet<_> = query_set.intersection(&doc_set).cloned().collect(); - println!("Intersection: {:?}", intersection); - - assert_eq!(intersection.len(), 0, "Intersection should be 0 for document with no overlap"); - } - - #[test] - fn test_visualize_jaccard_similarity() { - let documents = get_test_documents(); - let mut jaccard_scorer = JaccardScorer::new(); - jaccard_scorer.initialize(&documents); - - let query = "programming language"; - - // Score documents - let scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - // Calculate term sets - let query_terms: HashSet = query.split_whitespace() - .map(|s| s.to_lowercase()) - .collect(); - - let doc_terms: HashSet = doc.body.split_whitespace() - .map(|s| s.to_lowercase()) - .collect(); - - // Calculate intersection and union - let intersection = query_terms.intersection(&doc_terms).count(); - let union = query_terms.len() + doc_terms.len() - intersection; - - // Calculate Jaccard score - let score = if union > 0 { - intersection as f64 / union as f64 - } else { - 0.0 - }; - - println!("Document: {}", doc.id); - println!(" Query terms: {:?}", query_terms); - println!(" Doc terms: {:?}", doc_terms); - println!(" Intersection: {}", intersection); - println!(" Union: {}", union); - println!(" Jaccard score: {:.4}", score); - println!(); - - // Compare with the scorer's result - let scorer_result = jaccard_scorer.score(query, doc); - println!(" Scorer result: {:.4}", scorer_result); - - // They should be close (allowing for minor differences in implementation) - assert!((score - scorer_result).abs() < 0.1, - "Manual calculation ({}) should match scorer result ({})", - score, scorer_result); - - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score - let mut ranked_scores = scores.clone(); - ranked_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - println!("Final ranking: {:?}", ranked_scores); - } -} diff --git a/crates/terraphim_service/src/score/bm25_test.rs b/crates/terraphim_service/src/score/bm25_test.rs deleted file mode 100644 index 4523bdfab..000000000 --- a/crates/terraphim_service/src/score/bm25_test.rs +++ /dev/null @@ -1,291 +0,0 @@ -use crate::score::bm25::{BM25FScorer, BM25PlusScorer}; -use crate::score::bm25_additional::{ - JaccardScorer, OkapiBM25Scorer, QueryRatioScorer, TFIDFScorer, -}; -use crate::score::common::{BM25Params, FieldWeights}; -use terraphim_types::{Document, DocumentType}; - -fn create_test_documents() -> Vec { - vec![ - Document { - id: "doc1".to_string(), - title: "Introduction to Rust Programming".to_string(), - body: "Rust is a systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety.".to_string(), - description: Some("A comprehensive guide to Rust programming language".to_string()), - summarization: None, - tags: Some(vec!["programming".to_string(), "rust".to_string(), "systems".to_string()]), - rank: None, - stub: None, - url: "https://example.com/doc1".to_string(), - source_haystack: None, - doc_type: DocumentType::KgEntry, - synonyms: None, - route: None, - priority: None, - }, - Document { - id: "doc2".to_string(), - title: "Advanced Rust Concepts".to_string(), - body: "This document covers advanced Rust concepts including ownership, borrowing, and lifetimes.".to_string(), - description: Some("Deep dive into advanced Rust programming concepts".to_string()), - summarization: None, - tags: Some(vec!["rust".to_string(), "advanced".to_string(), "ownership".to_string()]), - rank: None, - stub: None, - url: "https://example.com/doc2".to_string(), - source_haystack: None, - doc_type: DocumentType::KgEntry, - synonyms: None, - route: None, - priority: None, - }, - Document { - id: "doc3".to_string(), - title: "Systems Programming with Rust".to_string(), - body: "Systems programming requires careful memory management and performance optimization.".to_string(), - description: Some("Guide to systems programming using Rust".to_string()), - summarization: None, - tags: Some(vec!["systems".to_string(), "programming".to_string(), "performance".to_string()]), - rank: None, - stub: None, - url: "https://example.com/doc3".to_string(), - source_haystack: None, - doc_type: DocumentType::KgEntry, - synonyms: None, - route: None, - priority: None, - }, - ] -} - -#[test] -fn test_bm25_scorer_basic_functionality() { - let documents = create_test_documents(); - let mut bm25_scorer = OkapiBM25Scorer::new(); - bm25_scorer.initialize(&documents); - - let query = "rust programming"; - let scores: Vec = documents - .iter() - .map(|doc| bm25_scorer.score(query, doc)) - .collect(); - - // All documents should have positive scores for "rust programming" - assert!(scores.iter().all(|&score| score >= 0.0)); - - // Document 1 should have the highest score as it contains both "rust" and "programming" - assert!(scores[0] > scores[2]); - assert!(scores[0] > scores[1]); -} - -#[test] -fn test_bm25f_scorer_field_weights() { - let documents = create_test_documents(); - let weights = FieldWeights { - title: 2.0, - body: 1.0, - description: 1.5, - tags: 0.5, - }; - let params = BM25Params { - k1: 1.2, - b: 0.75, - delta: 1.0, - }; - - let mut bm25f_scorer = BM25FScorer::with_params(params, weights); - bm25f_scorer.initialize(&documents); - - let query = "rust"; - let scores: Vec = documents - .iter() - .map(|doc| bm25f_scorer.score(query, doc)) - .collect(); - - // All documents should have positive scores - assert!(scores.iter().all(|&score| score >= 0.0)); - - // Document with "rust" in title should score higher than document with "rust" only in body - assert!(scores[0] > scores[2]); -} - -#[test] -fn test_bm25plus_scorer_enhanced_parameters() { - let documents = create_test_documents(); - let params = BM25Params { - k1: 1.5, - b: 0.8, - delta: 1.2, - }; - - let mut bm25plus_scorer = BM25PlusScorer::with_params(params); - bm25plus_scorer.initialize(&documents); - - let query = "systems programming"; - let scores: Vec = documents - .iter() - .map(|doc| bm25plus_scorer.score(query, doc)) - .collect(); - - // All documents should have positive scores - assert!(scores.iter().all(|&score| score >= 0.0)); - - // Document 3 should have the highest score as it contains both "systems" and "programming" - assert!(scores[2] > scores[0]); - assert!(scores[2] > scores[1]); -} - -#[test] -fn test_tfidf_scorer_traditional_approach() { - let documents = create_test_documents(); - let mut tfidf_scorer = TFIDFScorer::new(); - tfidf_scorer.initialize(&documents); - - let query = "rust"; - let scores: Vec = documents - .iter() - .map(|doc| tfidf_scorer.score(query, doc)) - .collect(); - - // All documents should have positive scores - assert!(scores.iter().all(|&score| score >= 0.0)); - - // Documents with "rust" should have higher scores than those without - assert!(scores[0] > 0.0); - assert!(scores[1] > 0.0); -} - -#[test] -fn test_jaccard_scorer_similarity_based() { - let documents = create_test_documents(); - let mut jaccard_scorer = JaccardScorer::new(); - jaccard_scorer.initialize(&documents); - - let query = "rust programming"; - let scores: Vec = documents - .iter() - .map(|doc| jaccard_scorer.score(query, doc)) - .collect(); - - // All scores should be between 0.0 and 1.0 (Jaccard similarity) - assert!(scores.iter().all(|&score| (0.0..=1.0).contains(&score))); - - // Document 1 should have the highest similarity as it contains both terms - assert!(scores[0] > scores[2]); -} - -#[test] -fn test_query_ratio_scorer_term_matching() { - let documents = create_test_documents(); - let mut query_ratio_scorer = QueryRatioScorer::new(); - query_ratio_scorer.initialize(&documents); - - let query = "rust systems"; - let scores: Vec = documents - .iter() - .map(|doc| query_ratio_scorer.score(query, doc)) - .collect(); - - // All scores should be between 0.0 and 1.0 (ratio of matched terms) - assert!(scores.iter().all(|&score| (0.0..=1.0).contains(&score))); - - // Document 1 should have the highest ratio as it contains both "rust" and "systems" - assert!(scores[0] > scores[1]); - assert!(scores[0] > scores[2]); -} - -#[test] -fn test_scorer_initialization_with_empty_documents() { - let empty_documents: Vec = vec![]; - - let mut bm25_scorer = OkapiBM25Scorer::new(); - bm25_scorer.initialize(&empty_documents); - - let mut bm25f_scorer = BM25FScorer::new(); - bm25f_scorer.initialize(&empty_documents); - - let mut bm25plus_scorer = BM25PlusScorer::new(); - bm25plus_scorer.initialize(&empty_documents); - - // Should not panic with empty documents - // Note: We can't access private fields, so we just verify initialization doesn't panic - // Test passes by not panicking during initialization -} - -#[test] -fn test_scorer_empty_query_handling() { - let documents = create_test_documents(); - let mut bm25_scorer = OkapiBM25Scorer::new(); - bm25_scorer.initialize(&documents); - - let empty_query = ""; - let scores: Vec = documents - .iter() - .map(|doc| bm25_scorer.score(empty_query, doc)) - .collect(); - - // Empty query should return 0.0 scores - assert!(scores.iter().all(|&score| score == 0.0)); -} - -#[test] -fn test_scorer_case_insensitive_matching() { - let documents = create_test_documents(); - let mut bm25_scorer = OkapiBM25Scorer::new(); - bm25_scorer.initialize(&documents); - - let query_lower = "rust programming"; - let query_upper = "RUST PROGRAMMING"; - - let scores_lower: Vec = documents - .iter() - .map(|doc| bm25_scorer.score(query_lower, doc)) - .collect(); - - let scores_upper: Vec = documents - .iter() - .map(|doc| bm25_scorer.score(query_upper, doc)) - .collect(); - - // Case should not affect scores significantly - for (lower, upper) in scores_lower.iter().zip(scores_upper.iter()) { - assert!((lower - upper).abs() < 0.001); - } -} - -#[test] -fn test_scorer_parameter_sensitivity() { - let documents = create_test_documents(); - let query = "rust programming"; - - // Test different k1 values - let params_low_k1 = BM25Params { - k1: 0.5, - b: 0.75, - delta: 1.0, - }; - let params_high_k1 = BM25Params { - k1: 2.0, - b: 0.75, - delta: 1.0, - }; - - let mut scorer_low = BM25FScorer::with_params(params_low_k1, FieldWeights::default()); - let mut scorer_high = BM25FScorer::with_params(params_high_k1, FieldWeights::default()); - - scorer_low.initialize(&documents); - scorer_high.initialize(&documents); - - let scores_low: Vec = documents - .iter() - .map(|doc| scorer_low.score(query, doc)) - .collect(); - - let scores_high: Vec = documents - .iter() - .map(|doc| scorer_high.score(query, doc)) - .collect(); - - // Different k1 values should produce different scores - assert_ne!(scores_low, scores_high); -} diff --git a/crates/terraphim_service/src/score/bm25_test_dataset.rs b/crates/terraphim_service/src/score/bm25_test_dataset.rs deleted file mode 100644 index 991bf98a7..000000000 --- a/crates/terraphim_service/src/score/bm25_test_dataset.rs +++ /dev/null @@ -1,551 +0,0 @@ -use std::collections::HashMap; -use std::fs::File; -use std::io::Read; -use std::path::Path; - -use serde::{Deserialize, Serialize}; -use terraphim_types::Document; - -use super::bm25::{BM25FScorer, BM25PlusScorer, BM25Params, FieldWeights}; - -/// Test document structure from the test dataset -#[derive(Debug, Deserialize, Serialize)] -struct TestDocument { - id: String, - url: String, - title: String, - body: String, - description: Option, - tags: Option>, - rank: Option, -} - -/// Test query structure from the test dataset -#[derive(Debug, Deserialize, Serialize)] -struct TestQuery { - id: String, - query: String, - expected_results: Option>>, - description: String, -} - -/// Test dataset structure -#[derive(Debug, Deserialize, Serialize)] -struct TestDataset { - documents: Vec, -} - -/// Queries dataset structure -#[derive(Debug, Deserialize, Serialize)] -struct QueriesDataset { - queries: Vec, -} - -/// Convert a test document to a terraphim document -fn convert_test_document(doc: &TestDocument) -> Document { - Document { - id: doc.id.clone(), - url: doc.url.clone(), - title: doc.title.clone(), - body: doc.body.clone(), - description: doc.description.clone(), - stub: None, - tags: doc.tags.clone(), - rank: doc.rank, - } -} - -/// Load test documents from a JSON file -fn load_test_documents(file_path: &str) -> Vec { - let path = Path::new(file_path); - let mut file = File::open(path).expect("Failed to open test dataset file"); - let mut contents = String::new(); - file.read_to_string(&mut contents).expect("Failed to read test dataset file"); - - let dataset: TestDataset = serde_json::from_str(&contents).expect("Failed to parse test dataset"); - - dataset.documents.iter() - .map(|doc| convert_test_document(doc)) - .collect() -} - -/// Load test queries from a JSON file -fn load_test_queries(file_path: &str) -> Vec { - let path = Path::new(file_path); - let mut file = File::open(path).expect("Failed to open queries file"); - let mut contents = String::new(); - file.read_to_string(&mut contents).expect("Failed to read queries file"); - - let dataset: QueriesDataset = serde_json::from_str(&contents).expect("Failed to parse queries dataset"); - - dataset.queries -} - -/// Score documents using BM25F and return them sorted by score -fn score_documents_bm25f(query: &str, documents: &[Document]) -> Vec<(String, f64)> { - let mut scorer = BM25FScorer::new(); - scorer.initialize(documents); - - let mut scored_docs: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - scored_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - scored_docs -} - -/// Score documents using BM25Plus and return them sorted by score -fn score_documents_bm25plus(query: &str, documents: &[Document]) -> Vec<(String, f64)> { - let mut scorer = BM25PlusScorer::new(); - scorer.initialize(documents); - - let mut scored_docs: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = scorer.score(query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - scored_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - scored_docs -} - -/// Compare the ranking of documents with the expected ranking -fn compare_rankings(actual: &[(String, f64)], expected: &[String]) -> bool { - if actual.len() < expected.len() { - return false; - } - - // Check if all expected documents are in the top results - // (not necessarily in the exact same order) - let actual_top_n: Vec<&String> = actual.iter() - .take(expected.len()) - .map(|(id, _)| id) - .collect(); - - for expected_id in expected { - if !actual_top_n.contains(&expected_id) { - return false; - } - } - - true -} - -#[cfg(test)] -mod tests { - use super::*; - - // Helper function to get the document path - fn get_document_path() -> String { - let base_path = env!("CARGO_MANIFEST_DIR"); - format!("{}/../../docs/en/test_data/bm25_test_dataset/documents.json", base_path) - } - - // Helper function to get the queries path - fn get_queries_path() -> String { - let base_path = env!("CARGO_MANIFEST_DIR"); - format!("{}/../../docs/en/test_data/bm25_test_dataset/queries.json", base_path) - } - - #[test] - fn test_bm25_key_characteristics() { - // Test key characteristics of BM25F and BM25Plus instead of exact rankings - - // 1. Test that BM25F gives more weight to matches in title fields - test_field_weighting_in_bm25f(); - - // 2. Test that BM25+ handles rare terms better - test_rare_terms_in_bm25plus(); - - // 3. Test that both algorithms normalize document length - test_document_length_normalization(); - - // 4. Test that both algorithms handle term frequency saturation - test_term_frequency_saturation(); - - // 5. Test that both algorithms rank relevant documents higher - let documents = load_test_documents(&get_document_path()); - - // Test BM25F on a simple query - { - let query = "rust programming"; - let scored_docs = score_documents_bm25f(query, &documents); - let top_docs: Vec<&String> = scored_docs.iter().take(2).map(|(id, _)| id).collect(); - - println!("BM25F top docs for 'rust programming': {:?}", top_docs); - - // Check that doc1 and doc5 are in the top results (they're about Rust) - assert!( - top_docs.contains(&&"doc1".to_string()) && top_docs.contains(&&"doc5".to_string()), - "BM25F should rank doc1 and doc5 in the top results for 'rust programming'" - ); - } - - // Test BM25+ on a simple query - { - let query = "rust programming"; - let scored_docs = score_documents_bm25plus(query, &documents); - let top_docs: Vec<&String> = scored_docs.iter().take(2).map(|(id, _)| id).collect(); - - println!("BM25+ top docs for 'rust programming': {:?}", top_docs); - - // Check that doc1 and doc5 are in the top results (they're about Rust) - assert!( - top_docs.contains(&&"doc1".to_string()) && top_docs.contains(&&"doc5".to_string()), - "BM25+ should rank doc1 and doc5 in the top results for 'rust programming'" - ); - } - - // Test BM25F on another query - { - let query = "database systems"; - let scored_docs = score_documents_bm25f(query, &documents); - let top_doc = &scored_docs[0].0; - - println!("BM25F top doc for 'database systems': {}", top_doc); - - // Check that doc8 is the top result (it's about databases) - assert_eq!( - top_doc, "doc8", - "BM25F should rank doc8 as the top result for 'database systems'" - ); - } - - // Test BM25+ on another query - { - let query = "database systems"; - let scored_docs = score_documents_bm25plus(query, &documents); - let top_doc = &scored_docs[0].0; - - println!("BM25+ top doc for 'database systems': {}", top_doc); - - // Check that doc8 is the top result (it's about databases) - assert_eq!( - top_doc, "doc8", - "BM25+ should rank doc8 as the top result for 'database systems'" - ); - } - } - - #[test] - fn test_field_weighting_in_bm25f() { - let base_path = env!("CARGO_MANIFEST_DIR"); - let file_path = format!("{}/../../docs/en/test_data/bm25_test_dataset/field_weighting_test.json", base_path); - - // Load the field weighting test dataset - let path = Path::new(&file_path); - let mut file = File::open(path).expect("Failed to open field weighting test file"); - let mut contents = String::new(); - file.read_to_string(&mut contents).expect("Failed to read field weighting test file"); - - #[derive(Debug, Deserialize)] - struct FieldWeightingTest { - documents: Vec, - queries: Vec, - field_weights: HashMap>, - } - - let test_data: FieldWeightingTest = serde_json::from_str(&contents).expect("Failed to parse field weighting test"); - - // Convert test documents to terraphim documents - let documents: Vec = test_data.documents.iter() - .map(|doc| convert_test_document(doc)) - .collect(); - - // Test with title priority - if let Some(title_weights) = test_data.field_weights.get("title_priority") { - let field_weights = FieldWeights { - title: *title_weights.get("title").unwrap_or(&3.0), - body: *title_weights.get("body").unwrap_or(&1.0), - description: *title_weights.get("description").unwrap_or(&1.0), - tags: *title_weights.get("tags").unwrap_or(&1.0), - }; - - let params = BM25Params::default(); - let mut scorer = BM25FScorer::with_params(params, field_weights); - scorer.initialize(&documents); - - // Use the first query (fwq1) - let query = &test_data.queries[0]; - - let mut scored_docs: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = scorer.score(&query.query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - scored_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Extract just the document IDs for comparison - let ranked_ids: Vec = scored_docs.iter() - .map(|(id, _)| id.clone()) - .collect(); - - println!("Query: {} (title priority)", query.query); - println!("Ranking with title priority: {:?}", ranked_ids); - - // When title field is weighted higher, fw3 should rank higher than fw2 - // despite fw2 having more occurrences in the body - assert!( - scored_docs.iter().position(|(id, _)| id == "fw3").unwrap() < - scored_docs.iter().position(|(id, _)| id == "fw2").unwrap(), - "With title priority, fw3 should rank higher than fw2" - ); - } - } - - #[test] - fn test_rare_terms_in_bm25plus() { - let base_path = env!("CARGO_MANIFEST_DIR"); - let file_path = format!("{}/../../docs/en/test_data/bm25_test_dataset/rare_terms_test.json", base_path); - - // Load the rare terms test dataset - let path = Path::new(&file_path); - let mut file = File::open(path).expect("Failed to open rare terms test file"); - let mut contents = String::new(); - file.read_to_string(&mut contents).expect("Failed to read rare terms test file"); - - #[derive(Debug, Deserialize)] - struct RareTermsTest { - documents: Vec, - queries: Vec, - expected_results: HashMap>>, - } - - let test_data: RareTermsTest = serde_json::from_str(&contents).expect("Failed to parse rare terms test"); - - // Convert test documents to terraphim documents - let documents: Vec = test_data.documents.iter() - .map(|doc| convert_test_document(doc)) - .collect(); - - // Initialize scorers - let mut bm25plus_scorer = BM25PlusScorer::new(); - bm25plus_scorer.initialize(&documents); - - // Test with the first query (rtq1) - let query = &test_data.queries[0]; - - // Score documents with BM25+ - let mut bm25plus_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25plus_scorer.score(&query.query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - bm25plus_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Extract just the document IDs for comparison - let bm25plus_ranked_ids: Vec = bm25plus_scores.iter() - .map(|(id, _)| id.clone()) - .collect(); - - println!("Query: {} (rare terms)", query.query); - println!("BM25+ ranking: {:?}", bm25plus_ranked_ids); - - // Check if the expected document (rt3) is ranked first - assert_eq!( - bm25plus_ranked_ids[0], "rt3", - "BM25+ should rank rt3 first for query '{}'", query.query - ); - - // Check if BM25+ assigns scores to documents that don't contain the query terms - // This is a key feature of BM25+ - let non_matching_docs = bm25plus_scores.iter() - .filter(|(id, score)| *id != "rt3" && *score > 0.0) - .count(); - - assert!( - non_matching_docs > 0, - "BM25+ should assign scores to documents that don't contain the query terms" - ); - } - - #[test] - fn test_document_length_normalization() { - let base_path = env!("CARGO_MANIFEST_DIR"); - let file_path = format!("{}/../../docs/en/test_data/bm25_test_dataset/document_length_test.json", base_path); - - // Load the document length test dataset - let path = Path::new(&file_path); - let mut file = File::open(path).expect("Failed to open document length test file"); - let mut contents = String::new(); - file.read_to_string(&mut contents).expect("Failed to read document length test file"); - - #[derive(Debug, Deserialize)] - struct DocumentLengthTest { - documents: Vec, - queries: Vec, - } - - let test_data: DocumentLengthTest = serde_json::from_str(&contents).expect("Failed to parse document length test"); - - // Convert test documents to terraphim documents - let documents: Vec = test_data.documents.iter() - .map(|doc| convert_test_document(doc)) - .collect(); - - // Initialize scorers - let mut bm25f_scorer = BM25FScorer::new(); - bm25f_scorer.initialize(&documents); - - let mut bm25plus_scorer = BM25PlusScorer::new(); - bm25plus_scorer.initialize(&documents); - - // Test with the first query (dlq1) - let query = &test_data.queries[0]; - - // Score documents with BM25F - let mut bm25f_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25f_scorer.score(&query.query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - bm25f_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Extract just the document IDs for comparison - let bm25f_ranked_ids: Vec = bm25f_scores.iter() - .map(|(id, _)| id.clone()) - .collect(); - - println!("Query: {} (document length)", query.query); - println!("BM25F ranking: {:?}", bm25f_ranked_ids); - - // Score documents with BM25+ - let mut bm25plus_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25plus_scorer.score(&query.query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - bm25plus_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Extract just the document IDs for comparison - let bm25plus_ranked_ids: Vec = bm25plus_scores.iter() - .map(|(id, _)| id.clone()) - .collect(); - - println!("BM25+ ranking: {:?}", bm25plus_ranked_ids); - - // Check if shorter documents are ranked higher when term frequency is similar - assert!( - bm25f_ranked_ids.iter().position(|id| id == "dl2").unwrap() < - bm25f_ranked_ids.iter().position(|id| id == "dl5").unwrap(), - "BM25F should rank shorter document dl2 higher than longer document dl5" - ); - - assert!( - bm25plus_ranked_ids.iter().position(|id| id == "dl2").unwrap() < - bm25plus_ranked_ids.iter().position(|id| id == "dl5").unwrap(), - "BM25+ should rank shorter document dl2 higher than longer document dl5" - ); - } - - #[test] - fn test_term_frequency_saturation() { - let base_path = env!("CARGO_MANIFEST_DIR"); - let file_path = format!("{}/../../docs/en/test_data/bm25_test_dataset/term_frequency_test.json", base_path); - - // Load the term frequency test dataset - let path = Path::new(&file_path); - let mut file = File::open(path).expect("Failed to open term frequency test file"); - let mut contents = String::new(); - file.read_to_string(&mut contents).expect("Failed to read term frequency test file"); - - #[derive(Debug, Deserialize)] - struct TermFrequencyTest { - documents: Vec, - queries: Vec, - } - - let test_data: TermFrequencyTest = serde_json::from_str(&contents).expect("Failed to parse term frequency test"); - - // Convert test documents to terraphim documents - let documents: Vec = test_data.documents.iter() - .map(|doc| convert_test_document(doc)) - .collect(); - - // Initialize scorers - let mut bm25f_scorer = BM25FScorer::new(); - bm25f_scorer.initialize(&documents); - - let mut bm25plus_scorer = BM25PlusScorer::new(); - bm25plus_scorer.initialize(&documents); - - // Test with the first query (tfq1) - let query = &test_data.queries[0]; - - // Score documents with BM25F - let mut bm25f_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25f_scorer.score(&query.query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - bm25f_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Extract just the document IDs for comparison - let bm25f_ranked_ids: Vec = bm25f_scores.iter() - .map(|(id, _)| id.clone()) - .collect(); - - println!("Query: {} (term frequency)", query.query); - println!("BM25F ranking: {:?}", bm25f_ranked_ids); - - // Score documents with BM25+ - let mut bm25plus_scores: Vec<(String, f64)> = documents.iter() - .map(|doc| { - let score = bm25plus_scorer.score(&query.query, doc); - (doc.id.clone(), score) - }) - .collect(); - - // Sort by score in descending order - bm25plus_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Extract just the document IDs for comparison - let bm25plus_ranked_ids: Vec = bm25plus_scores.iter() - .map(|(id, _)| id.clone()) - .collect(); - - println!("BM25+ ranking: {:?}", bm25plus_ranked_ids); - - // Check if documents with extreme term frequency (tf5) are not ranked significantly higher - // than documents with high term frequency (tf3, tf4) - let tf5_position_bm25f = bm25f_ranked_ids.iter().position(|id| id == "tf5").unwrap_or(usize::MAX); - let tf3_position_bm25f = bm25f_ranked_ids.iter().position(|id| id == "tf3").unwrap_or(usize::MAX); - - let tf5_position_bm25plus = bm25plus_ranked_ids.iter().position(|id| id == "tf5").unwrap_or(usize::MAX); - let tf3_position_bm25plus = bm25plus_ranked_ids.iter().position(|id| id == "tf3").unwrap_or(usize::MAX); - - // tf5 should not be ranked significantly higher than tf3 despite having many more occurrences of 'rust' - assert!( - tf5_position_bm25f >= tf3_position_bm25f, - "BM25F should not rank tf5 significantly higher than tf3 despite having many more occurrences of 'rust'" - ); - - assert!( - tf5_position_bm25plus >= tf3_position_bm25plus, - "BM25+ should not rank tf5 significantly higher than tf3 despite having many more occurrences of 'rust'" - ); - } -} diff --git a/crates/terraphim_service/src/score/mod.rs b/crates/terraphim_service/src/score/mod.rs index ec134616a..6bdaff80d 100644 --- a/crates/terraphim_service/src/score/mod.rs +++ b/crates/terraphim_service/src/score/mod.rs @@ -1,366 +1 @@ -use std::f64; -use std::fmt; -use std::result; - -mod bm25; -pub mod bm25_additional; -#[cfg(test)] -mod bm25_test; -pub mod common; -mod names; -mod scored; -#[cfg(test)] -mod scorer_integration_test; - -use bm25::{BM25FScorer, BM25PlusScorer}; -use bm25_additional::{JaccardScorer, OkapiBM25Scorer, QueryRatioScorer, TFIDFScorer}; -pub use names::QueryScorer; -use scored::{Scored, SearchResults}; -use serde::{Serialize, Serializer}; - -use crate::ServiceError; -use terraphim_types::Document; - -/// Score module local Result type using TerraphimService's error enum. -type Result = std::result::Result; - -/// Sort the documents by relevance. -/// -/// The `relevance_function` parameter is used to determine how the documents -/// should be sorted. -pub fn sort_documents(query: &Query, documents: Vec) -> Vec { - let mut scorer = Scorer::new().with_similarity(query.similarity); - - // Initialize the appropriate scorer based on the query's name_scorer - match query.name_scorer { - QueryScorer::BM25 => { - let mut bm25_scorer = OkapiBM25Scorer::new(); - bm25_scorer.initialize(&documents); - scorer = scorer.with_scorer(Box::new(bm25_scorer)); - } - QueryScorer::BM25F => { - let mut bm25f_scorer = BM25FScorer::new(); - bm25f_scorer.initialize(&documents); - scorer = scorer.with_scorer(Box::new(bm25f_scorer)); - } - QueryScorer::BM25Plus => { - let mut bm25plus_scorer = BM25PlusScorer::new(); - bm25plus_scorer.initialize(&documents); - scorer = scorer.with_scorer(Box::new(bm25plus_scorer)); - } - QueryScorer::Tfidf => { - let mut tfidf_scorer = TFIDFScorer::new(); - tfidf_scorer.initialize(&documents); - scorer = scorer.with_scorer(Box::new(tfidf_scorer)); - } - QueryScorer::Jaccard => { - let mut jaccard_scorer = JaccardScorer::new(); - jaccard_scorer.initialize(&documents); - scorer = scorer.with_scorer(Box::new(jaccard_scorer)); - } - QueryScorer::QueryRatio => { - let mut query_ratio_scorer = QueryRatioScorer::new(); - query_ratio_scorer.initialize(&documents); - scorer = scorer.with_scorer(Box::new(query_ratio_scorer)); - } - _ => { - // For OkapiBM25 and other cases, use similarity scoring - } - } - - match scorer.score_documents(query, documents.clone()) { - Ok(results) => results - .into_iter() - .map(|scored| scored.into_value()) - .collect(), - Err(_) => documents, - } -} - -#[derive(Debug)] -pub struct Scorer { - similarity: Similarity, - scorer: Option>, -} - -impl Scorer { - pub fn new() -> Scorer { - Scorer { - similarity: Similarity::default(), - scorer: None, - } - } - - pub fn with_similarity(mut self, similarity: Similarity) -> Scorer { - self.similarity = similarity; - self - } - - pub fn with_scorer(mut self, scorer: Box) -> Scorer { - self.scorer = Some(scorer); - self - } - - /// Execute a search with the given `Query`. - /// - /// Generally, the results returned are ranked in relevance order, where - /// each result has a score associated with it. The score is between - /// `0` and `1.0` (inclusive), where a score of `1.0` means "most similar" - /// and a score of `0` means "least similar." - /// - /// Depending on the query, the behavior of search can vary: - /// - /// * When the query specifies a similarity function, then the results are - /// ranked by that function. - /// * When the query contains a name to search by and a name scorer, then - /// results are ranked by the name scorer. If the query specifies a - /// similarity function, then results are first ranked by the name - /// scorer, and then re-ranked by the similarity function. - /// * When no name or no name scorer are specified by the query, then - /// this search will do a (slow) exhaustive search over all media records - /// in IMDb. As a special case, if the query contains a TV show ID, then - /// only records in that TV show are searched, and this is generally - /// fast. - /// * If the query is empty, then no results are returned. - /// - /// If there was a problem reading the underlying index or the IMDb data, - /// then an error is returned. - #[allow(dead_code)] - pub fn score( - &mut self, - query: &Query, - documents: Vec, - ) -> Result> { - if query.is_empty() { - return Ok(SearchResults::new()); - } - let mut results = self.score_documents(query, documents)?; - results.trim(query.size); - results.normalize(); - Ok(results) - } - - fn score_documents( - &mut self, - query: &Query, - documents: Vec, - ) -> Result> { - let mut results = SearchResults::new(); - for document in documents { - results.push(Scored::new(document)); - } - - match query.name_scorer { - QueryScorer::BM25 => { - if let Some(scorer) = &self.scorer { - if let Some(bm25_scorer) = scorer.downcast_ref::() { - results.rescore(|document| bm25_scorer.score(&query.name, document)); - } - } - } - QueryScorer::BM25F => { - if let Some(scorer) = &self.scorer { - if let Some(bm25f_scorer) = scorer.downcast_ref::() { - results.rescore(|document| bm25f_scorer.score(&query.name, document)); - } - } - } - QueryScorer::BM25Plus => { - if let Some(scorer) = &self.scorer { - if let Some(bm25plus_scorer) = scorer.downcast_ref::() { - results.rescore(|document| bm25plus_scorer.score(&query.name, document)); - } - } - } - QueryScorer::Tfidf => { - if let Some(scorer) = &self.scorer { - if let Some(tfidf_scorer) = scorer.downcast_ref::() { - results.rescore(|document| tfidf_scorer.score(&query.name, document)); - } - } - } - QueryScorer::Jaccard => { - if let Some(scorer) = &self.scorer { - if let Some(jaccard_scorer) = scorer.downcast_ref::() { - results.rescore(|document| jaccard_scorer.score(&query.name, document)); - } - } - } - QueryScorer::QueryRatio => { - if let Some(scorer) = &self.scorer { - if let Some(query_ratio_scorer) = scorer.downcast_ref::() { - results.rescore(|document| query_ratio_scorer.score(&query.name, document)); - } - } - } - _ => { - // Fall back to similarity scoring for OkapiBM25 and other cases - log::debug!("Similarity {:?}", query.similarity); - log::debug!("Query {:?}", query); - results.rescore(|document| self.similarity(query, &document.title)); - } - } - - log::debug!("results after rescoring: {:#?}", results); - Ok(results) - } - - fn similarity(&self, query: &Query, name: &str) -> f64 { - log::debug!("Similarity {:?}", query.similarity); - log::debug!("Query {:?}", query); - log::debug!("Name {:?}", name); - let result = query.similarity.similarity(&query.name, name); - log::debug!("Similarity calculation {:?}", result); - result - } -} - -/// A simplified query structure for Terraphim document search. -/// -/// This is a streamlined version focused on document search functionality. -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -pub struct Query { - pub name: String, - pub name_scorer: QueryScorer, - pub similarity: Similarity, - pub size: usize, -} - -impl Query { - /// Create a new query with the given search term. - pub fn new(name: &str) -> Query { - Query { - name: name.to_string(), - name_scorer: QueryScorer::default(), - similarity: Similarity::default(), - size: 30, - } - } - - /// Return true if and only if this query is empty. - /// - /// Searching with an empty query always yields no results. - #[allow(dead_code)] - pub fn is_empty(&self) -> bool { - self.name.is_empty() - } - - /// Set the name scorer to use for ranking. - /// - /// The name scorer determines which algorithm is used to rank documents. - pub fn name_scorer(mut self, scorer: QueryScorer) -> Query { - self.name_scorer = scorer; - self - } - - /// Set the similarity function. - /// - /// The similarity function can be selected from a predefined set of - /// choices defined by the [`Similarity`](enum.Similarity.html) type. - /// - /// When a similarity function is used, then any results from searching - /// the name index are re-ranked according to their similarity with the - /// query. - /// - /// By default, no similarity function is used. - #[allow(dead_code)] - pub fn similarity(mut self, sim: Similarity) -> Query { - self.similarity = sim; - self - } -} - -impl Serialize for Query { - fn serialize(&self, s: S) -> result::Result - where - S: Serializer, - { - s.serialize_str(&self.to_string()) - } -} - -impl fmt::Display for Query { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{{scorer:{}}}", self.name_scorer)?; - write!(f, " {{sim:{}}}", self.similarity)?; - write!(f, " {{size:{}}}", self.size)?; - write!(f, " {}", self.name)?; - Ok(()) - } -} - -/// A ranking function to use when searching IMDb records. -/// -/// A similarity ranking function computes a score between `0.0` and `1.0` (not -/// including `0` but including `1.0`) for a query and a candidate result. The -/// score is determined by the corresponding names for a query and a candidate, -/// and a higher score indicates more similarity. -/// -/// This ranking function can be used to increase the precision of a set -/// of results. In particular, when a similarity function is provided to -/// a [`Query`](struct.Query.html), then any results returned by querying -/// the IMDb name index will be rescored according to this function. If no -/// similarity function is provided, then the results will be ranked according -/// to scores produced by the name index. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Default)] -pub enum Similarity { - /// Do not use a similarity function. - #[default] - None, - /// Computes the Levenshtein edit distance between two names and converts - /// it to a similarity. - #[allow(dead_code)] // Part of public API, documented in user guide - Levenshtein, - /// Computes the Jaro edit distance between two names and converts it to a - /// similarity. - #[allow(dead_code)] - Jaro, - /// Computes the Jaro-Winkler edit distance between two names and converts - /// it to a similarity. - #[allow(dead_code)] - JaroWinkler, -} - -impl Similarity { - /// Computes the similarity between the given strings according to the - /// underlying similarity function. If no similarity function is present, - /// then this always returns `1.0`. - /// - /// The returned value is always in the range `(0, 1]`. - pub fn similarity(&self, q1: &str, q2: &str) -> f64 { - let sim = match *self { - Similarity::None => 1.0, - Similarity::Levenshtein => { - let distance = strsim::levenshtein(q1, q2) as f64; - // We do a simple conversion of distance to similarity. This - // will produce very low scores even for very similar names, - // but callers may normalize scores. - // - // We also add `1` to the denominator to avoid division by - // zero. Incidentally, this causes the similarity of identical - // strings to be exactly 1.0, which is what we want. - 1.0 / (1.0 + distance) - } - Similarity::Jaro => strsim::jaro(q1, q2), - Similarity::JaroWinkler => strsim::jaro_winkler(q1, q2), - }; - // Don't permit a score to actually be zero. This prevents division - // by zero during normalization if all results have a score of zero. - if sim < f64::EPSILON { - f64::EPSILON - } else { - sim - } - } -} - -impl fmt::Display for Similarity { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Similarity::None => write!(f, "none"), - Similarity::Levenshtein => write!(f, "levenshtein"), - Similarity::Jaro => write!(f, "jaro"), - Similarity::JaroWinkler => write!(f, "jarowinkler"), - } - } -} +pub use terraphim_types::score::*; diff --git a/crates/terraphim_service/src/score/scorer_integration_test.rs b/crates/terraphim_service/src/score/scorer_integration_test.rs deleted file mode 100644 index 7bb0168d2..000000000 --- a/crates/terraphim_service/src/score/scorer_integration_test.rs +++ /dev/null @@ -1,189 +0,0 @@ -#[cfg(test)] -mod tests { - use super::super::*; - use terraphim_types::{Document, DocumentType}; - - fn create_test_documents() -> Vec { - vec![ - Document { - id: "doc1".to_string(), - title: "Rust Programming".to_string(), - body: - "Rust is a systems programming language that focuses on safety and performance" - .to_string(), - url: "http://example.com/doc1".to_string(), - description: Some("About Rust programming".to_string()), - summarization: None, - stub: None, - tags: Some(vec!["programming".to_string(), "rust".to_string()]), - rank: None, - source_haystack: None, - doc_type: DocumentType::KgEntry, - synonyms: None, - route: None, - priority: None, - }, - Document { - id: "doc2".to_string(), - title: "Python Development".to_string(), - body: "Python is a high-level programming language with dynamic typing".to_string(), - url: "http://example.com/doc2".to_string(), - description: Some("About Python development".to_string()), - summarization: None, - stub: None, - tags: Some(vec!["programming".to_string(), "python".to_string()]), - rank: None, - source_haystack: None, - doc_type: DocumentType::KgEntry, - synonyms: None, - route: None, - priority: None, - }, - Document { - id: "doc3".to_string(), - title: "Machine Learning".to_string(), - body: "Machine learning involves algorithms that improve through experience" - .to_string(), - url: "http://example.com/doc3".to_string(), - description: Some("About machine learning".to_string()), - summarization: None, - stub: None, - tags: Some(vec!["ai".to_string(), "ml".to_string()]), - rank: None, - source_haystack: None, - doc_type: DocumentType::KgEntry, - synonyms: None, - route: None, - priority: None, - }, - ] - } - - #[test] - fn test_okapi_bm25_scorer_integration() { - let documents = create_test_documents(); - let mut scorer = bm25_additional::OkapiBM25Scorer::new(); - scorer.initialize(&documents); - - // Test scoring - let score1 = scorer.score("programming", &documents[0]); - let score2 = scorer.score("programming", &documents[1]); - let score3 = scorer.score("programming", &documents[2]); - - // Documents 1 and 2 should have higher scores than document 3 for "programming" query - assert!(score1 > 0.0); - assert!(score2 > 0.0); - assert!(score3 >= 0.0); - assert!(score1 > score3); - assert!(score2 > score3); - } - - #[test] - fn test_jaccard_scorer_integration() { - let documents = create_test_documents(); - let mut scorer = bm25_additional::JaccardScorer::new(); - scorer.initialize(&documents); - - // Test scoring - let score1 = scorer.score("programming language", &documents[0]); - let score2 = scorer.score("programming language", &documents[1]); - let score3 = scorer.score("programming language", &documents[2]); - - // Documents 1 and 2 should have higher scores than document 3 - assert!(score1 > 0.0); - assert!(score2 > 0.0); - assert!(score3 >= 0.0); - } - - #[test] - fn test_query_ratio_scorer_integration() { - let documents = create_test_documents(); - let mut scorer = bm25_additional::QueryRatioScorer::new(); - scorer.initialize(&documents); - - // Test scoring - let score1 = scorer.score("rust systems", &documents[0]); - let score2 = scorer.score("rust systems", &documents[1]); - let score3 = scorer.score("rust systems", &documents[2]); - - // Document 1 should have the highest score for "rust systems" query - assert!(score1 > 0.0); - assert!(score1 >= score2); - assert!(score1 >= score3); - } - - #[test] - fn test_tfidf_scorer_integration() { - let documents = create_test_documents(); - let mut scorer = bm25_additional::TFIDFScorer::new(); - scorer.initialize(&documents); - - // Test scoring - let score1 = scorer.score("programming", &documents[0]); - let score2 = scorer.score("programming", &documents[1]); - let score3 = scorer.score("programming", &documents[2]); - - // Documents 1 and 2 should have higher scores than document 3 - assert!(score1 > 0.0); - assert!(score2 > 0.0); - assert!(score1 > score3); - assert!(score2 > score3); - } - - #[test] - fn test_with_params_functionality() { - use super::super::common::BM25Params; - - let params = BM25Params { - k1: 2.0, - b: 0.5, - delta: 0.0, - }; - - let documents = create_test_documents(); - let mut scorer = bm25_additional::OkapiBM25Scorer::with_params(params); - scorer.initialize(&documents); - - // Should work with custom parameters - let score = scorer.score("programming", &documents[0]); - assert!(score > 0.0); - } - - #[test] - fn test_sort_documents_with_different_scorers() { - let documents = create_test_documents(); - - // Test with BM25 scorer - let query = Query { - name: "programming".to_string(), - name_scorer: QueryScorer::BM25, - similarity: Similarity::default(), - size: 30, - }; - - let sorted_docs = sort_documents(&query, documents.clone()); - assert_eq!(sorted_docs.len(), 3); - - // Test with Jaccard scorer - let query = Query { - name: "programming".to_string(), - name_scorer: QueryScorer::Jaccard, - similarity: Similarity::default(), - size: 30, - }; - - let sorted_docs = sort_documents(&query, documents.clone()); - assert_eq!(sorted_docs.len(), 3); - - // Test with TFIDF scorer - let query = Query { - name: "programming".to_string(), - name_scorer: QueryScorer::Tfidf, - similarity: Similarity::default(), - size: 30, - }; - - let sorted_docs = sort_documents(&query, documents); - assert_eq!(sorted_docs.len(), 3); - } -} diff --git a/crates/terraphim_sessions/Cargo.toml b/crates/terraphim_sessions/Cargo.toml index 675a564bf..2b79513cd 100644 --- a/crates/terraphim_sessions/Cargo.toml +++ b/crates/terraphim_sessions/Cargo.toml @@ -32,8 +32,11 @@ extra-connectors = ["aider-connector", "cline-connector"] # Enable terraphim knowledge graph enrichment enrichment = ["terraphim_automata", "terraphim_rolegraph", "terraphim_types"] +# Enable BM25-ranked session search via terraphim_types score module +search-index = ["terraphim_types"] + # All features -full = ["tsa-full", "extra-connectors", "enrichment"] +full = ["tsa-full", "extra-connectors", "enrichment", "search-index"] [dependencies] # Core dependencies diff --git a/crates/terraphim_sessions/src/lib.rs b/crates/terraphim_sessions/src/lib.rs index 08a68ad6c..f41819b2e 100644 --- a/crates/terraphim_sessions/src/lib.rs +++ b/crates/terraphim_sessions/src/lib.rs @@ -37,6 +37,9 @@ pub mod cla; #[cfg(feature = "enrichment")] pub mod enrichment; +#[cfg(feature = "search-index")] +pub mod search; + // Re-exports for convenience pub use connector::{ConnectorRegistry, ConnectorStatus, ImportOptions, SessionConnector}; pub use model::{ @@ -50,5 +53,11 @@ pub use enrichment::{ SessionEnricher, find_related_sessions, search_by_concept, }; +#[cfg(feature = "search-index")] +pub use search::{search_sessions, session_to_document}; + +#[cfg(all(feature = "search-index", feature = "enrichment"))] +pub use search::search_sessions_hybrid; + /// Crate version pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/crates/terraphim_sessions/src/search.rs b/crates/terraphim_sessions/src/search.rs new file mode 100644 index 000000000..c24f869d9 --- /dev/null +++ b/crates/terraphim_sessions/src/search.rs @@ -0,0 +1,410 @@ +//! BM25-ranked session search adapter with KG hybrid boost +//! +//! Converts sessions into `terraphim_types::Document` instances and uses +//! the existing BM25 scoring infrastructure for ranked full-text search. +//! +//! When the `enrichment` feature is also enabled and a `Thesaurus` is provided, +//! sessions matching knowledge graph concepts for the current role are boosted +//! above pure BM25 results (hybrid search). + +use crate::model::{MessageRole, Session}; +use terraphim_types::score::{OkapiBM25Scorer, Query, QueryScorer, Scored, Scorer, SearchResults}; +use terraphim_types::{Document, DocumentType}; + +const MAX_BODY_LENGTH: usize = 50_000; +const MAX_SEARCH_RESULTS: usize = 50; +const MIN_SCORE_FRACTION: f64 = 0.1; + +/// Score multiplier applied to sessions with KG concept matches. +#[cfg(feature = "enrichment")] +const KG_BOOST_MULTIPLIER: f64 = 10_000.0; + +/// Adapter that converts a `Session` into a searchable `Document`. +pub fn session_to_document(session: &Session) -> Document { + let title = session + .title + .clone() + .unwrap_or_else(|| session.source.clone()); + + let body = build_body(session); + + Document { + id: session.id.clone(), + url: session.source_path.to_string_lossy().to_string(), + title, + body, + description: session.summary(), + summarization: None, + stub: None, + tags: if session.metadata.tags.is_empty() { + None + } else { + Some(session.metadata.tags.clone()) + }, + rank: None, + source_haystack: Some(session.source.clone()), + doc_type: DocumentType::default(), + synonyms: None, + route: None, + priority: None, + } +} + +fn build_body(session: &Session) -> String { + let mut parts: Vec = Vec::new(); + + if let Some(path) = &session.metadata.project_path { + parts.push(path.clone()); + } + + if let Some(model) = &session.metadata.model { + parts.push(model.clone()); + } + + for msg in &session.messages { + if msg.content.is_empty() { + continue; + } + let prefix = match msg.role { + MessageRole::User => "User: ", + MessageRole::Assistant => "Assistant: ", + MessageRole::System => "System: ", + MessageRole::Tool => "Tool: ", + MessageRole::Other => "", + }; + parts.push(format!("{}{}", prefix, msg.content)); + } + + let body = parts.join("\n"); + if body.len() > MAX_BODY_LENGTH { + let mut end = MAX_BODY_LENGTH; + while !body.is_char_boundary(end) { + end -= 1; + } + body[..end].to_string() + } else { + body + } +} + +/// Perform BM25-ranked search over sessions. +/// +/// Returns sessions ranked by relevance to the query, with BM25 scoring +/// applied to the combined title + message body text. +pub fn search_sessions(sessions: &[Session], query: &str) -> Vec> { + if query.trim().is_empty() || sessions.is_empty() { + return Vec::new(); + } + + let documents: Vec = sessions.iter().map(session_to_document).collect(); + + let mut bm25 = OkapiBM25Scorer::new(); + bm25.initialize(&documents); + + let mut q = Query::new(query); + q.name_scorer = QueryScorer::BM25; + q.size = MAX_SEARCH_RESULTS; + + let mut scorer = Scorer::new() + .with_scorer(Box::new(bm25)) + .with_similarity(terraphim_types::score::Similarity::None); + + let results: SearchResults = match scorer.score(&q, documents) { + Ok(r) => r, + Err(e) => { + tracing::warn!("BM25 scoring failed: {}", e); + return Vec::new(); + } + }; + + let session_map: std::collections::HashMap<&str, &Session> = + sessions.iter().map(|s| (s.id.as_str(), s)).collect(); + + let mut scored: Vec> = results + .as_slice() + .iter() + .filter_map(|scored_doc| { + let score = scored_doc.score(); + session_map + .get(scored_doc.value().id.as_str()) + .map(|session| Scored::new((*session).clone()).with_score(score)) + }) + .collect(); + + scored.sort_by(|a, b| { + b.score() + .partial_cmp(&a.score()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + if let Some(top_score) = scored.first().map(|s| s.score()) { + if top_score > 0.0 { + let threshold = top_score * MIN_SCORE_FRACTION; + scored.retain(|s| s.score() >= threshold); + } + } + + scored +} + +/// Hybrid search: KG concept matches for the current role come first, BM25 fallback. +/// +/// When a `Thesaurus` is provided, the query is matched against KG terms via +/// Aho-Corasick. Sessions whose enrichment data contains matching concepts +/// receive a large score boost so they always rank above pure BM25 results. +/// +/// Without a thesaurus (or when no sessions are enriched), falls back to pure BM25. +#[cfg(feature = "enrichment")] +pub fn search_sessions_hybrid( + sessions: &[Session], + query: &str, + thesaurus: Option, +) -> Vec> { + if query.trim().is_empty() || sessions.is_empty() { + return Vec::new(); + } + + let mut scored = search_sessions(sessions, query); + + let Some(thesaurus) = thesaurus else { + return scored; + }; + + let kg_terms = match extract_kg_terms(query, thesaurus) { + Ok(terms) if !terms.is_empty() => terms, + _ => return scored, + }; + + let kg_term_set: std::collections::HashSet = kg_terms + .iter() + .map(|t| t.normalized_term.value.as_str().to_lowercase()) + .collect(); + + if kg_term_set.is_empty() { + return scored; + } + + for scored_session in &mut scored { + let session = scored_session.value(); + let boost = compute_kg_boost(session, &kg_term_set); + if boost > 0 { + let new_score = scored_session.score() + (boost as f64 * KG_BOOST_MULTIPLIER); + scored_session.set_score(new_score); + } + } + + scored.sort_by(|a, b| { + b.score() + .partial_cmp(&a.score()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + scored +} + +#[cfg(feature = "enrichment")] +fn extract_kg_terms( + query: &str, + thesaurus: terraphim_types::Thesaurus, +) -> Result, terraphim_automata::TerraphimAutomataError> { + terraphim_automata::matcher::find_matches(query, thesaurus, false) +} + +#[cfg(feature = "enrichment")] +fn compute_kg_boost(session: &Session, kg_term_set: &std::collections::HashSet) -> usize { + let Some(enrichment) = &session.metadata.enrichment else { + return 0; + }; + + let mut match_count = 0usize; + for (normalized, concept) in &enrichment.concepts { + if kg_term_set.contains(&normalized.to_lowercase()) { + match_count += concept.count; + } + } + + match_count +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::{Message, SessionMetadata}; + use std::path::PathBuf; + + fn make_session(id: &str, title: &str, messages: Vec<(&str, MessageRole, &str)>) -> Session { + Session { + id: id.to_string(), + source: "test".to_string(), + external_id: id.to_string(), + title: if title.is_empty() { + None + } else { + Some(title.to_string()) + }, + source_path: PathBuf::from(format!("/sessions/{}.jsonl", id)), + started_at: None, + ended_at: None, + messages: messages + .into_iter() + .enumerate() + .map(|(i, (role, role_type, content))| { + let mut msg = Message::text(i, role_type, content); + msg.author = Some(role.to_string()); + msg + }) + .collect(), + metadata: SessionMetadata::default(), + } + } + + #[test] + fn test_session_to_document_basic() { + let session = make_session( + "s1", + "Rust async help", + vec![("user", MessageRole::User, "How do I use tokio?")], + ); + let doc = session_to_document(&session); + + assert_eq!(doc.id, "s1"); + assert_eq!(doc.title, "Rust async help"); + assert!(doc.body.contains("How do I use tokio?")); + assert!(doc.body.contains("User: ")); + assert_eq!(doc.source_haystack, Some("test".to_string())); + } + + #[test] + fn test_session_to_document_no_title() { + let session = make_session("s2", "", vec![]); + let doc = session_to_document(&session); + + assert_eq!(doc.title, "test"); + } + + #[test] + fn test_session_to_document_tags() { + let mut session = make_session("s3", "test", vec![]); + session.metadata.tags = vec!["rust".to_string(), "async".to_string()]; + let doc = session_to_document(&session); + + assert_eq!( + doc.tags, + Some(vec!["rust".to_string(), "async".to_string()]) + ); + } + + #[test] + fn test_search_sessions_basic() { + let sessions = vec![ + make_session( + "s1", + "Rust async programming", + vec![("user", MessageRole::User, "How to use async await in Rust?")], + ), + make_session( + "s2", + "Python web scraping", + vec![("user", MessageRole::User, "Best library for web scraping?")], + ), + make_session( + "s3", + "Rust error handling", + vec![("user", MessageRole::User, "How to handle errors in Rust?")], + ), + ]; + + let results = search_sessions(&sessions, "Rust async"); + + assert!(!results.is_empty()); + assert_eq!(results[0].value().id, "s1"); + } + + #[test] + fn test_search_sessions_empty_query() { + let sessions = vec![make_session("s1", "test", vec![])]; + let results = search_sessions(&sessions, ""); + assert!(results.is_empty()); + } + + #[test] + fn test_search_sessions_empty_sessions() { + let results = search_sessions(&[], "test query"); + assert!(results.is_empty()); + } + + #[test] + fn test_search_sessions_ranking_order() { + let sessions = vec![ + make_session( + "s1", + "Rust async", + vec![ + ("user", MessageRole::User, "Rust async Rust async Rust"), + ("assistant", MessageRole::Assistant, "async rust patterns"), + ], + ), + make_session( + "s2", + "General programming", + vec![("user", MessageRole::User, "What is async?")], + ), + make_session( + "s3", + "Unrelated", + vec![("user", MessageRole::User, "How to bake bread?")], + ), + ]; + + let results = search_sessions(&sessions, "Rust async"); + + assert!(!results.is_empty()); + assert!(results.len() <= 3); + + for window in results.windows(2) { + assert!(window[0].score() >= window[1].score()); + } + } + + #[test] + fn test_build_body_truncation() { + let long_content = "x".repeat(60_000); + let session = make_session( + "s1", + "test", + vec![("user", MessageRole::User, long_content.as_str())], + ); + + let body = build_body(&session); + assert_eq!(body.len(), MAX_BODY_LENGTH); + } + + #[test] + fn test_build_body_includes_metadata() { + let mut session = make_session("s1", "test", vec![]); + session.metadata.project_path = Some("/my/project".to_string()); + session.metadata.model = Some("claude-3".to_string()); + + let body = build_body(&session); + assert!(body.contains("/my/project")); + assert!(body.contains("claude-3")); + } + + #[test] + fn test_build_body_truncation_multibyte_utf8() { + let emoji = "🎉"; + let emoji_bytes = emoji.len(); + let count = (MAX_BODY_LENGTH / emoji_bytes) + 10; + let long_content: String = emoji.repeat(count); + let session = make_session( + "s1", + "test", + vec![("user", MessageRole::User, long_content.as_str())], + ); + + let body = build_body(&session); + assert!(body.len() <= MAX_BODY_LENGTH + emoji_bytes); + assert!(body.is_char_boundary(body.len())); + assert!(!body.is_empty()); + } +} diff --git a/crates/terraphim_sessions/src/service.rs b/crates/terraphim_sessions/src/service.rs index 3c4a4fef8..6b9f66cfc 100644 --- a/crates/terraphim_sessions/src/service.rs +++ b/crates/terraphim_sessions/src/service.rs @@ -178,43 +178,114 @@ impl SessionService { /// Search sessions by query string /// Auto-imports from available sources if cache is empty and auto-import is enabled + /// + /// When the `search-index` feature is enabled, uses BM25 scoring for + /// relevance-ranked results. Otherwise falls back to substring matching. pub async fn search(&self, query: &str) -> Vec { - // Try auto-import if needed + self.search_inner(query).await + } + + /// Hybrid search: KG concept matches from the role thesaurus rank first, BM25 fallback. + /// + /// Requires both `search-index` and `enrichment` features for the hybrid boost. + /// Without a thesaurus, falls back to plain BM25. + #[cfg(feature = "enrichment")] + pub async fn search_with_thesaurus( + &self, + query: &str, + thesaurus: Option, + ) -> Vec { + self.search_inner_with_thesaurus(query, thesaurus).await + } + + async fn search_inner(&self, query: &str) -> Vec { if let Err(e) = self.maybe_auto_import().await { tracing::warn!("Auto-import check failed: {}", e); } let cache = self.cache.read().await; - let query_lower = query.to_lowercase(); + let sessions: Vec = cache.values().cloned().collect(); + drop(cache); - cache - .values() - .filter(|session| { - // Search in title - if let Some(title) = &session.title { - if title.to_lowercase().contains(&query_lower) { - return true; - } - } + #[cfg(feature = "search-index")] + { + let scored = crate::search::search_sessions(&sessions, query); + scored.into_iter().map(|s| s.into_value()).collect() + } - // Search in project path - if let Some(path) = &session.metadata.project_path { - if path.to_lowercase().contains(&query_lower) { - return true; + #[cfg(not(feature = "search-index"))] + { + let query_lower = query.to_lowercase(); + sessions + .into_iter() + .filter(|session| { + if let Some(title) = &session.title { + if title.to_lowercase().contains(&query_lower) { + return true; + } } - } - - // Search in message content - for msg in &session.messages { - if msg.content.to_lowercase().contains(&query_lower) { - return true; + if let Some(path) = &session.metadata.project_path { + if path.to_lowercase().contains(&query_lower) { + return true; + } } - } + for msg in &session.messages { + if msg.content.to_lowercase().contains(&query_lower) { + return true; + } + } + false + }) + .collect() + } + } - false - }) - .cloned() - .collect() + #[cfg(feature = "enrichment")] + async fn search_inner_with_thesaurus( + &self, + query: &str, + thesaurus: Option, + ) -> Vec { + if let Err(e) = self.maybe_auto_import().await { + tracing::warn!("Auto-import check failed: {}", e); + } + + let cache = self.cache.read().await; + let sessions: Vec = cache.values().cloned().collect(); + drop(cache); + + #[cfg(feature = "search-index")] + { + let scored = crate::search::search_sessions_hybrid(&sessions, query, thesaurus); + scored.into_iter().map(|s| s.into_value()).collect() + } + + #[cfg(not(feature = "search-index"))] + { + let _ = thesaurus; + let query_lower = query.to_lowercase(); + sessions + .into_iter() + .filter(|session| { + if let Some(title) = &session.title { + if title.to_lowercase().contains(&query_lower) { + return true; + } + } + if let Some(path) = &session.metadata.project_path { + if path.to_lowercase().contains(&query_lower) { + return true; + } + } + for msg in &session.messages { + if msg.content.to_lowercase().contains(&query_lower) { + return true; + } + } + false + }) + .collect() + } } /// Get sessions by source @@ -441,16 +512,47 @@ mod tests { #[tokio::test] async fn test_search_by_title() { + use crate::model::{Message, MessageRole}; let service = SessionService::new(); - let sessions = vec![ - make_test_session("s1", "test", vec![]), - make_test_session("s2", "test", vec![]), - ]; - service.load_sessions(sessions).await; + let s1 = { + let mut s = make_test_session( + "s1", + "test", + vec![Message::text( + 0, + MessageRole::User, + "Rust async programming help", + )], + ); + s.title = Some("Rust async programming".to_string()); + s + }; + let s2 = { + let mut s = make_test_session( + "s2", + "test", + vec![Message::text( + 0, + MessageRole::User, + "Python web scraping tutorial", + )], + ); + s.title = Some("Python web scraping".to_string()); + s + }; + service.load_sessions(vec![s1, s2]).await; - let results = service.search("Session s1").await; - assert_eq!(results.len(), 1); - assert_eq!(results[0].id, "s1"); + let results = service.search("Rust async").await; + #[cfg(feature = "search-index")] + { + assert!(!results.is_empty()); + assert_eq!(results[0].id, "s1"); + } + #[cfg(not(feature = "search-index"))] + { + assert_eq!(results.len(), 1); + assert_eq!(results[0].id, "s1"); + } } #[tokio::test] @@ -488,8 +590,19 @@ mod tests { let sessions = vec![make_test_session("s1", "test", vec![])]; service.load_sessions(sessions).await; - let results = service.search("nonexistent-query").await; - assert!(results.is_empty()); + let results = service.search("xyzzy-zyzzyva-plugh").await; + #[cfg(feature = "search-index")] + { + // BM25 may return low-scoring results; verify top result is irrelevant + if !results.is_empty() { + // The top result should not be a strong match + assert!(results.iter().all(|r| r.id == "s1")); + } + } + #[cfg(not(feature = "search-index"))] + { + assert!(results.is_empty()); + } } #[tokio::test] diff --git a/crates/terraphim_types/Cargo.toml b/crates/terraphim_types/Cargo.toml index 8517bb9cb..a3ae8c68f 100644 --- a/crates/terraphim_types/Cargo.toml +++ b/crates/terraphim_types/Cargo.toml @@ -26,6 +26,7 @@ serde_json = { workspace = true } thiserror = { workspace = true } +strsim = "0.11.1" schemars = { version = "0.8.22", features = ["derive"] } tsify = { version = "0.5", features = ["js"], optional = true } wasm-bindgen = { version = "0.2", optional = true } diff --git a/crates/terraphim_types/src/lib.rs b/crates/terraphim_types/src/lib.rs index cd03edb11..e978a79be 100644 --- a/crates/terraphim_types/src/lib.rs +++ b/crates/terraphim_types/src/lib.rs @@ -97,6 +97,8 @@ pub mod shared_learning; pub mod capability; pub use capability::*; +pub mod score; + // MCP Tool types for self-learning system pub mod mcp_tool; pub use mcp_tool::*; diff --git a/crates/terraphim_service/src/score/bm25.rs b/crates/terraphim_types/src/score/bm25.rs similarity index 97% rename from crates/terraphim_service/src/score/bm25.rs rename to crates/terraphim_types/src/score/bm25.rs index 569376d75..85bb72cf1 100644 --- a/crates/terraphim_service/src/score/bm25.rs +++ b/crates/terraphim_types/src/score/bm25.rs @@ -2,9 +2,10 @@ use std::collections::HashMap; use std::f64; use super::common::{BM25Params, FieldWeights}; -use terraphim_types::Document; +use crate::Document; /// BM25F scorer implementation +#[derive(Default)] pub struct BM25FScorer { params: BM25Params, weights: FieldWeights, @@ -170,6 +171,7 @@ impl BM25FScorer { } /// BM25+ scorer implementation +#[derive(Default)] pub struct BM25PlusScorer { params: BM25Params, avg_doc_length: f64, @@ -308,7 +310,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "systems".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -324,7 +326,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "tutorial".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -363,7 +365,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "systems".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -379,7 +381,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "tutorial".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, diff --git a/crates/terraphim_service/src/score/bm25_additional.rs b/crates/terraphim_types/src/score/bm25_additional.rs similarity index 96% rename from crates/terraphim_service/src/score/bm25_additional.rs rename to crates/terraphim_types/src/score/bm25_additional.rs index 4a882d2e1..ea0470801 100644 --- a/crates/terraphim_service/src/score/bm25_additional.rs +++ b/crates/terraphim_types/src/score/bm25_additional.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; use super::common::BM25Params; -use terraphim_types::Document; +use crate::Document; /// Okapi BM25 scorer implementation +#[derive(Default)] pub struct OkapiBM25Scorer { params: BM25Params, avg_doc_length: f64, @@ -114,6 +115,7 @@ impl OkapiBM25Scorer { } /// TFIDF scorer implementation +#[derive(Default)] pub struct TFIDFScorer { doc_count: usize, term_doc_frequencies: HashMap, @@ -189,6 +191,7 @@ impl TFIDFScorer { } /// Jaccard similarity scorer implementation +#[derive(Default)] pub struct JaccardScorer { doc_count: usize, } @@ -242,6 +245,7 @@ impl JaccardScorer { } /// QueryRatio scorer implementation +#[derive(Default)] pub struct QueryRatioScorer { doc_count: usize, } @@ -319,7 +323,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "systems".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -335,7 +339,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "tutorial".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -374,7 +378,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "systems".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -390,7 +394,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "tutorial".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -429,7 +433,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "systems".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -445,7 +449,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "tutorial".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -484,7 +488,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "systems".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, @@ -500,7 +504,7 @@ mod tests { tags: Some(vec!["programming".to_string(), "tutorial".to_string()]), rank: None, source_haystack: None, - doc_type: terraphim_types::DocumentType::KgEntry, + doc_type: crate::DocumentType::KgEntry, synonyms: None, route: None, priority: None, diff --git a/crates/terraphim_service/src/score/common.rs b/crates/terraphim_types/src/score/common.rs similarity index 100% rename from crates/terraphim_service/src/score/common.rs rename to crates/terraphim_types/src/score/common.rs diff --git a/crates/terraphim_types/src/score/mod.rs b/crates/terraphim_types/src/score/mod.rs new file mode 100644 index 000000000..ae3bd1511 --- /dev/null +++ b/crates/terraphim_types/src/score/mod.rs @@ -0,0 +1,266 @@ +pub mod bm25; +pub mod bm25_additional; +pub mod common; +pub mod names; +mod scored; + +pub use bm25::{BM25FScorer, BM25PlusScorer}; +pub use bm25_additional::{JaccardScorer, OkapiBM25Scorer, QueryRatioScorer, TFIDFScorer}; +pub use names::QueryScorer; +pub use scored::{Scored, SearchResults}; + +use std::f64; +use std::fmt; + +use serde::{Serialize, Serializer}; + +use crate::Document; + +pub fn sort_documents(query: &Query, documents: Vec) -> Vec { + let mut scorer = Scorer::new().with_similarity(query.similarity); + + match query.name_scorer { + QueryScorer::BM25 => { + let mut bm25_scorer = OkapiBM25Scorer::new(); + bm25_scorer.initialize(&documents); + scorer = scorer.with_scorer(Box::new(bm25_scorer)); + } + QueryScorer::BM25F => { + let mut bm25f_scorer = BM25FScorer::new(); + bm25f_scorer.initialize(&documents); + scorer = scorer.with_scorer(Box::new(bm25f_scorer)); + } + QueryScorer::BM25Plus => { + let mut bm25plus_scorer = BM25PlusScorer::new(); + bm25plus_scorer.initialize(&documents); + scorer = scorer.with_scorer(Box::new(bm25plus_scorer)); + } + QueryScorer::Tfidf => { + let mut tfidf_scorer = TFIDFScorer::new(); + tfidf_scorer.initialize(&documents); + scorer = scorer.with_scorer(Box::new(tfidf_scorer)); + } + QueryScorer::Jaccard => { + let mut jaccard_scorer = JaccardScorer::new(); + jaccard_scorer.initialize(&documents); + scorer = scorer.with_scorer(Box::new(jaccard_scorer)); + } + QueryScorer::QueryRatio => { + let mut query_ratio_scorer = QueryRatioScorer::new(); + query_ratio_scorer.initialize(&documents); + scorer = scorer.with_scorer(Box::new(query_ratio_scorer)); + } + _ => {} + } + + match scorer.score_documents(query, documents.clone()) { + Ok(results) => results + .into_iter() + .map(|scored| scored.into_value()) + .collect(), + Err(_) => documents, + } +} + +#[derive(Debug, Default)] +pub struct Scorer { + similarity: Similarity, + scorer: Option>, +} + +impl Scorer { + pub fn new() -> Scorer { + Scorer::default() + } + + pub fn with_similarity(mut self, similarity: Similarity) -> Scorer { + self.similarity = similarity; + self + } + + pub fn with_scorer(mut self, scorer: Box) -> Scorer { + self.scorer = Some(scorer); + self + } + + pub fn score( + &mut self, + query: &Query, + documents: Vec, + ) -> Result, ScoreError> { + if query.is_empty() { + return Ok(SearchResults::new()); + } + let mut results = self.score_documents(query, documents)?; + results.trim(query.size); + results.normalize(); + Ok(results) + } + + fn score_documents( + &mut self, + query: &Query, + documents: Vec, + ) -> Result, ScoreError> { + let mut results = SearchResults::new(); + for document in documents { + results.push(Scored::new(document)); + } + + match query.name_scorer { + QueryScorer::BM25 => { + if let Some(scorer) = &self.scorer { + if let Some(bm25_scorer) = scorer.downcast_ref::() { + results.rescore(|document| bm25_scorer.score(&query.name, document)); + } + } + } + QueryScorer::BM25F => { + if let Some(scorer) = &self.scorer { + if let Some(bm25f_scorer) = scorer.downcast_ref::() { + results.rescore(|document| bm25f_scorer.score(&query.name, document)); + } + } + } + QueryScorer::BM25Plus => { + if let Some(scorer) = &self.scorer { + if let Some(bm25plus_scorer) = scorer.downcast_ref::() { + results.rescore(|document| bm25plus_scorer.score(&query.name, document)); + } + } + } + QueryScorer::Tfidf => { + if let Some(scorer) = &self.scorer { + if let Some(tfidf_scorer) = scorer.downcast_ref::() { + results.rescore(|document| tfidf_scorer.score(&query.name, document)); + } + } + } + QueryScorer::Jaccard => { + if let Some(scorer) = &self.scorer { + if let Some(jaccard_scorer) = scorer.downcast_ref::() { + results.rescore(|document| jaccard_scorer.score(&query.name, document)); + } + } + } + QueryScorer::QueryRatio => { + if let Some(scorer) = &self.scorer { + if let Some(query_ratio_scorer) = scorer.downcast_ref::() { + results.rescore(|document| query_ratio_scorer.score(&query.name, document)); + } + } + } + _ => { + log::debug!("Similarity {:?}", query.similarity); + log::debug!("Query {:?}", query); + results.rescore(|document| self.similarity(query, &document.title)); + } + } + + log::debug!("results after rescoring: {:#?}", results); + Ok(results) + } + + fn similarity(&self, query: &Query, name: &str) -> f64 { + log::debug!("Similarity {:?}", query.similarity); + log::debug!("Query {:?}", query); + log::debug!("Name {:?}", name); + let result = query.similarity.similarity(&query.name, name); + log::debug!("Similarity calculation {:?}", result); + result + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ScoreError { + #[error("scoring error: {0}")] + Scoring(String), +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct Query { + pub name: String, + pub name_scorer: QueryScorer, + pub similarity: Similarity, + pub size: usize, +} + +impl Query { + pub fn new(name: &str) -> Query { + Query { + name: name.to_string(), + name_scorer: QueryScorer::default(), + similarity: Similarity::default(), + size: 30, + } + } + + pub fn is_empty(&self) -> bool { + self.name.is_empty() + } + + pub fn name_scorer(mut self, scorer: QueryScorer) -> Query { + self.name_scorer = scorer; + self + } + + pub fn similarity(mut self, sim: Similarity) -> Query { + self.similarity = sim; + self + } +} + +impl Serialize for Query { + fn serialize(&self, s: S) -> std::result::Result { + s.serialize_str(&self.to_string()) + } +} + +impl fmt::Display for Query { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{{scorer:{}}}", self.name_scorer)?; + write!(f, " {{sim:{}}}", self.similarity)?; + write!(f, " {{size:{}}}", self.size)?; + write!(f, " {}", self.name)?; + Ok(()) + } +} + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Default)] +pub enum Similarity { + #[default] + None, + Levenshtein, + Jaro, + JaroWinkler, +} + +impl Similarity { + pub fn similarity(&self, q1: &str, q2: &str) -> f64 { + let sim = match *self { + Similarity::None => 1.0, + Similarity::Levenshtein => { + let distance = strsim::levenshtein(q1, q2) as f64; + 1.0 / (1.0 + distance) + } + Similarity::Jaro => strsim::jaro(q1, q2), + Similarity::JaroWinkler => strsim::jaro_winkler(q1, q2), + }; + if sim < f64::EPSILON { + f64::EPSILON + } else { + sim + } + } +} + +impl fmt::Display for Similarity { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Similarity::None => write!(f, "none"), + Similarity::Levenshtein => write!(f, "levenshtein"), + Similarity::Jaro => write!(f, "jaro"), + Similarity::JaroWinkler => write!(f, "jarowinkler"), + } + } +} diff --git a/crates/terraphim_service/src/score/names.rs b/crates/terraphim_types/src/score/names.rs similarity index 100% rename from crates/terraphim_service/src/score/names.rs rename to crates/terraphim_types/src/score/names.rs diff --git a/crates/terraphim_service/src/score/scored.rs b/crates/terraphim_types/src/score/scored.rs similarity index 100% rename from crates/terraphim_service/src/score/scored.rs rename to crates/terraphim_types/src/score/scored.rs