diff --git a/benches/agg_bench.rs b/benches/agg_bench.rs index 2f47bc0e4e..a9f3ce7f4b 100644 --- a/benches/agg_bench.rs +++ b/benches/agg_bench.rs @@ -1,5 +1,6 @@ use binggan::plugins::PeakMemAllocPlugin; use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM}; +use common::DateTime; use rand::prelude::SliceRandom; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -61,6 +62,12 @@ fn bench_agg(mut group: InputGroup) { register!(group, terms_many_with_avg_sub_agg); register!(group, terms_many_json_mixed_type_with_avg_sub_agg); + register!(group, composite_term_many_page_1000); + register!(group, composite_term_many_page_1000_with_avg_sub_agg); + register!(group, composite_term_few); + register!(group, composite_histogram); + register!(group, composite_histogram_calendar); + register!(group, cardinality_agg); register!(group, terms_few_with_cardinality_agg); @@ -225,6 +232,75 @@ fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) { }); execute_agg(index, agg_req); } +fn composite_term_few(index: &Index) { + let agg_req = json!({ + "my_ctf": { + "composite": { + "sources": [ + { "text_few_terms": { "terms": { "field": "text_few_terms" } } } + ], + "size": 1000 + } + }, + }); + execute_agg(index, agg_req); +} +fn composite_term_many_page_1000(index: &Index) { + let agg_req = json!({ + "my_ctmp1000": { + "composite": { + "sources": [ + { "text_many_terms": { "terms": { "field": "text_many_terms" } } } + ], + "size": 1000 + } + }, + }); + execute_agg(index, agg_req); +} +fn composite_term_many_page_1000_with_avg_sub_agg(index: &Index) { + let agg_req = json!({ + "my_ctmp1000wasa": { + "composite": { + "sources": [ + { "text_many_terms": { "terms": { "field": "text_many_terms" } } } + ], + "size": 1000, + + }, + "aggs": { + "average_f64": { "avg": { "field": "score_f64" } } + } + }, + }); + execute_agg(index, agg_req); +} +fn composite_histogram(index: &Index) { + let agg_req = json!({ + "my_ch": { + "composite": { + "sources": [ + { "f64_histogram": { "histogram": { "field": "score_f64", "interval": 1 } } } + ], + "size": 1000 + } + }, + }); + execute_agg(index, agg_req); +} +fn composite_histogram_calendar(index: &Index) { + let agg_req = json!({ + "my_chc": { + "composite": { + "sources": [ + { "time_histogram": { "date_histogram": { "field": "timestamp", "calendar_interval": "month" } } } + ], + "size": 1000 + } + }, + }); + execute_agg(index, agg_req); +} fn execute_agg(index: &Index, agg_req: serde_json::Value) { let agg_req: Aggregations = serde_json::from_value(agg_req).unwrap(); @@ -404,6 +480,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone()); let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype); + let date_field = schema_builder.add_date_field("timestamp", FAST); let index = Index::create_from_tempdir(schema_builder.build())?; let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"]; @@ -459,6 +536,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result { score_field => val as u64, score_field_f64 => lg_norm.sample(&mut rng), score_field_i64 => val as i64, + date_field => DateTime::from_timestamp_millis((val * 1_000_000.) as i64), ))?; if cardinality == Cardinality::OptionalSparse { for _ in 0..20 { diff --git a/columnar/src/column_values/mod.rs b/columnar/src/column_values/mod.rs index f26bf6d337..cdab3be3e0 100644 --- a/columnar/src/column_values/mod.rs +++ b/columnar/src/column_values/mod.rs @@ -31,7 +31,7 @@ pub use u64_based::{ serialize_and_load_u64_based_column_values, serialize_u64_based_column_values, }; pub use u128_based::{ - CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped, + CompactHit, CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped, serialize_column_values_u128, }; pub use vec_column::VecColumn; diff --git a/columnar/src/column_values/u128_based/compact_space/mod.rs b/columnar/src/column_values/u128_based/compact_space/mod.rs index 2c815bdce7..8255456ee5 100644 --- a/columnar/src/column_values/u128_based/compact_space/mod.rs +++ b/columnar/src/column_values/u128_based/compact_space/mod.rs @@ -292,6 +292,19 @@ impl BinarySerializable for IPCodecParams { } } +/// Represents the result of looking up a u128 value in the compact space. +/// +/// If a value is outside the compact space, the next compact value is returned. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CompactHit { + /// The value exists in the compact space + Exact(u32), + /// The value does not exist in the compact space, but the next higher value does + Next(u32), + /// The value is greater than the maximum compact value + AfterLast, +} + /// Exposes the compact space compressed values as u64. /// /// This allows faster access to the values, as u64 is faster to work with than u128. @@ -309,6 +322,11 @@ impl CompactSpaceU64Accessor { pub fn compact_to_u128(&self, compact: u32) -> u128 { self.0.compact_to_u128(compact) } + + /// Finds the next compact space value for a given u128 value. + pub fn u128_to_next_compact(&self, value: u128) -> CompactHit { + self.0.u128_to_next_compact(value) + } } impl ColumnValues for CompactSpaceU64Accessor { @@ -430,6 +448,26 @@ impl CompactSpaceDecompressor { Ok(decompressor) } + /// Finds the next compact space value for a given u128 value + pub fn u128_to_next_compact(&self, value: u128) -> CompactHit { + // Try to convert to compact space + match self.u128_to_compact(value) { + // Value is in compact space, return its compact representation + Ok(compact) => CompactHit::Exact(compact), + // Value is not in compact space + Err(pos) => { + if pos >= self.params.compact_space.ranges_mapping.len() { + // Value is beyond all ranges, no next value exists + CompactHit::AfterLast + } else { + // Get the next range and return its start compact value + let next_range = &self.params.compact_space.ranges_mapping[pos]; + CompactHit::Next(next_range.compact_start) + } + } + } + } + /// Converting to compact space for the decompressor is more complex, since we may get values /// which are outside the compact space. e.g. if we map /// 1000 => 5 @@ -823,6 +861,41 @@ mod tests { let _data = test_aux_vals(vals); } + #[test] + fn test_u128_to_next_compact() { + let vals = &[100u128, 200u128, 1_000_000_000u128, 1_000_000_100u128]; + let mut data = test_aux_vals(vals); + + let _header = U128Header::deserialize(&mut data); + let decomp = CompactSpaceDecompressor::open(data).unwrap(); + + // Test value that's already in a range + let compact_100 = decomp.u128_to_compact(100).unwrap(); + assert_eq!( + decomp.u128_to_next_compact(100), + CompactHit::Exact(compact_100) + ); + + // Test value between two ranges + let compact_million = decomp.u128_to_compact(1_000_000_000).unwrap(); + assert_eq!( + decomp.u128_to_next_compact(250), + CompactHit::Next(compact_million) + ); + + // Test value before the first range + assert_eq!( + decomp.u128_to_next_compact(50), + CompactHit::Next(compact_100) + ); + + // Test value after the last range + assert_eq!( + decomp.u128_to_next_compact(10_000_000_000), + CompactHit::AfterLast + ); + } + use proptest::prelude::*; fn num_strategy() -> impl Strategy { diff --git a/columnar/src/column_values/u128_based/mod.rs b/columnar/src/column_values/u128_based/mod.rs index 62e9a1f929..d26f5ce353 100644 --- a/columnar/src/column_values/u128_based/mod.rs +++ b/columnar/src/column_values/u128_based/mod.rs @@ -7,7 +7,7 @@ mod compact_space; use common::{BinarySerializable, OwnedBytes, VInt}; pub use compact_space::{ - CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor, + CompactHit, CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor, }; use crate::column_values::monotonic_map_column; diff --git a/columnar/src/lib.rs b/columnar/src/lib.rs index 0925d912de..2055aceb33 100644 --- a/columnar/src/lib.rs +++ b/columnar/src/lib.rs @@ -59,7 +59,7 @@ pub struct RowAddr { pub row_id: RowId, } -pub use sstable::Dictionary; +pub use sstable::{Dictionary, TermOrdHit}; pub type Streamer<'a> = sstable::Streamer<'a, VoidSSTable>; pub use common::DateTime; diff --git a/src/aggregation/accessor_helpers.rs b/src/aggregation/accessor_helpers.rs index eb44a734b9..439bb455fb 100644 --- a/src/aggregation/accessor_helpers.rs +++ b/src/aggregation/accessor_helpers.rs @@ -94,11 +94,21 @@ pub(crate) fn get_all_ff_reader_or_empty( allowed_column_types: Option<&[ColumnType]>, fallback_type: ColumnType, ) -> crate::Result, ColumnType)>> { - let ff_fields = reader.fast_fields(); - let mut ff_field_with_type = - ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?; + let mut ff_field_with_type = get_all_ff_readers(reader, field_name, allowed_column_types)?; if ff_field_with_type.is_empty() { ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type)); } Ok(ff_field_with_type) } + +/// Get all fast field reader. +pub(crate) fn get_all_ff_readers( + reader: &SegmentReader, + field_name: &str, + allowed_column_types: Option<&[ColumnType]>, +) -> crate::Result, ColumnType)>> { + let ff_fields = reader.fast_fields(); + let ff_field_with_type = + ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?; + Ok(ff_field_with_type) +} diff --git a/src/aggregation/agg_data.rs b/src/aggregation/agg_data.rs index c7dc2e4e60..7117760373 100644 --- a/src/aggregation/agg_data.rs +++ b/src/aggregation/agg_data.rs @@ -9,10 +9,12 @@ use crate::aggregation::accessor_helpers::{ get_numeric_or_date_column_types, }; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; +pub use crate::aggregation::bucket::{CompositeAggReqData, CompositeSourceAccessors}; use crate::aggregation::bucket::{ - HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, - RangeAggReqData, SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector, - TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal, + CompositeAggregation, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, + MissingTermAggReqData, RangeAggReqData, SegmentCompositeCollector, SegmentHistogramCollector, + SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, + TermsAggregationInternal, }; use crate::aggregation::metric::{ AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, @@ -67,6 +69,12 @@ impl AggregationsSegmentCtx { self.per_request.range_req_data.push(Some(Box::new(data))); self.per_request.range_req_data.len() - 1 } + pub(crate) fn push_composite_req_data(&mut self, data: CompositeAggReqData) -> usize { + self.per_request + .composite_req_data + .push(Some(Box::new(data))); + self.per_request.composite_req_data.len() - 1 + } #[inline] pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData { @@ -102,6 +110,12 @@ impl AggregationsSegmentCtx { .as_deref() .expect("range_req_data slot is empty (taken)") } + #[inline] + pub(crate) fn get_composite_req_data(&self, idx: usize) -> &CompositeAggReqData { + self.per_request.composite_req_data[idx] + .as_deref() + .expect("composite_req_data slot is empty (taken)") + } // ---------- mutable getters ---------- @@ -128,8 +142,14 @@ impl AggregationsSegmentCtx { .as_deref_mut() .expect("histogram_req_data slot is empty (taken)") } + #[inline] + pub(crate) fn get_composite_req_data_mut(&mut self, idx: usize) -> &mut CompositeAggReqData { + self.per_request.composite_req_data[idx] + .as_deref_mut() + .expect("composite_req_data slot is empty (taken)") + } - // ---------- take / put (terms, histogram, range) ---------- + // ---------- take / put (terms, histogram, range, composite) ---------- /// Move out the boxed Terms request at `idx`, leaving `None`. #[inline] @@ -179,6 +199,25 @@ impl AggregationsSegmentCtx { debug_assert!(self.per_request.range_req_data[idx].is_none()); self.per_request.range_req_data[idx] = Some(value); } + + /// Move out the Composite request at `idx`. + #[inline] + pub(crate) fn take_composite_req_data(&mut self, idx: usize) -> Box { + self.per_request.composite_req_data[idx] + .take() + .expect("composite_req_data slot is empty (taken)") + } + + /// Put back a Composite request into an empty slot at `idx`. + #[inline] + pub(crate) fn put_back_composite_req_data( + &mut self, + idx: usize, + value: Box, + ) { + debug_assert!(self.per_request.composite_req_data[idx].is_none()); + self.per_request.composite_req_data[idx] = Some(value); + } } /// Each type of aggregation has its own request data struct. This struct holds @@ -196,6 +235,8 @@ pub struct PerRequestAggSegCtx { pub histogram_req_data: Vec>>, /// RangeAggReqData contains the request data for a range aggregation. pub range_req_data: Vec>>, + /// CompositeAggReqData contains the request data for a composite aggregation. + pub composite_req_data: Vec>>, /// Shared by avg, min, max, sum, stats, extended_stats, count pub stats_metric_req_data: Vec, /// CardinalityAggReqData contains the request data for a cardinality aggregation. @@ -246,6 +287,11 @@ impl PerRequestAggSegCtx { .iter() .map(|t| t.get_memory_consumption()) .sum::() + + self + .composite_req_data + .iter() + .map(|t| t.as_ref().unwrap().get_memory_consumption()) + .sum::() + self.agg_tree.len() * std::mem::size_of::() } @@ -277,6 +323,11 @@ impl PerRequestAggSegCtx { .expect("range_req_data slot is empty (taken)") .name .as_str(), + AggKind::Composite => &self.composite_req_data[idx] + .as_deref() + .expect("composite_req_data slot is empty (taken)") + .name + .as_str(), } } @@ -394,6 +445,9 @@ pub(crate) fn build_segment_agg_collector( AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( req, node, )?)), + AggKind::Composite => Ok(Box::new(SegmentCompositeCollector::from_req_and_validate( + req, node, + )?)), } } @@ -423,6 +477,7 @@ pub enum AggKind { Histogram, DateHistogram, Range, + Composite, } impl AggKind { @@ -437,6 +492,7 @@ impl AggKind { AggKind::Histogram => "Histogram", AggKind::DateHistogram => "DateHistogram", AggKind::Range => "Range", + AggKind::Composite => "Composite", } } } @@ -500,6 +556,7 @@ fn build_nodes( field_type, column_block_accessor: Default::default(), name: agg_name.to_string(), + // will be filled later when building collectors sub_aggregation_blueprint: None, req: histo_req.clone(), is_date_histogram: false, @@ -527,6 +584,7 @@ fn build_nodes( field_type, column_block_accessor: Default::default(), name: agg_name.to_string(), + // will be filled later when building collectors sub_aggregation_blueprint: None, req: histo_req, is_date_histogram: true, @@ -686,6 +744,14 @@ fn build_nodes( children, }]) } + AggregationVariants::Composite(composite_req) => Ok(vec![build_composite_node( + agg_name, + reader, + segment_ordinal, + data, + &req.sub_aggregation, + composite_req, + )?]), } } @@ -876,6 +942,37 @@ fn build_terms_or_cardinality_nodes( Ok(nodes) } +fn build_composite_node( + agg_name: &str, + reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, + data: &mut AggregationsSegmentCtx, + sub_aggs: &Aggregations, + req: &CompositeAggregation, +) -> crate::Result { + let mut composite_accessors = Vec::with_capacity(req.sources.len()); + for source in &req.sources { + let source_after_key_opt = req.after.get(source.name()); + let source_accessor = + CompositeSourceAccessors::build_for_source(reader, source, source_after_key_opt)?; + composite_accessors.push(source_accessor); + } + let agg = CompositeAggReqData { + name: agg_name.to_string(), + req: req.clone(), + composite_accessors, + // fields below will be filled later when building collectors + sub_aggregation_blueprint: None, + }; + let idx = data.push_composite_req_data(agg); + let children = build_children(sub_aggs, reader, segment_ordinal, data)?; + Ok(AggRefNode { + kind: AggKind::Composite, + idx_in_req_data: idx, + children, + }) +} + /// Builds a single BitSet of allowed term ordinals for a string dictionary column according to /// include/exclude parameters. fn build_allowed_term_ids_for_str( diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index e5dfed85a5..3228fd17af 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -39,6 +39,7 @@ use super::metric::{ MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregationReq, }; +use crate::aggregation::bucket::CompositeAggregation; /// The top-level aggregation request structure, which contains [`Aggregation`] and their user /// defined names. It is also used in buckets aggregations to define sub-aggregations. @@ -130,6 +131,9 @@ pub enum AggregationVariants { /// Put data into buckets of terms. #[serde(rename = "terms")] Terms(TermsAggregation), + /// Put data into multi level paginated buckets. + #[serde(rename = "composite")] + Composite(CompositeAggregation), // Metric aggregation types /// Computes the average of the extracted values. @@ -175,6 +179,11 @@ impl AggregationVariants { AggregationVariants::Range(range) => vec![range.field.as_str()], AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()], AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()], + AggregationVariants::Composite(composite) => composite + .sources + .iter() + .map(|source_map| source_map.field()) + .collect(), AggregationVariants::Average(avg) => vec![avg.field_name()], AggregationVariants::Count(count) => vec![count.field_name()], AggregationVariants::Max(max) => vec![max.field_name()], @@ -209,6 +218,14 @@ impl AggregationVariants { _ => None, } } + + pub(crate) fn as_composite(&self) -> Option<&CompositeAggregation> { + match &self { + AggregationVariants::Composite(composite) => Some(composite), + _ => None, + } + } + pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> { match &self { AggregationVariants::Percentiles(percentile_req) => Some(percentile_req), diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 0055681825..0095548fe0 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -13,6 +13,7 @@ use super::metric::{ ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult, }; use super::{AggregationError, Key}; +use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey; use crate::TantivyError; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] @@ -156,6 +157,16 @@ pub enum BucketResult { /// The upper bound error for the doc count of each term. doc_count_error_upper_bound: Option, }, + /// This is the composite aggregation result + Composite { + /// The buckets + /// + /// See [`CompositeAggregation`](super::bucket::CompositeAggregation) + buckets: Vec, + /// The key to start after when paginating + #[serde(skip_serializing_if = "FxHashMap::is_empty")] + after_key: FxHashMap, + }, } impl BucketResult { @@ -172,6 +183,9 @@ impl BucketResult { sum_other_doc_count: _, doc_count_error_upper_bound: _, } => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(), + BucketResult::Composite { buckets, .. } => { + buckets.iter().map(|bucket| bucket.get_bucket_count()).sum() + } } } } @@ -308,3 +322,131 @@ impl RangeBucketEntry { 1 + self.sub_aggregation.get_bucket_count() } } + +/// The JSON mappable key to identify a composite bucket. +/// +/// This is similar to `Key`, but composite keys can also be boolean and null. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum CompositeKey { + /// Boolean key + Bool(bool), + /// String key + Str(String), + /// `i64` key + I64(i64), + /// `u64` key + U64(u64), + /// `f64` key + F64(f64), + /// Null key + Null, +} +impl Eq for CompositeKey {} +impl std::hash::Hash for CompositeKey { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + match self { + Self::Bool(val) => val.hash(state), + Self::Str(text) => text.hash(state), + Self::F64(val) => val.to_bits().hash(state), + Self::U64(val) => val.hash(state), + Self::I64(val) => val.hash(state), + Self::Null => {} + } + } +} +impl PartialEq for CompositeKey { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Bool(l), Self::Bool(r)) => l == r, + (Self::Str(l), Self::Str(r)) => l == r, + (Self::F64(l), Self::F64(r)) => l.to_bits() == r.to_bits(), + (Self::I64(l), Self::I64(r)) => l == r, + (Self::U64(l), Self::U64(r)) => l == r, + (Self::Null, Self::Null) => true, + ( + Self::Bool(_) + | Self::Str(_) + | Self::F64(_) + | Self::I64(_) + | Self::U64(_) + | Self::Null, + _, + ) => false, + } + } +} +impl From for CompositeKey { + fn from(value: CompositeIntermediateKey) -> Self { + match value { + CompositeIntermediateKey::Str(s) => Self::Str(s), + CompositeIntermediateKey::IpAddr(s) => { + // Prefer to use the IPv4 representation if possible + if let Some(ip) = s.to_ipv4_mapped() { + Self::Str(ip.to_string()) + } else { + Self::Str(s.to_string()) + } + } + CompositeIntermediateKey::F64(f) => Self::F64(f), + CompositeIntermediateKey::Bool(f) => Self::Bool(f), + CompositeIntermediateKey::U64(f) => Self::U64(f), + CompositeIntermediateKey::I64(f) => Self::I64(f), + CompositeIntermediateKey::DateTime(f) => Self::I64(f), + CompositeIntermediateKey::Null => Self::Null, + } + } +} + +/// This is the default entry for a bucket, which contains a composite key, count, and optionally +/// sub-aggregations. +/// +/// # JSON Format +/// ```json +/// { +/// ... +/// "my_composite": { +/// "buckets": [ +/// { +/// "key": { +/// "date": 1494201600000, +/// "product": "rocky" +/// }, +/// "doc_count": 5 +/// }, +/// { +/// "key": { +/// "date": 1494201600000, +/// "product": "balboa" +/// }, +/// "doc_count": 2 +/// }, +/// { +/// "key": { +/// "date": 1494201700000, +/// "product": "john" +/// }, +/// "doc_count": 3 +/// } +/// ] +/// } +/// ... +/// } +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct CompositeBucketEntry { + /// The identifier of the bucket. + pub key: FxHashMap, + /// Number of documents in the bucket. + pub doc_count: u64, + #[serde(flatten)] + /// Sub-aggregations in this bucket. + pub sub_aggregation: AggregationResults, +} + +impl CompositeBucketEntry { + pub(crate) fn get_bucket_count(&self) -> u64 { + 1 + self.sub_aggregation.get_bucket_count() + } +} diff --git a/src/aggregation/bucket/composite/accessors.rs b/src/aggregation/bucket/composite/accessors.rs new file mode 100644 index 0000000000..9c27b18f69 --- /dev/null +++ b/src/aggregation/bucket/composite/accessors.rs @@ -0,0 +1,657 @@ +use std::fmt::Debug; +use std::net::IpAddr; +use std::str::FromStr; + +use columnar::column_values::{CompactHit, CompactSpaceU64Accessor}; +use columnar::{Column, ColumnType, MonotonicallyMappableToU64, StrColumn, TermOrdHit}; + +use crate::aggregation::accessor_helpers::{get_all_ff_readers, get_numeric_or_date_column_types}; +use crate::aggregation::agg_result::CompositeKey; +use crate::aggregation::bucket::composite::numeric_types::num_proj; +use crate::aggregation::bucket::composite::numeric_types::num_proj::ProjectedNumber; +use crate::aggregation::bucket::{ + parse_into_milliseconds, CalendarInterval, CompositeAggregation, CompositeAggregationSource, + MissingOrder, Order, +}; +use crate::aggregation::date::parse_date; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::schema::IntoIpv6Addr; +use crate::{SegmentReader, TantivyError}; + +/// Contains all information required by the SegmentCompositeCollector to perform the +/// composite aggregation on a segment. +pub struct CompositeAggReqData { + /// Note: sub_aggregation_blueprint is filled later when building collectors + pub sub_aggregation_blueprint: Option>, + /// The name of the aggregation. + pub name: String, + /// The normalized term aggregation request. + pub req: CompositeAggregation, + /// Accessors for each source, each source can have multiple accessors (columns). + pub composite_accessors: Vec, +} + +impl CompositeAggReqData { + /// Estimate the memory consumption of this struct in bytes. + pub fn get_memory_consumption(&self) -> usize { + std::mem::size_of::() + + self.composite_accessors.len() * std::mem::size_of::() + } +} + +/// Accessors for a single column in a composite source. +pub struct CompositeAccessor { + /// The fast field column + pub column: Column, + /// The column type + pub column_type: ColumnType, + /// Term dictionary if the column type is Str + /// + /// Only used by term sources + pub str_dict_column: Option, + /// Parsed date interval for date histogram sources + pub date_histogram_interval: PrecomputedDateInterval, +} + +/// Accessors to all the columns that belong to the field of a composite source. +pub struct CompositeSourceAccessors { + /// The accessors for this source + pub accessors: Vec, + /// The key after which to start collecting results. Applies to the first + /// column of the source. + pub after_key: PrecomputedAfterKey, + + /// The column index the after_key applies to. The after_key only applies to + /// one column. Columns before should be skipped. Columns after should be + /// kept without comparison to the after_key. + pub after_key_accessor_idx: usize, + + /// Whether to skip missing values because of the after_key. Skipping only + /// applies if the value for previous columns were exactly equal to the + /// corresponding after keys (is_on_after_key). + pub skip_missing: bool, + + /// The after key was set to null to indicate that the last collected key + /// was a missing value. + pub is_after_key_explicit_missing: bool, +} + +impl CompositeSourceAccessors { + /// Creates a new set of accessors for the composite source. + /// + /// Precomputes some values to make collection faster. + pub fn build_for_source( + reader: &SegmentReader, + source: &CompositeAggregationSource, + // First option is None when no after key was set in the query, the + // second option is None when the after key was set but its value for + // this source was set to `null` + source_after_key_opt: Option<&CompositeKey>, + ) -> crate::Result { + let is_after_key_explicit_missing = source_after_key_opt + .map(|after_key| matches!(after_key, CompositeKey::Null)) + .unwrap_or(false); + let mut skip_missing = false; + if let Some(CompositeKey::Null) = source_after_key_opt { + if !source.missing_bucket() { + return Err(TantivyError::InvalidArgument( + "the 'after' key for a source cannot be null when 'missing_bucket' is false" + .to_string(), + )); + } + } else if source_after_key_opt.is_some() { + // if missing buckets come first and we have a non null after key, we skip missing + if MissingOrder::First == source.missing_order() { + skip_missing = true; + } + if MissingOrder::Default == source.missing_order() && Order::Asc == source.order() { + skip_missing = true; + } + }; + + match source { + CompositeAggregationSource::Terms(source) => { + let allowed_column_types = [ + ColumnType::I64, + ColumnType::U64, + ColumnType::F64, + ColumnType::Str, + ColumnType::DateTime, + ColumnType::Bool, + ColumnType::IpAddr, + // ColumnType::Bytes Unsupported + ]; + let mut columns_and_types = + get_all_ff_readers(reader, &source.field, Some(&allowed_column_types))?; + + columns_and_types + .sort_by_key(|(_, col_type)| col_type_order_key(col_type, source.order)); + let mut after_key_accessor_idx = 0; + if let Some(source_after_key_explicit_opt) = source_after_key_opt { + after_key_accessor_idx = skip_for_key( + &columns_and_types, + &source_after_key_explicit_opt, + source.missing_order, + source.order, + )?; + } + + let source_collectors: Vec = columns_and_types + .into_iter() + .map(|(column, column_type)| { + Ok(CompositeAccessor { + column, + column_type, + str_dict_column: reader.fast_fields().str(&source.field)?, + date_histogram_interval: PrecomputedDateInterval::NotApplicable, + }) + }) + .collect::>()?; + + let after_key = if let Some(first_col) = + source_collectors.get(after_key_accessor_idx) + { + match source_after_key_opt { + Some(after_key) => PrecomputedAfterKey::precompute( + &first_col, + after_key, + &source.field, + source.missing_order, + source.order, + )?, + None => { + precompute_missing_after_key(false, source.missing_order, source.order) + } + } + } else { + // if no columns, we don't care about the after_key + PrecomputedAfterKey::Next(0) + }; + + Ok(CompositeSourceAccessors { + accessors: source_collectors, + is_after_key_explicit_missing, + skip_missing, + after_key, + after_key_accessor_idx, + }) + } + CompositeAggregationSource::Histogram(source) => { + let column_and_types: Vec<(Column, ColumnType)> = get_all_ff_readers( + reader, + &source.field, + Some(get_numeric_or_date_column_types()), + )?; + let source_collectors: Vec = column_and_types + .into_iter() + .map(|(column, column_type)| { + Ok(CompositeAccessor { + column, + column_type, + str_dict_column: None, + date_histogram_interval: PrecomputedDateInterval::NotApplicable, + }) + }) + .collect::>()?; + let after_key = match source_after_key_opt { + Some(CompositeKey::I64(key)) => { + let normalized_key = *key as f64 / source.interval; + num_proj::f64_to_i64(normalized_key).into() + } + Some(CompositeKey::U64(key)) => { + let normalized_key = *key as f64 / source.interval; + num_proj::f64_to_i64(normalized_key).into() + } + Some(CompositeKey::F64(key)) => { + let normalized_key = *key / source.interval; + num_proj::f64_to_i64(normalized_key).into() + } + Some(CompositeKey::Null) => { + precompute_missing_after_key(true, source.missing_order, source.order) + } + None => precompute_missing_after_key(true, source.missing_order, source.order), + _ => { + return Err(crate::TantivyError::InvalidArgument( + "After key type invalid for interval composite source".to_string(), + )); + } + }; + Ok(CompositeSourceAccessors { + accessors: source_collectors, + is_after_key_explicit_missing, + skip_missing, + after_key, + after_key_accessor_idx: 0, + }) + } + CompositeAggregationSource::DateHistogram(source) => { + let column_and_types = + get_all_ff_readers(reader, &source.field, Some(&[ColumnType::DateTime]))?; + let date_histogram_interval = + PrecomputedDateInterval::from_date_histogram_source_intervals( + &source.fixed_interval, + source.calendar_interval, + )?; + let source_collectors: Vec = column_and_types + .into_iter() + .map(|(column, column_type)| { + Ok(CompositeAccessor { + column, + column_type, + str_dict_column: None, + date_histogram_interval, + }) + }) + .collect::>()?; + let after_key = match source_after_key_opt { + Some(CompositeKey::I64(key)) => PrecomputedAfterKey::Exact(key.to_u64()), + Some(CompositeKey::Null) => { + precompute_missing_after_key(true, source.missing_order, source.order) + } + None => precompute_missing_after_key(true, source.missing_order, source.order), + _ => { + return Err(crate::TantivyError::InvalidArgument( + "After key type invalid for interval composite source".to_string(), + )); + } + }; + Ok(CompositeSourceAccessors { + accessors: source_collectors, + is_after_key_explicit_missing, + skip_missing, + after_key, + after_key_accessor_idx: 0, + }) + } + } + } +} + +/// Sort orders: +/// - Asc: Bool->Str->F64/I64/U64->DateTime/IpAddr +/// - Desc: U64/I64/F64->Str->Bool->DateTime/IpAddr +fn col_type_order_key(col_type: &ColumnType, composite_order: Order) -> i32 { + let apply_order = match composite_order { + Order::Asc => 1, + Order::Desc => -1, + }; + match col_type { + ColumnType::Bool => 1 * apply_order, + ColumnType::Str => 2 * apply_order, + // numeric types are coerced so it will be either U64, I64 or F64 + ColumnType::F64 => 3 * apply_order, + ColumnType::I64 => 3 * apply_order, + ColumnType::U64 => 3 * apply_order, + // DateTime/IpAddr cannot be automatically deduced from + // json, so if present we are guaranteed to have exactly + // one column + ColumnType::DateTime => 4, + ColumnType::IpAddr => 4, + ColumnType::Bytes => panic!("unsupported"), + } +} + +fn find_skip_idx( + columns_and_types: &Vec<(T, ColumnType)>, + order: Order, + skip_until_col_type_order_key: i32, +) -> crate::Result { + for (idx, (_, col_type)) in columns_and_types.iter().enumerate() { + let col_type_order = col_type_order_key(col_type, order); + if col_type_order >= skip_until_col_type_order_key { + return Ok(idx); + } + } + Ok(columns_and_types.len()) +} + +fn skip_for_key( + columns_and_types: &Vec<(T, ColumnType)>, + after_key: &CompositeKey, + missing_order: MissingOrder, + order: Order, +) -> crate::Result { + match (after_key, order) { + // Asc: Bool->Str->F64/I64/U64->DateTime/IpAddr + (CompositeKey::Bool(_), Order::Asc) => find_skip_idx(columns_and_types, order, 1), + (CompositeKey::Str(_), Order::Asc) => find_skip_idx(columns_and_types, order, 2), + (CompositeKey::F64(_) | CompositeKey::I64(_) | CompositeKey::U64(_), Order::Asc) => { + find_skip_idx(columns_and_types, order, 3) + } + // Desc: U64/I64/F64->Str->Bool->DateTime/IpAddr + (CompositeKey::F64(_) | CompositeKey::I64(_) | CompositeKey::U64(_), Order::Desc) => { + find_skip_idx(columns_and_types, order, -3) + } + (CompositeKey::Str(_), Order::Desc) => find_skip_idx(columns_and_types, order, -2), + (CompositeKey::Bool(_), Order::Desc) => find_skip_idx(columns_and_types, order, -1), + (CompositeKey::Null, _) => { + match (missing_order, order) { + (MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => { + Ok(0) // don't skip any columns + } + (MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => { + // all columns are skipped + Ok(columns_and_types.len()) + } + } + } + } +} + +fn precompute_missing_after_key( + is_after_key_explicit_missing: bool, + missing_order: MissingOrder, + order: Order, +) -> PrecomputedAfterKey { + let after_last = PrecomputedAfterKey::AfterLast; + let before_first = PrecomputedAfterKey::Next(0); + match (is_after_key_explicit_missing, missing_order, order) { + (true, MissingOrder::First, Order::Asc) => before_first, + (true, MissingOrder::First, Order::Desc) => after_last, + (true, MissingOrder::Last, Order::Asc) => after_last, + (true, MissingOrder::Last, Order::Desc) => before_first, + (true, MissingOrder::Default, Order::Asc) => before_first, + (true, MissingOrder::Default, Order::Desc) => after_last, + (false, _, Order::Asc) => before_first, + (false, _, Order::Desc) => after_last, + } +} + +/// A parsed representation of the date interval for date histogram sources +#[derive(Clone, Copy, Debug)] +pub enum PrecomputedDateInterval { + /// This is not a date histogram source + NotApplicable, + /// Source was configured with a fixed interval + FixedMilliseconds(i64), + /// Source was configured with a calendar interval + Calendar(CalendarInterval), +} + +impl PrecomputedDateInterval { + /// Validates the date histogram source interval fields and parses a date interval from them. + pub fn from_date_histogram_source_intervals( + fixed_interval: &Option, + calendar_interval: Option, + ) -> crate::Result { + match (fixed_interval, calendar_interval) { + (Some(_), Some(_)) | (None, None) => Err(TantivyError::InvalidArgument( + "date histogram source must one and only one of fixed_interval or \ + calendar_interval set" + .to_string(), + )), + (Some(fixed_interval), None) => { + let fixed_interval_ms = parse_into_milliseconds(&fixed_interval)?; + Ok(PrecomputedDateInterval::FixedMilliseconds( + fixed_interval_ms, + )) + } + (None, Some(calendar_interval)) => { + Ok(PrecomputedDateInterval::Calendar(calendar_interval)) + } + } + } +} + +/// The after key projected to the u64 column space +/// +/// Some column types (term, IP) might not have an exact representation of the +/// specified after key +#[derive(Debug)] +pub enum PrecomputedAfterKey { + /// The after key could be exactly represented in the column space. + Exact(u64), + /// The after key could not be exactly represented exactly represented, so + /// this is the next closest one. + Next(u64), + /// The after key could not be represented in the column space, it is + /// greater than all value + AfterLast, +} + +impl From for PrecomputedAfterKey { + fn from(hit: TermOrdHit) -> Self { + match hit { + TermOrdHit::Exact(ord) => PrecomputedAfterKey::Exact(ord), + // TermOrdHit represents AfterLast as Next(u64::MAX), we keep it as is + TermOrdHit::Next(ord) => PrecomputedAfterKey::Next(ord), + } + } +} + +impl From for PrecomputedAfterKey { + fn from(hit: CompactHit) -> Self { + match hit { + CompactHit::Exact(ord) => PrecomputedAfterKey::Exact(ord as u64), + CompactHit::Next(ord) => PrecomputedAfterKey::Next(ord as u64), + CompactHit::AfterLast => PrecomputedAfterKey::AfterLast, + } + } +} + +impl From> for PrecomputedAfterKey { + fn from(num: ProjectedNumber) -> Self { + match num { + ProjectedNumber::Exact(number) => PrecomputedAfterKey::Exact(number.to_u64()), + ProjectedNumber::Next(number) => PrecomputedAfterKey::Next(number.to_u64()), + ProjectedNumber::AfterLast => PrecomputedAfterKey::AfterLast, + } + } +} + +// /!\ These operators only makes sense if both values are in the same column space +impl PrecomputedAfterKey { + pub fn equals(&self, column_value: u64) -> bool { + match self { + PrecomputedAfterKey::Exact(v) => *v == column_value, + PrecomputedAfterKey::Next(_) => false, + PrecomputedAfterKey::AfterLast => false, + } + } + + pub fn gt(&self, column_value: u64) -> bool { + match self { + PrecomputedAfterKey::Exact(v) => *v > column_value, + PrecomputedAfterKey::Next(v) => *v > column_value, + PrecomputedAfterKey::AfterLast => true, + } + } + + pub fn lt(&self, column_value: u64) -> bool { + match self { + PrecomputedAfterKey::Exact(v) => *v < column_value, + // a value equal to the next is greater than the after key + PrecomputedAfterKey::Next(v) => *v <= column_value, + PrecomputedAfterKey::AfterLast => false, + } + } + + fn precompute_i64(key: &CompositeKey, missing_order: MissingOrder, order: Order) -> Self { + // avoid rough casting + match key { + CompositeKey::I64(k) => PrecomputedAfterKey::Exact(k.to_u64()), + CompositeKey::U64(k) => num_proj::u64_to_i64(*k).into(), + CompositeKey::F64(k) => num_proj::f64_to_i64(*k).into(), + CompositeKey::Bool(_) => Self::keep_all(order), + CompositeKey::Str(_) => Self::keep_all(order), + CompositeKey::Null => precompute_missing_after_key(false, missing_order, order), + } + } + + fn precompute_u64(key: &CompositeKey, missing_order: MissingOrder, order: Order) -> Self { + match key { + CompositeKey::I64(k) => num_proj::i64_to_u64(*k).into(), + CompositeKey::U64(k) => PrecomputedAfterKey::Exact(*k), + CompositeKey::F64(k) => num_proj::f64_to_u64(*k).into(), + CompositeKey::Bool(_) => Self::keep_all(order), + CompositeKey::Str(_) => Self::keep_all(order), + CompositeKey::Null => precompute_missing_after_key(false, missing_order, order), + } + } + + fn precompute_f64(key: &CompositeKey, missing_order: MissingOrder, order: Order) -> Self { + match key { + CompositeKey::F64(k) => PrecomputedAfterKey::Exact(k.to_u64()), + CompositeKey::I64(k) => num_proj::i64_to_f64(*k).into(), + CompositeKey::U64(k) => num_proj::u64_to_f64(*k).into(), + CompositeKey::Bool(_) => Self::keep_all(order), + CompositeKey::Str(_) => Self::keep_all(order), + CompositeKey::Null => precompute_missing_after_key(false, missing_order, order), + } + } + + fn precompute_ip_addr(column: &Column, key: &str, field: &str) -> crate::Result { + let compact_space_accessor = column + .values + .clone() + .downcast_arc::() + .map_err(|_| { + TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError( + "type mismatch: could not downcast to CompactSpaceU64Accessor".to_string(), + )) + })?; + let ip_u128 = IpAddr::from_str(key) + .map_err(|_| { + TantivyError::InvalidArgument(format!( + "failed to parse after_key '{}' as IpAddr for field '{}'", + key, field + )) + })? + .into_ipv6_addr() + .to_bits(); + let ip_next_compact = compact_space_accessor.u128_to_next_compact(ip_u128); + Ok(ip_next_compact.into()) + } + + fn precompute_term_ord( + str_dict_column: &Option, + key: &str, + field: &str, + ) -> crate::Result { + let dict = str_dict_column + .as_ref() + .expect("dictionary missing for str accessor") + .dictionary(); + let next_ord = dict.term_ord_or_next(key).map_err(|_| { + TantivyError::InvalidArgument(format!( + "failed to lookup after_key '{}' for field '{}'", + key, field + )) + })?; + Ok(next_ord.into()) + } + + /// Assumes that the relevant columns were already skipped + pub fn precompute( + composite_accessor: &CompositeAccessor, + source_after_key: &CompositeKey, + field: &str, + missing_order: MissingOrder, + order: Order, + ) -> crate::Result { + let precomputed_key = match (composite_accessor.column_type, source_after_key) { + (_, CompositeKey::F64(f)) if f.is_nan() => { + return Err(crate::TantivyError::InvalidArgument(format!( + "unexptected NaN in after key {:?}", + source_after_key + ))); + } + (ColumnType::I64, key) => { + PrecomputedAfterKey::precompute_i64(key, missing_order, order) + } + (ColumnType::U64, key) => { + PrecomputedAfterKey::precompute_u64(key, missing_order, order) + } + (ColumnType::F64, key) => { + PrecomputedAfterKey::precompute_f64(key, missing_order, order) + } + (ColumnType::Bool, CompositeKey::Bool(key)) => PrecomputedAfterKey::Exact(key.to_u64()), + (ColumnType::Bool, CompositeKey::Null) => { + precompute_missing_after_key(false, missing_order, order) + } + (ColumnType::Bool, _) => PrecomputedAfterKey::keep_all(order), + (ColumnType::Str, CompositeKey::Str(key)) => PrecomputedAfterKey::precompute_term_ord( + &composite_accessor.str_dict_column, + key, + field, + )?, + (ColumnType::Str, CompositeKey::Null) => { + precompute_missing_after_key(false, missing_order, order) + } + (ColumnType::Str, _) => PrecomputedAfterKey::keep_all(order), + (ColumnType::DateTime, CompositeKey::Str(key)) => { + PrecomputedAfterKey::Exact(parse_date(key)?.to_u64()) + } + (ColumnType::IpAddr, CompositeKey::Str(key)) => { + PrecomputedAfterKey::precompute_ip_addr(&composite_accessor.column, key, field)? + } + (ColumnType::Bytes, _) => panic!("unsupported"), + (ColumnType::DateTime | ColumnType::IpAddr, CompositeKey::Null) => { + precompute_missing_after_key(false, missing_order, order) + } + (ColumnType::DateTime | ColumnType::IpAddr, _) => { + // we don't support fields for which the schema changes + return Err(crate::TantivyError::InvalidArgument(format!( + "after key {:?} does not match column type {:?} for field '{}'", + source_after_key, composite_accessor.column_type, field + ))); + } + }; + Ok(precomputed_key) + } + + fn keep_all(order: Order) -> Self { + match order { + Order::Asc => PrecomputedAfterKey::Next(0), + Order::Desc => PrecomputedAfterKey::Next(u64::MAX), + } + } +} + +#[cfg(test)] +mod tests { + use std::net::Ipv6Addr; + + use super::super::type_order_key; + use super::*; + use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey; + + #[test] + fn test_sort_order_keys_aligned() { + // it is important that the order keys used to order column types and + // intermediate key types are the same + for order in [Order::Asc, Order::Desc] { + assert_eq!( + col_type_order_key(&ColumnType::Bool, order), + type_order_key(&CompositeIntermediateKey::Bool(true), order) + ); + assert_eq!( + col_type_order_key(&ColumnType::Str, order), + type_order_key(&CompositeIntermediateKey::Str("".to_string()), order) + ); + assert_eq!( + col_type_order_key(&ColumnType::I64, order), + type_order_key(&CompositeIntermediateKey::I64(0), order) + ); + assert_eq!( + col_type_order_key(&ColumnType::U64, order), + type_order_key(&CompositeIntermediateKey::U64(0), order) + ); + assert_eq!( + col_type_order_key(&ColumnType::F64, order), + type_order_key(&CompositeIntermediateKey::F64(0.0), order) + ); + assert_eq!( + col_type_order_key(&ColumnType::DateTime, order), + type_order_key(&CompositeIntermediateKey::DateTime(0), order) + ); + assert_eq!( + col_type_order_key(&ColumnType::IpAddr, order), + type_order_key( + &CompositeIntermediateKey::IpAddr(Ipv6Addr::LOCALHOST), + order + ) + ); + } + } +} diff --git a/src/aggregation/bucket/composite/calendar_interval.rs b/src/aggregation/bucket/composite/calendar_interval.rs new file mode 100644 index 0000000000..ce883b7ff2 --- /dev/null +++ b/src/aggregation/bucket/composite/calendar_interval.rs @@ -0,0 +1,140 @@ +use time::convert::{Day, Nanosecond}; +use time::{Time, UtcDateTime}; + +const NS_IN_DAY: i64 = Nanosecond::per_t::(Day) as i64; + +/// Computes the timestamp in nanoseconds corresponding to the beginning of the +/// year (January 1st at midnight UTC). +pub(super) fn try_year_bucket(timestamp_ns: i64) -> crate::Result { + year_bucket_using_time_crate(timestamp_ns).map_err(|e| { + crate::TantivyError::InvalidArgument(format!( + "Failed to compute year bucket for timestamp {}: {}", + timestamp_ns, + e.to_string() + )) + }) +} + +/// Computes the timestamp in nanoseconds corresponding to the beginning of the +/// month (1st at midnight UTC). +pub(super) fn try_month_bucket(timestamp_ns: i64) -> crate::Result { + month_bucket_using_time_crate(timestamp_ns).map_err(|e| { + crate::TantivyError::InvalidArgument(format!( + "Failed to compute month bucket for timestamp {}: {}", + timestamp_ns, + e.to_string() + )) + }) +} + +/// Computes the timestamp in nanoseconds corresponding to the beginning of the +/// week (Monday at midnight UTC). +pub(super) fn week_bucket(timestamp_ns: i64) -> i64 { + // 1970-01-01 was a Thursday (weekday = 4) + let days_since_epoch = timestamp_ns.div_euclid(NS_IN_DAY); + // Find the weekday: 0=Monday, ..., 6=Sunday + let weekday = (days_since_epoch + 3).rem_euclid(7); + let monday_days_since_epoch = days_since_epoch - weekday; + monday_days_since_epoch * NS_IN_DAY +} + +fn year_bucket_using_time_crate(timestamp_ns: i64) -> Result { + let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)? + .replace_ordinal(1)? + .replace_time(Time::MIDNIGHT) + .unix_timestamp_nanos(); + Ok(timestamp_ns as i64) +} + +fn month_bucket_using_time_crate(timestamp_ns: i64) -> Result { + let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)? + .replace_day(1)? + .replace_time(Time::MIDNIGHT) + .unix_timestamp_nanos(); + Ok(timestamp_ns as i64) +} + +#[cfg(test)] +mod tests { + use std::i64; + + use time::format_description::well_known::Iso8601; + use time::UtcDateTime; + + use super::*; + + fn ts_ns(iso: &str) -> i64 { + UtcDateTime::parse(iso, &Iso8601::DEFAULT) + .unwrap() + .unix_timestamp_nanos() as i64 + } + + #[test] + fn test_year_bucket() { + let ts = ts_ns("1970-01-01T00:00:00Z"); + let res = try_year_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("1970-01-01T00:00:00Z")); + + let ts = ts_ns("1970-06-01T10:00:01.010Z"); + let res = try_year_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("1970-01-01T00:00:00Z")); + + let ts = ts_ns("2008-12-31T23:59:59.999999999Z"); // leap year + let res = try_year_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("2008-01-01T00:00:00Z")); + + let ts = ts_ns("2008-01-01T00:00:00Z"); // leap year + let res = try_year_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("2008-01-01T00:00:00Z")); + + let ts = ts_ns("2010-12-31T23:59:59.999999999Z"); + let res = try_year_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("2010-01-01T00:00:00Z")); + + let ts = ts_ns("1972-06-01T00:10:00Z"); + let res = try_year_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("1972-01-01T00:00:00Z")); + } + + #[test] + fn test_month_bucket() { + let ts = ts_ns("1970-01-15T00:00:00Z"); + let res = try_month_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("1970-01-01T00:00:00Z")); + + let ts = ts_ns("1970-02-01T00:00:00Z"); + let res = try_month_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("1970-02-01T00:00:00Z")); + + let ts = ts_ns("2000-01-31T23:59:59.999999999Z"); + let res = try_month_bucket(ts).unwrap(); + assert_eq!(res, ts_ns("2000-01-01T00:00:00Z")); + } + + #[test] + fn test_week_bucket() { + let ts = ts_ns("1970-01-05T00:00:00Z"); // Monday + let res = week_bucket(ts); + assert_eq!(res, ts_ns("1970-01-05T00:00:00Z")); + + let ts = ts_ns("1970-01-05T23:59:59Z"); // Monday + let res = week_bucket(ts); + assert_eq!(res, ts_ns("1970-01-05T00:00:00Z")); + + let ts = ts_ns("1970-01-07T01:13:00Z"); // Wednesday + let res = week_bucket(ts); + assert_eq!(res, ts_ns("1970-01-05T00:00:00Z")); + + let ts = ts_ns("1970-01-11T23:59:59.999999999Z"); // Sunday + let res = week_bucket(ts); + assert_eq!(res, ts_ns("1970-01-05T00:00:00Z")); + + let ts = ts_ns("2025-10-16T10:41:59.010Z"); // Thursday + let res = week_bucket(ts); + assert_eq!(res, ts_ns("2025-10-13T00:00:00Z")); + + let ts = ts_ns("1970-01-01T00:00:00Z"); // Thursday + let res = week_bucket(ts); + assert_eq!(res, ts_ns("1969-12-29T00:00:00Z")); // Negative + } +} diff --git a/src/aggregation/bucket/composite/collector.rs b/src/aggregation/bucket/composite/collector.rs new file mode 100644 index 0000000000..a849c087e2 --- /dev/null +++ b/src/aggregation/bucket/composite/collector.rs @@ -0,0 +1,623 @@ +use std::fmt::Debug; +use std::net::Ipv6Addr; + +use columnar::column_values::CompactSpaceU64Accessor; +use columnar::{ + Column, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64, + NumericalValue, StrColumn, +}; +use rustc_hash::FxHashMap; +use smallvec::SmallVec; + +use crate::aggregation::agg_data::{ + build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, +}; +use crate::aggregation::bucket::composite::accessors::{ + CompositeAccessor, CompositeAggReqData, PrecomputedDateInterval, +}; +use crate::aggregation::bucket::composite::calendar_interval; +use crate::aggregation::bucket::composite::map::{DynArrayHeapMap, MAX_DYN_ARRAY_SIZE}; +use crate::aggregation::bucket::{ + CalendarInterval, CompositeAggregationSource, MissingOrder, Order, +}; +use crate::aggregation::format_date; +use crate::aggregation::intermediate_agg_result::{ + CompositeIntermediateKey, IntermediateAggregationResult, IntermediateAggregationResults, + IntermediateBucketResult, IntermediateCompositeBucketEntry, IntermediateCompositeBucketResult, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::TantivyError; + +#[derive(Clone, Debug)] +struct CompositeBucketCollector { + count: u32, + sub_aggs: Option>, +} + +impl CompositeBucketCollector { + fn new(sub_aggs: Option>) -> Self { + CompositeBucketCollector { count: 0, sub_aggs } + } + #[inline] + fn collect( + &mut self, + doc: crate::DocId, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.count += 1; + if let Some(sub_aggs) = &mut self.sub_aggs { + sub_aggs.collect(doc, agg_data)?; + } + Ok(()) + } +} + +/// The value is represented as a tuple of: +/// - the column index or missing value sentinel +/// - if the value is present, store the accessor index + 1 +/// - if the value is missing, store 0 (for missing first) or u8::MAX (for missing last) +/// - the fast field value u64 representation +/// - 0 if the field is missing +/// - regular u64 repr if the ordering is ascending +/// - bitwise NOT of the u64 repr if the ordering is descending +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default, Hash)] +struct InternalValueRepr(u8, u64); + +impl InternalValueRepr { + #[inline] + fn new(raw: u64, accessor_idx: u8, order: Order) -> Self { + match order { + Order::Asc => InternalValueRepr(accessor_idx + 1, raw), + Order::Desc => InternalValueRepr(accessor_idx + 1, !raw), + } + } + #[inline] + fn new_missing(order: Order, missing_order: MissingOrder) -> Self { + let column_idx = match (missing_order, order) { + (MissingOrder::First, _) => 0, + (MissingOrder::Last, _) => u8::MAX, + (MissingOrder::Default, Order::Asc) => 0, + (MissingOrder::Default, Order::Desc) => u8::MAX, + }; + InternalValueRepr(column_idx, 0) + } + #[inline] + fn decode(self, order: Order) -> Option<(u8, u64)> { + if self.0 == u8::MAX || self.0 == 0 { + return None; + } + match order { + Order::Asc => Some((self.0 - 1, self.1)), + Order::Desc => Some((self.0 - 1, !self.1)), + } + } +} + +/// The collector puts values from the fast field into the correct buckets and +/// does a conversion to the correct datatype. +#[derive(Clone, Debug)] +pub struct SegmentCompositeCollector { + buckets: DynArrayHeapMap, + accessor_idx: usize, +} + +impl SegmentAggregationCollector for SegmentCompositeCollector { + fn add_intermediate_aggregation_result( + self: Box, + agg_data: &AggregationsSegmentCtx, + results: &mut IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_data + .get_composite_req_data(self.accessor_idx) + .name + .clone(); + + let buckets = self.into_intermediate_bucket_result(agg_data)?; + results.push( + name, + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite { buckets }), + )?; + + Ok(()) + } + + #[inline] + fn collect( + &mut self, + doc: crate::DocId, + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + self.collect_block(&[doc], agg_data) + } + + #[inline] + fn collect_block( + &mut self, + docs: &[crate::DocId], + agg_data: &mut AggregationsSegmentCtx, + ) -> crate::Result<()> { + let mem_pre = self.get_memory_consumption(); + let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx); + + for doc in docs { + let mut sub_level_values = SmallVec::new(); + recursive_key_visitor( + *doc, + agg_data, + &composite_agg_data, + 0, + &mut sub_level_values, + &mut self.buckets, + true, + )?; + } + agg_data.put_back_composite_req_data(self.accessor_idx, composite_agg_data); + + let mem_delta = self.get_memory_consumption() - mem_pre; + if mem_delta > 0 { + agg_data.limits.add_memory_consumed(mem_delta)?; + } + + Ok(()) + } + + fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { + for sub_agg_collector in self.buckets.values_mut() { + if let Some(sub_aggs_collector) = &mut sub_agg_collector.sub_aggs { + sub_aggs_collector.flush(agg_data)?; + } + } + Ok(()) + } +} + +impl SegmentCompositeCollector { + fn get_memory_consumption(&self) -> u64 { + // TODO: the footprint is underestimated because we don't account for the + // sub-aggregations which are trait objects + self.buckets.memory_consumption() + } + + pub(crate) fn from_req_and_validate( + req_data: &mut AggregationsSegmentCtx, + node: &AggRefNode, + ) -> crate::Result { + validate_req(req_data, node.idx_in_req_data)?; + + let has_sub_aggregations = !node.children.is_empty(); + let blueprint = if has_sub_aggregations { + let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; + Some(sub_aggregation) + } else { + None + }; + let composite_req_data = req_data.get_composite_req_data_mut(node.idx_in_req_data); + composite_req_data.sub_aggregation_blueprint = blueprint; + + Ok(SegmentCompositeCollector { + buckets: DynArrayHeapMap::try_new(composite_req_data.req.sources.len())?, + accessor_idx: node.idx_in_req_data, + }) + } + + #[inline] + pub(crate) fn into_intermediate_bucket_result( + self, + agg_data: &AggregationsSegmentCtx, + ) -> crate::Result { + let mut dict: FxHashMap, IntermediateCompositeBucketEntry> = + Default::default(); + dict.reserve(self.buckets.size()); + let composite_data = agg_data.get_composite_req_data(self.accessor_idx); + for (key_internal_repr, agg) in self.buckets.into_iter() { + let key = resolve_key(&key_internal_repr, composite_data)?; + let mut sub_aggregation_res = IntermediateAggregationResults::default(); + if let Some(sub_aggs_collector) = agg.sub_aggs { + sub_aggs_collector + .add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; + } + + dict.insert( + key, + IntermediateCompositeBucketEntry { + doc_count: agg.count, + sub_aggregation: sub_aggregation_res, + }, + ); + } + + Ok(IntermediateCompositeBucketResult { + entries: dict, + target_size: composite_data.req.size, + orders: composite_data + .req + .sources + .iter() + .map(|source| match source { + CompositeAggregationSource::Terms(t) => (t.order, t.missing_order), + CompositeAggregationSource::Histogram(h) => (h.order, h.missing_order), + CompositeAggregationSource::DateHistogram(d) => (d.order, d.missing_order), + }) + .collect(), + }) + } +} + +fn validate_req(req_data: &mut AggregationsSegmentCtx, accessor_idx: usize) -> crate::Result<()> { + let composite_data = req_data.get_composite_req_data(accessor_idx); + let req = &composite_data.req; + if req.sources.is_empty() { + return Err(TantivyError::InvalidArgument( + "composite aggregation must have at least one source".to_string(), + )); + } + if req.size == 0 { + return Err(TantivyError::InvalidArgument( + "composite aggregation 'size' must be > 0".to_string(), + )); + } + let column_types_for_sources = composite_data.composite_accessors.iter().map(|item| { + item.accessors + .iter() + .map(|a| a.column_type) + .collect::>() + }); + + for (source, column_types) in req.sources.iter().zip(column_types_for_sources) { + if column_types.len() > MAX_DYN_ARRAY_SIZE { + return Err(TantivyError::InvalidArgument(format!( + "composite aggregation source supports maximum {MAX_DYN_ARRAY_SIZE} sources", + ))); + } + if column_types.contains(&ColumnType::Bytes) { + return Err(TantivyError::InvalidArgument( + "composite aggregation does not support 'bytes' field type".to_string(), + )); + } + if column_types.contains(&ColumnType::DateTime) && column_types.len() > 1 { + return Err(TantivyError::InvalidArgument( + "composite aggregation expects 'date' fields to have a single column".to_string(), + )); + } + if column_types.contains(&ColumnType::IpAddr) && column_types.len() > 1 { + return Err(TantivyError::InvalidArgument( + "composite aggregation expects 'ip' fields to have a single column".to_string(), + )); + } + match source { + CompositeAggregationSource::Terms(_) => { + if column_types.len() > 3 { + return Err(TantivyError::InvalidArgument( + "expected at most 3 columns for composite aggregation 'terms' source \ + (text, numerical and boolean)" + .to_string(), + )); + } + } + CompositeAggregationSource::Histogram(_) => { + if column_types.len() > 1 { + return Err(TantivyError::InvalidArgument( + "expected at most 1 column for composite aggregation 'histogram' source \ + (numerical or date)" + .to_string(), + )); + } + } + CompositeAggregationSource::DateHistogram(_) => { + if column_types.len() > 1 { + return Err(TantivyError::InvalidArgument( + "expected at most 1 column (date) for composite aggregation \ + 'date_histogram' source" + .to_string(), + )); + } + } + } + } + Ok(()) +} + +fn collect_bucket_with_limit( + doc_id: crate::DocId, + agg_data: &mut AggregationsSegmentCtx, + composite_agg_data: &CompositeAggReqData, + buckets: &mut DynArrayHeapMap, + key: &[InternalValueRepr], +) -> crate::Result<()> { + // we still have room for buckets, just insert + if (buckets.size() as u32) < composite_agg_data.req.size { + buckets + .get_or_insert_with(key, || { + CompositeBucketCollector::new(composite_agg_data.sub_aggregation_blueprint.clone()) + }) + .collect(doc_id, agg_data)?; + return Ok(()); + } + + // map is full, but we can still update the bucket if it already exists + if let Some(entry) = buckets.get_mut(key) { + entry.collect(doc_id, agg_data)?; + return Ok(()); + } + + // check if the item qualfies to enter the top-k, and evict the highest if it does + if let Some(highest_key) = buckets.peek_highest() { + if key < highest_key { + buckets.evict_highest(); + buckets + .get_or_insert_with(key, || { + CompositeBucketCollector::new( + composite_agg_data.sub_aggregation_blueprint.clone(), + ) + }) + .collect(doc_id, agg_data)?; + } + } + + Ok(()) +} + +/// Converts the composite key from its internal column space representation +/// (segment specific) into its intermediate form. +fn resolve_key( + internal_key: &[InternalValueRepr], + agg_data: &CompositeAggReqData, +) -> crate::Result> { + internal_key + .into_iter() + .enumerate() + .map(|(idx, val)| { + resolve_internal_value_repr( + *val, + &agg_data.req.sources[idx], + &agg_data.composite_accessors[idx].accessors, + ) + }) + .collect() +} + +fn resolve_internal_value_repr( + internal_value_repr: InternalValueRepr, + source: &CompositeAggregationSource, + composite_accessors: &[CompositeAccessor], +) -> crate::Result { + let decoded_value_opt = match source { + CompositeAggregationSource::Terms(source) => internal_value_repr.decode(source.order), + CompositeAggregationSource::Histogram(source) => internal_value_repr.decode(source.order), + CompositeAggregationSource::DateHistogram(source) => { + internal_value_repr.decode(source.order) + } + }; + let Some((decoded_accessor_idx, val)) = decoded_value_opt else { + return Ok(CompositeIntermediateKey::Null); + }; + let CompositeAccessor { + column_type, + str_dict_column, + column, + .. + } = &composite_accessors[decoded_accessor_idx as usize]; + let key = match source { + CompositeAggregationSource::Terms(_) => { + resolve_term(val, column_type, str_dict_column, column)? + } + CompositeAggregationSource::Histogram(source) => { + // Results are collected as interval indices to avoid Fx Hash collisions. + // Multiply back by the interval to get the bucket value. + CompositeIntermediateKey::F64(i64::from_u64(val) as f64 * source.interval) + } + CompositeAggregationSource::DateHistogram(_) => { + CompositeIntermediateKey::I64(i64::from_u64(val)) + } + }; + + Ok(key) +} + +fn resolve_term( + val: u64, + column_type: &ColumnType, + str_dict_column: &Option, + column: &Column, +) -> crate::Result { + let key = if *column_type == ColumnType::Str { + let fallback_dict = Dictionary::empty(); + let term_dict = str_dict_column + .as_ref() + .map(|el| el.dictionary()) + .unwrap_or_else(|| &fallback_dict); + + // TODO try use sorted_ords_to_term_cb to batch + let mut buffer = Vec::new(); + term_dict.ord_to_term(val, &mut buffer)?; + CompositeIntermediateKey::Str( + String::from_utf8(buffer.to_vec()).expect("could not convert to String"), + ) + } else if *column_type == ColumnType::DateTime { + let val = i64::from_u64(val); + let date = format_date(val)?; + CompositeIntermediateKey::Str(date) + } else if *column_type == ColumnType::Bool { + let val = bool::from_u64(val); + CompositeIntermediateKey::Bool(val) + } else if *column_type == ColumnType::IpAddr { + let compact_space_accessor = column + .values + .clone() + .downcast_arc::() + .map_err(|_| { + TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError( + "Type mismatch: Could not downcast to CompactSpaceU64Accessor".to_string(), + )) + })?; + let val: u128 = compact_space_accessor.compact_to_u128(val as u32); + let val = Ipv6Addr::from_u128(val); + CompositeIntermediateKey::IpAddr(val) + } else { + if *column_type == ColumnType::U64 { + CompositeIntermediateKey::U64(val) + } else if *column_type == ColumnType::I64 { + CompositeIntermediateKey::I64(i64::from_u64(val)) + } else { + let val = f64::from_u64(val); + let val: NumericalValue = val.into(); + + match val.normalize() { + NumericalValue::U64(val) => CompositeIntermediateKey::U64(val), + NumericalValue::I64(val) => CompositeIntermediateKey::I64(val), + NumericalValue::F64(val) => CompositeIntermediateKey::F64(val), + } + } + }; + Ok(key) +} + +/// Depth-first walk of the accessors to build the composite key combinations +/// and update the buckets. +fn recursive_key_visitor( + doc_id: crate::DocId, + agg_data: &mut AggregationsSegmentCtx, + composite_agg_data: &CompositeAggReqData, + source_idx_for_recursion: usize, + sub_level_values: &mut SmallVec<[InternalValueRepr; MAX_DYN_ARRAY_SIZE]>, + buckets: &mut DynArrayHeapMap, + // whether the we need to considere the after_key in the following levels + is_on_after_key: bool, +) -> crate::Result<()> { + if source_idx_for_recursion == composite_agg_data.req.sources.len() { + if !is_on_after_key { + collect_bucket_with_limit( + doc_id, + agg_data, + composite_agg_data, + buckets, + sub_level_values, + )?; + } + return Ok(()); + } + + let current_level_accessor = &composite_agg_data.composite_accessors[source_idx_for_recursion]; + let current_level_source = &composite_agg_data.req.sources[source_idx_for_recursion]; + let mut missing = true; + for (i, accessor) in current_level_accessor.accessors.iter().enumerate() { + // TODO: optimize with prefetching using fetch_block + // TODO: currently duplicate values for a document imply double counting + // in doc_count (this is also the case in term aggregations) + let values = accessor.column.values_for_doc(doc_id); + for value in values { + missing = false; + if is_on_after_key && i < current_level_accessor.after_key_accessor_idx { + break; + } + let bucket_value: u64 = match current_level_source { + CompositeAggregationSource::Terms(_) => value, + CompositeAggregationSource::Histogram(source) => { + let float_value = match accessor.column_type { + ColumnType::U64 => value as f64, + ColumnType::I64 => i64::from_u64(value) as f64, + // Dates are stored as nanoseconds since epoch but the + // interval is in milliseconds + ColumnType::DateTime => i64::from_u64(value) as f64 / 1_000_000., + ColumnType::F64 => f64::from_u64(value), + _ => { + panic!( + "unexpected type {:?}. This should not happen", + accessor.column_type + ) + } + }; + // We use the interval index (as i64) instead of its value + // (f64) because Fx Hash has a very high collision rate when + // lower bits are similar. The index needs to be multiplied + // back by the interval when building the result. + let bucket_index = (float_value / source.interval).floor() as i64; + i64::to_u64(bucket_index) + } + CompositeAggregationSource::DateHistogram(_) => { + let value_ns = match accessor.column_type { + // Dates are stored as nanoseconds since epoch but the + // interval is in milliseconds + ColumnType::DateTime => i64::from_u64(value), + _ => { + panic!( + "unexpected type {:?}. This should not happen", + accessor.column_type + ) + } + }; + let bucket_value_i64 = match accessor.date_histogram_interval { + PrecomputedDateInterval::FixedMilliseconds(fixed_interval_ms) => { + (value_ns / 1_000_000 / fixed_interval_ms) * fixed_interval_ms + } + PrecomputedDateInterval::Calendar(CalendarInterval::Year) => { + calendar_interval::try_year_bucket(value_ns)? / 1_000_000 + } + PrecomputedDateInterval::Calendar(CalendarInterval::Month) => { + calendar_interval::try_month_bucket(value_ns)? / 1_000_000 + } + PrecomputedDateInterval::Calendar(CalendarInterval::Week) => { + calendar_interval::week_bucket(value_ns) / 1_000_000 + } + PrecomputedDateInterval::NotApplicable => { + panic!("interval not precomputed for date histogram source") + } + }; + i64::to_u64(bucket_value_i64) + } + }; + + if i == current_level_accessor.after_key_accessor_idx + && is_on_after_key + && current_level_source.order() == Order::Asc + && current_level_accessor.after_key.gt(bucket_value) + { + continue; + } + if i == current_level_accessor.after_key_accessor_idx + && is_on_after_key + && current_level_source.order() == Order::Desc + && current_level_accessor.after_key.lt(bucket_value) + { + continue; + } + sub_level_values.push(InternalValueRepr::new( + bucket_value, + i as u8, + current_level_source.order(), + )); + let still_on_after_key = current_level_accessor.after_key_accessor_idx == i + && current_level_accessor.after_key.equals(bucket_value); + recursive_key_visitor( + doc_id, + agg_data, + composite_agg_data, + source_idx_for_recursion + 1, + sub_level_values, + buckets, + is_on_after_key && still_on_after_key, + )?; + sub_level_values.pop(); + } + } + if missing && current_level_source.missing_bucket() { + if is_on_after_key && current_level_accessor.skip_missing { + return Ok(()); + } + sub_level_values.push(InternalValueRepr::new_missing( + current_level_source.order(), + current_level_source.missing_order(), + )); + recursive_key_visitor( + doc_id, + agg_data, + composite_agg_data, + source_idx_for_recursion + 1, + sub_level_values, + buckets, + is_on_after_key && current_level_accessor.is_after_key_explicit_missing, + )?; + sub_level_values.pop(); + } + Ok(()) +} diff --git a/src/aggregation/bucket/composite/map.rs b/src/aggregation/bucket/composite/map.rs new file mode 100644 index 0000000000..4f9610cfdf --- /dev/null +++ b/src/aggregation/bucket/composite/map.rs @@ -0,0 +1,364 @@ +use std::collections::BinaryHeap; +use std::fmt::Debug; +use std::hash::Hash; + +use rustc_hash::FxHashMap; +use smallvec::SmallVec; + +use crate::TantivyError; + +/// Map backed by a hash map for fast access and a binary heap to track the +/// highest key. The key is an array of fixed size S. +#[derive(Clone, Debug)] +struct ArrayHeapMap { + pub(crate) buckets: FxHashMap<[K; S], V>, + pub(crate) heap: BinaryHeap<[K; S]>, +} + +impl Default for ArrayHeapMap { + fn default() -> Self { + ArrayHeapMap { + buckets: FxHashMap::default(), + heap: BinaryHeap::default(), + } + } +} + +impl ArrayHeapMap { + /// Panics if the length of `key` is not S. + fn get_or_insert_with V>(&mut self, key: &[K], f: F) -> &mut V { + let key_array: &[K; S] = key.try_into().expect("Key length mismatch"); + self.buckets.entry(key_array.clone()).or_insert_with(|| { + self.heap.push(key_array.clone()); + f() + }) + } + + /// Panics if the length of `key` is not S. + fn get_mut(&mut self, key: &[K]) -> Option<&mut V> { + let key_array: &[K; S] = key.try_into().expect("Key length mismatch"); + self.buckets.get_mut(key_array) + } + + fn peek_highest(&self) -> Option<&[K]> { + self.heap.peek().map(|k_array| k_array.as_slice()) + } + + fn evict_highest(&mut self) { + if let Some(highest) = self.heap.pop() { + self.buckets.remove(&highest); + } + } + + fn memory_consumption(&self) -> u64 { + let key_size = std::mem::size_of::<[K; S]>(); + let map_size = (key_size + std::mem::size_of::()) * self.buckets.capacity(); + let heap_size = key_size * self.heap.capacity(); + (map_size + heap_size) as u64 + } +} + +impl ArrayHeapMap { + fn into_iter(self) -> Box, V)>> { + Box::new( + self.buckets + .into_iter() + .map(|(k, v)| (SmallVec::from_slice(&k), v)), + ) + } + + fn values_mut<'a>(&'a mut self) -> Box + 'a> { + Box::new(self.buckets.values_mut()) + } +} + +pub(super) const MAX_DYN_ARRAY_SIZE: usize = 16; +const MAX_DYN_ARRAY_SIZE_PLUS_ONE: usize = MAX_DYN_ARRAY_SIZE + 1; + +/// A map optimized for memory footprint, fast access and efficient eviction of +/// the highest key. +/// +/// Keys are inlined arrays of size 1 to [MAX_DYN_ARRAY_SIZE] but for a given +/// instance the key size is fixed. This allows to avoid heap allocations for the +/// keys. +#[derive(Clone, Debug)] +pub(super) struct DynArrayHeapMap(DynArrayHeapMapInner); + +/// Wrapper around ArrayHeapMap to dynamically dispatch on the array size. +#[derive(Clone, Debug)] +enum DynArrayHeapMapInner { + Dim1(ArrayHeapMap), + Dim2(ArrayHeapMap), + Dim3(ArrayHeapMap), + Dim4(ArrayHeapMap), + Dim5(ArrayHeapMap), + Dim6(ArrayHeapMap), + Dim7(ArrayHeapMap), + Dim8(ArrayHeapMap), + Dim9(ArrayHeapMap), + Dim10(ArrayHeapMap), + Dim11(ArrayHeapMap), + Dim12(ArrayHeapMap), + Dim13(ArrayHeapMap), + Dim14(ArrayHeapMap), + Dim15(ArrayHeapMap), + Dim16(ArrayHeapMap), +} + +impl DynArrayHeapMap { + /// Creates a new heap map with dynamic array keys of size `key_dimension`. + pub(super) fn try_new(key_dimension: usize) -> crate::Result { + let inner = match key_dimension { + 0 => { + return Err(TantivyError::InvalidArgument( + "DynArrayHeapMap dimension must be at least 1".to_string(), + )) + } + 1 => DynArrayHeapMapInner::Dim1(ArrayHeapMap::default()), + 2 => DynArrayHeapMapInner::Dim2(ArrayHeapMap::default()), + 3 => DynArrayHeapMapInner::Dim3(ArrayHeapMap::default()), + 4 => DynArrayHeapMapInner::Dim4(ArrayHeapMap::default()), + 5 => DynArrayHeapMapInner::Dim5(ArrayHeapMap::default()), + 6 => DynArrayHeapMapInner::Dim6(ArrayHeapMap::default()), + 7 => DynArrayHeapMapInner::Dim7(ArrayHeapMap::default()), + 8 => DynArrayHeapMapInner::Dim8(ArrayHeapMap::default()), + 9 => DynArrayHeapMapInner::Dim9(ArrayHeapMap::default()), + 10 => DynArrayHeapMapInner::Dim10(ArrayHeapMap::default()), + 11 => DynArrayHeapMapInner::Dim11(ArrayHeapMap::default()), + 12 => DynArrayHeapMapInner::Dim12(ArrayHeapMap::default()), + 13 => DynArrayHeapMapInner::Dim13(ArrayHeapMap::default()), + 14 => DynArrayHeapMapInner::Dim14(ArrayHeapMap::default()), + 15 => DynArrayHeapMapInner::Dim15(ArrayHeapMap::default()), + 16 => DynArrayHeapMapInner::Dim16(ArrayHeapMap::default()), + MAX_DYN_ARRAY_SIZE_PLUS_ONE.. => { + return Err(TantivyError::InvalidArgument(format!( + "DynArrayHeapMap supports maximum {MAX_DYN_ARRAY_SIZE} dimensions, got \ + {key_dimension}", + ))) + } + }; + Ok(DynArrayHeapMap(inner)) + } + + /// Number of elements in the map. This is not the dimension of the keys. + pub(super) fn size(&self) -> usize { + match &self.0 { + DynArrayHeapMapInner::Dim1(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim2(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim3(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim4(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim5(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim6(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim7(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim8(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim9(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim10(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim11(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim12(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim13(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim14(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim15(map) => map.buckets.len(), + DynArrayHeapMapInner::Dim16(map) => map.buckets.len(), + } + } +} + +impl DynArrayHeapMap { + /// Get a mutable reference to the value corresponding to `key` or inserts a new + /// value created by calling `f`. + /// + /// Panics if the length of `key` does not match the key dimension of the map. + pub(super) fn get_or_insert_with V>(&mut self, key: &[K], f: F) -> &mut V { + match &mut self.0 { + DynArrayHeapMapInner::Dim1(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim2(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim3(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim4(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim5(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim6(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim7(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim8(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim9(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim10(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim11(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim12(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim13(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim14(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim15(map) => map.get_or_insert_with(key, f), + DynArrayHeapMapInner::Dim16(map) => map.get_or_insert_with(key, f), + } + } + + /// Returns a mutable reference to the value corresponding to `key`. + /// + /// Panics if the length of `key` does not match the key dimension of the map. + pub fn get_mut(&mut self, key: &[K]) -> Option<&mut V> { + match &mut self.0 { + DynArrayHeapMapInner::Dim1(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim2(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim3(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim4(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim5(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim6(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim7(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim8(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim9(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim10(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim11(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim12(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim13(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim14(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim15(map) => map.get_mut(key), + DynArrayHeapMapInner::Dim16(map) => map.get_mut(key), + } + } + + /// Returns a reference to the highest key in the map. + pub(super) fn peek_highest(&self) -> Option<&[K]> { + match &self.0 { + DynArrayHeapMapInner::Dim1(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim2(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim3(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim4(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim5(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim6(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim7(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim8(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim9(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim10(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim11(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim12(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim13(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim14(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim15(map) => map.peek_highest(), + DynArrayHeapMapInner::Dim16(map) => map.peek_highest(), + } + } + + /// Removes the entry with the highest key from the map. + pub(super) fn evict_highest(&mut self) { + match &mut self.0 { + DynArrayHeapMapInner::Dim1(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim2(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim3(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim4(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim5(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim6(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim7(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim8(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim9(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim10(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim11(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim12(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim13(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim14(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim15(map) => map.evict_highest(), + DynArrayHeapMapInner::Dim16(map) => map.evict_highest(), + } + } + + pub(crate) fn memory_consumption(&self) -> u64 { + match &self.0 { + DynArrayHeapMapInner::Dim1(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim2(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim3(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim4(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim5(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim6(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim7(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim8(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim9(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim10(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim11(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim12(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim13(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim14(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim15(map) => map.memory_consumption(), + DynArrayHeapMapInner::Dim16(map) => map.memory_consumption(), + } + } +} + +impl DynArrayHeapMap { + /// Turns this map into an iterator over key-value pairs. + pub fn into_iter(self) -> impl Iterator, V)> { + match self.0 { + DynArrayHeapMapInner::Dim1(map) => map.into_iter(), + DynArrayHeapMapInner::Dim2(map) => map.into_iter(), + DynArrayHeapMapInner::Dim3(map) => map.into_iter(), + DynArrayHeapMapInner::Dim4(map) => map.into_iter(), + DynArrayHeapMapInner::Dim5(map) => map.into_iter(), + DynArrayHeapMapInner::Dim6(map) => map.into_iter(), + DynArrayHeapMapInner::Dim7(map) => map.into_iter(), + DynArrayHeapMapInner::Dim8(map) => map.into_iter(), + DynArrayHeapMapInner::Dim9(map) => map.into_iter(), + DynArrayHeapMapInner::Dim10(map) => map.into_iter(), + DynArrayHeapMapInner::Dim11(map) => map.into_iter(), + DynArrayHeapMapInner::Dim12(map) => map.into_iter(), + DynArrayHeapMapInner::Dim13(map) => map.into_iter(), + DynArrayHeapMapInner::Dim14(map) => map.into_iter(), + DynArrayHeapMapInner::Dim15(map) => map.into_iter(), + DynArrayHeapMapInner::Dim16(map) => map.into_iter(), + } + } + + /// Returns an iterator over mutable references to the values in the map. + pub(super) fn values_mut(&mut self) -> impl Iterator { + match &mut self.0 { + DynArrayHeapMapInner::Dim1(map) => map.values_mut(), + DynArrayHeapMapInner::Dim2(map) => map.values_mut(), + DynArrayHeapMapInner::Dim3(map) => map.values_mut(), + DynArrayHeapMapInner::Dim4(map) => map.values_mut(), + DynArrayHeapMapInner::Dim5(map) => map.values_mut(), + DynArrayHeapMapInner::Dim6(map) => map.values_mut(), + DynArrayHeapMapInner::Dim7(map) => map.values_mut(), + DynArrayHeapMapInner::Dim8(map) => map.values_mut(), + DynArrayHeapMapInner::Dim9(map) => map.values_mut(), + DynArrayHeapMapInner::Dim10(map) => map.values_mut(), + DynArrayHeapMapInner::Dim11(map) => map.values_mut(), + DynArrayHeapMapInner::Dim12(map) => map.values_mut(), + DynArrayHeapMapInner::Dim13(map) => map.values_mut(), + DynArrayHeapMapInner::Dim14(map) => map.values_mut(), + DynArrayHeapMapInner::Dim15(map) => map.values_mut(), + DynArrayHeapMapInner::Dim16(map) => map.values_mut(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dyn_array_heap_map() { + let mut map = DynArrayHeapMap::::try_new(2).unwrap(); + // insert + let key1 = [1u32, 2u32]; + let key2 = [2u32, 1u32]; + map.get_or_insert_with(&key1, || "a"); + map.get_or_insert_with(&key2, || "b"); + assert_eq!(map.size(), 2); + + // evict highest + assert_eq!(map.peek_highest(), Some(&key2[..])); + map.evict_highest(); + assert_eq!(map.size(), 1); + assert_eq!(map.peek_highest(), Some(&key1[..])); + + // mutable iterator + { + let mut mut_iter = map.values_mut(); + let v = mut_iter.next().unwrap(); + assert_eq!(*v, "a"); + *v = "c"; + assert_eq!(mut_iter.next(), None); + } + + // into_iter + let mut iter = map.into_iter(); + let (k, v) = iter.next().unwrap(); + assert_eq!(k.as_slice(), &key1); + assert_eq!(v, "c"); + assert_eq!(iter.next(), None); + } +} diff --git a/src/aggregation/bucket/composite/mod.rs b/src/aggregation/bucket/composite/mod.rs new file mode 100644 index 0000000000..3e8706f9b6 --- /dev/null +++ b/src/aggregation/bucket/composite/mod.rs @@ -0,0 +1,1789 @@ +mod accessors; +mod calendar_interval; +mod collector; +mod map; +mod numeric_types; + +use core::panic; +use std::cmp::Ordering; +use std::fmt::Debug; + +use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; + +use crate::aggregation::agg_result::CompositeKey; +pub use crate::aggregation::bucket::composite::accessors::{ + CompositeAccessor, CompositeAggReqData, CompositeSourceAccessors, PrecomputedDateInterval, +}; +pub use crate::aggregation::bucket::composite::collector::SegmentCompositeCollector; +use crate::aggregation::bucket::composite::numeric_types::num_cmp::{ + cmp_i64_f64, cmp_i64_u64, cmp_u64_f64, +}; +use crate::aggregation::bucket::Order; +use crate::aggregation::deserialize_f64; +use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey; +use crate::TantivyError; + +/// Position of missing keys in the ordering. +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum MissingOrder { + /// Missing keys appear first in ascending order, last in descending order. + #[default] + Default, + /// Missing keys should appear first. + First, + /// Missing keys should appear last. + Last, +} + +/// Term source for a composite aggregation. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct TermCompositeAggregationSource { + /// The name used to refer to this source in the composite key. + #[serde(skip)] + pub name: String, + /// The field to aggregate on. + pub field: String, + /// The order for this source. + #[serde(default = "Order::asc")] + pub order: Order, + /// Whether to create a `null` bucket for documents without value for this + /// field. By default documents without a value are ignored. + #[serde(default)] + pub missing_bucket: bool, + /// Whether missing keys should appear first or last. + #[serde(default)] + pub missing_order: MissingOrder, +} + +/// Histogram source for a composite aggregation. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct HistogramCompositeAggregationSource { + /// The name used to refer to this source in the composite key. + #[serde(skip)] + pub name: String, + /// The field to aggregate on. + pub field: String, + /// The interval for the histogram. For datetime fields, this is expressed. + /// in milliseconds. + #[serde(deserialize_with = "deserialize_f64")] + pub interval: f64, + /// The order for this source. + #[serde(default = "Order::asc")] + pub order: Order, + /// Whether to create a `null` bucket for documents without value for this + /// field. By default documents without a value are ignored. + #[serde(default)] + pub missing_bucket: bool, + /// Whether missing keys should appear first or last. + #[serde(default)] + pub missing_order: MissingOrder, +} + +/// Calendar intervals supported for date histogram sources +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CalendarInterval { + /// A year between Jan 1st and Dec 31st, taking into account leap years. + Year, + /// A month between the 1st and the last day of the month. + Month, + /// A week between Monday and Sunday. + Week, +} + +/// Date histogram source for a composite aggregation. +/// +/// Time zone not supported yet. Every interval is aligned on UTC. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct DateHistogramCompositeAggregationSource { + /// The name used to refer to this source in the composite key. + #[serde(skip)] + pub name: String, + /// The field to aggregate on. + pub field: String, + /// The fixed interval for the histogram. Either this or `calendar_interval`. + /// must be set. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub fixed_interval: Option, + /// The calendar adjusted interval for the histogram. Either this or + /// `fixed_interval` must be set. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub calendar_interval: Option, + /// The order for this source. + #[serde(default = "Order::asc")] + pub order: Order, + /// Whether to create a `null` bucket for documents without value for this + /// field. By default documents without a value are ignored. Not supported + /// in Elasticsearch. + #[serde(default)] + pub missing_bucket: bool, + /// Whether missing keys should appear first or last. + #[serde(default)] + pub missing_order: MissingOrder, +} + +/// Source for the composite aggregation. A composite aggregation can have +/// multiple sources. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompositeAggregationSource { + /// Terms source. + Terms(TermCompositeAggregationSource), + /// Histogram source. + Histogram(HistogramCompositeAggregationSource), + /// Date histogram source. + DateHistogram(DateHistogramCompositeAggregationSource), +} + +impl CompositeAggregationSource { + pub(crate) fn field(&self) -> &str { + match self { + CompositeAggregationSource::Terms(source) => &source.field, + CompositeAggregationSource::Histogram(source) => &source.field, + CompositeAggregationSource::DateHistogram(source) => &source.field, + } + } + + pub(crate) fn name(&self) -> &str { + match self { + CompositeAggregationSource::Terms(source) => &source.name, + CompositeAggregationSource::Histogram(source) => &source.name, + CompositeAggregationSource::DateHistogram(source) => &source.name, + } + } + + pub(crate) fn order(&self) -> Order { + match self { + CompositeAggregationSource::Terms(source) => source.order, + CompositeAggregationSource::Histogram(source) => source.order, + CompositeAggregationSource::DateHistogram(source) => source.order, + } + } + + pub(crate) fn missing_order(&self) -> MissingOrder { + match self { + CompositeAggregationSource::Terms(source) => source.missing_order, + CompositeAggregationSource::Histogram(source) => source.missing_order, + CompositeAggregationSource::DateHistogram(source) => source.missing_order, + } + } + + pub(crate) fn missing_bucket(&self) -> bool { + match self { + CompositeAggregationSource::Terms(source) => source.missing_bucket, + CompositeAggregationSource::Histogram(source) => source.missing_bucket, + CompositeAggregationSource::DateHistogram(source) => source.missing_bucket, + } + } +} + +/// A paginable aggregation that performs on multiple dimensions (sources), +/// potentially mixing term and range queries. +/// +/// Pagination is made possible because the buckets are ordered by the composite +/// key, so the next page can be fetched "efficiently" by filtering using range +/// queries on the key dimensions. +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] +#[serde( + try_from = "CompositeAggregationSerde", + into = "CompositeAggregationSerde" +)] +pub struct CompositeAggregation { + /// The fields and bucketting strategies. + pub sources: Vec, + /// Number of buckets to return (page size). + pub size: u32, + /// The key of the previous page's last bucket. + pub after: FxHashMap, +} + +#[derive(Serialize, Deserialize)] +struct CompositeAggregationSerde { + sources: Vec>, + size: u32, + #[serde(default, skip_serializing_if = "FxHashMap::is_empty")] + after: FxHashMap, +} + +impl TryFrom for CompositeAggregation { + type Error = TantivyError; + + fn try_from(value: CompositeAggregationSerde) -> Result { + let mut sources = Vec::with_capacity(value.sources.len()); + for map in value.sources { + if map.len() != 1 { + return Err(TantivyError::InvalidArgument( + "each composite source must have exactly one named entry".to_string(), + )); + } + let (name, mut source) = map.into_iter().next().unwrap(); + match &mut source { + CompositeAggregationSource::Terms(source) => { + source.name = name; + } + CompositeAggregationSource::Histogram(source) => { + source.name = name; + } + CompositeAggregationSource::DateHistogram(source) => { + source.name = name; + } + } + sources.push(source); + } + Ok(CompositeAggregation { + sources, + size: value.size, + after: value.after, + }) + } +} + +impl From for CompositeAggregationSerde { + fn from(value: CompositeAggregation) -> Self { + let mut serde_sources = Vec::with_capacity(value.sources.len()); + for source in value.sources { + let (name, stored_source) = match source { + CompositeAggregationSource::Terms(source) => { + let name = source.name.clone(); + // name field is #[serde(skip)] so it won't be serialized inside the value + (name, CompositeAggregationSource::Terms(source)) + } + CompositeAggregationSource::Histogram(source) => { + let name = source.name.clone(); + (name, CompositeAggregationSource::Histogram(source)) + } + CompositeAggregationSource::DateHistogram(source) => { + let name = source.name.clone(); + (name, CompositeAggregationSource::DateHistogram(source)) + } + }; + let mut map = FxHashMap::default(); + map.insert(name, stored_source); + serde_sources.push(map); + } + CompositeAggregationSerde { + sources: serde_sources, + size: value.size, + after: value.after, + } + } +} + +/// Ordering key for intermediate composite keys when comparing different types. +fn type_order_key(key: &CompositeIntermediateKey, order: Order) -> i32 { + let apply_order = match order { + Order::Asc => 1, + Order::Desc => -1, + }; + match key { + CompositeIntermediateKey::Null => panic!("unexpected"), + CompositeIntermediateKey::Bool(_) => 1 * apply_order, + CompositeIntermediateKey::Str(_) => 2 * apply_order, + CompositeIntermediateKey::I64(_) => 3 * apply_order, + CompositeIntermediateKey::F64(_) => 3 * apply_order, + CompositeIntermediateKey::U64(_) => 3 * apply_order, + CompositeIntermediateKey::IpAddr(_) => 4, + CompositeIntermediateKey::DateTime(_) => 4, + } +} + +/// Calculates the ordering between intermediate keys. +pub fn composite_intermediate_key_ordering( + left_opt: &CompositeIntermediateKey, + right_opt: &CompositeIntermediateKey, + order: Order, + missing_order: MissingOrder, +) -> crate::Result { + use CompositeIntermediateKey as CIKey; + let mut forced_ordering = false; + let asc_ordering = match (left_opt, right_opt) { + // null comparisons + (CIKey::Null, CIKey::Null) => Ordering::Equal, + (CIKey::Null, _) => { + forced_ordering = missing_order != MissingOrder::Default; + match missing_order { + MissingOrder::First => Ordering::Less, + MissingOrder::Last => Ordering::Greater, + MissingOrder::Default => Ordering::Less, + } + } + (_, CIKey::Null) => { + forced_ordering = missing_order != MissingOrder::Default; + match missing_order { + MissingOrder::First => Ordering::Greater, + MissingOrder::Last => Ordering::Less, + MissingOrder::Default => Ordering::Greater, + } + } + // same type comparisons + (CIKey::Bool(left), CIKey::Bool(right)) => left.cmp(right), + (CIKey::I64(left), CIKey::I64(right)) => left.cmp(right), + (CIKey::Str(left), CIKey::Str(right)) => left.cmp(right), + (CIKey::IpAddr(left), CIKey::IpAddr(right)) => left.cmp(right), + (CIKey::DateTime(left), CIKey::DateTime(right)) => left.cmp(right), + (CIKey::U64(left), CIKey::U64(right)) => left.cmp(right), + (CIKey::F64(f), CIKey::F64(_)) | (CIKey::F64(_), CIKey::F64(f)) if f.is_nan() => { + return Err(TantivyError::InvalidArgument( + "NaN comparison is not supported".to_string(), + )) + } + (CIKey::F64(left), CIKey::F64(right)) => left.partial_cmp(right).unwrap_or(Ordering::Equal), + // numeric cross-type comparisons + (CIKey::F64(left), CIKey::I64(right)) => cmp_i64_f64(*right, *left)?.reverse(), + (CIKey::F64(left), CIKey::U64(right)) => cmp_u64_f64(*right, *left)?.reverse(), + (CIKey::I64(left), CIKey::F64(right)) => cmp_i64_f64(*left, *right)?, + (CIKey::I64(left), CIKey::U64(right)) => cmp_i64_u64(*left, *right), + (CIKey::U64(left), CIKey::I64(right)) => cmp_i64_u64(*right, *left).reverse(), + (CIKey::U64(left), CIKey::F64(right)) => cmp_u64_f64(*left, *right)?, + // other cross-type comparisons + (type_a, type_b) => { + forced_ordering = true; + type_order_key(type_a, order).cmp(&type_order_key(type_b, order)) + } + }; + if !forced_ordering && order == Order::Desc { + Ok(asc_ordering.reverse()) + } else { + Ok(asc_ordering) + } +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, Ipv6Addr}; + + use serde_json::json; + use time::format_description::well_known::Rfc3339; + use time::OffsetDateTime; + + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::tests::exec_request; + use crate::schema::{Schema, FAST, STRING}; + use crate::Index; + + fn datetime_from_iso_str(date_str: &str) -> common::DateTime { + let dt = OffsetDateTime::parse(date_str, &Rfc3339) + .expect(&format!("Failed to parse date: {}", date_str)); + let timestamp_secs = dt.unix_timestamp_nanos(); + common::DateTime::from_timestamp_nanos(timestamp_secs as i64) + } + + fn ms_timestamp_from_iso_str(date_str: &str) -> i64 { + let dt = OffsetDateTime::parse(date_str, &Rfc3339) + .expect(&format!("Failed to parse date: {}", date_str)); + (dt.unix_timestamp_nanos() / 1_000_000) as i64 + } + + /// Runs the query and compares the result buckets to the expected buckets, + /// then run the same query with a all possible `after` keys and different + /// page sizes. + fn exec_and_assert_all_paginations( + index: &Index, + composite_agg_req: serde_json::Value, + expected_buckets: serde_json::Value, + ) { + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": composite_agg_req + } + })) + .unwrap(); + let res = exec_request(agg_req, &index).unwrap(); + let buckets = &res["my_composite"]["buckets"]; + assert_eq!(buckets, &expected_buckets); + + // Check that all returned buckets can be used as after key + // Note: this is not a requirement of the API, only the key explicitly + // returned as after_key is guaranteed to work, but this is a nice + // property of the implementation. + for (i, expected_bucket) in expected_buckets.as_array().unwrap().iter().enumerate() { + let new_composite_agg_req = json!({ + "sources": composite_agg_req["sources"].clone(), + "size": composite_agg_req["size"].clone(), + "after": expected_bucket["key"].clone() + }); + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": new_composite_agg_req + } + })) + .unwrap(); + let paginated_res = exec_request(agg_req, &index).unwrap(); + assert_eq!( + &paginated_res["my_composite"]["buckets"], + &json!(&expected_buckets.as_array().unwrap()[i + 1..]), + "query with after key from bucket failed: {}", + new_composite_agg_req.to_string() + ); + } + + // paginate 1 by 1 + let one_by_one_composite_agg_req = json!({ + "sources": composite_agg_req["sources"].clone(), + "size": 1, + }); + let mut after_key = None; + for i in 0..expected_buckets.as_array().unwrap().len() { + let mut paged_req = one_by_one_composite_agg_req.clone(); + if let Some(after_key) = after_key { + paged_req["after"] = after_key; + } + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": paged_req + } + })) + .unwrap(); + let paged_res = exec_request(agg_req, &index).unwrap(); + assert_eq!( + &paged_res["my_composite"]["buckets"], + &json!(&[&expected_buckets[i]]), + "1-by-1 pagination failed at index {}, query: {}", + i, + paged_req.to_string() + ); + after_key = paged_res["my_composite"].get("after_key").cloned(); + } + // Ideally, we should not require the user to issue an extra request + // because we could know that this is the last page. + if let Some(last_after_key) = after_key { + let mut last_page_req = one_by_one_composite_agg_req.clone(); + last_page_req["after"] = last_after_key; + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": last_page_req + } + })) + .unwrap(); + let paged_res = exec_request(agg_req, &index).unwrap(); + assert_eq!( + &paged_res["my_composite"]["buckets"], + &json!([]), + "last page request failed, query: {}", + last_page_req.to_string() + ); + after_key = paged_res["my_composite"].get("after_key").cloned(); + } + assert_eq!(after_key, None); + } + + fn composite_aggregation_test(merge_segments: bool) -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let string_field = schema_builder.add_text_field("string_id", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.add_document(doc!(string_field => "termb"))?; + index_writer.add_document(doc!(string_field => "termc"))?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.commit()?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.add_document(doc!(string_field => "termb"))?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.commit()?; + if merge_segments { + index_writer.wait_merging_threads()?; + } + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"term1": {"terms": {"field": "string_id"}}} + ], + "size": 10 + } + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + let buckets = &res["my_composite"]["buckets"]; + + assert_eq!( + buckets, + &json!([ + {"key": {"term1": "terma"}, "doc_count": 5}, + {"key": {"term1": "termb"}, "doc_count": 2}, + {"key": {"term1": "termc"}, "doc_count": 1} + ]) + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_single_segment() -> crate::Result<()> { + composite_aggregation_test(true) + } + + #[test] + fn composite_aggregation_term_multi_segment() -> crate::Result<()> { + composite_aggregation_test(false) + } + + fn composite_aggregation_term_size_limit(merge_segments: bool) -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let string_field = schema_builder.add_text_field("string_id", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.add_document(doc!(string_field => "termb"))?; + index_writer.commit()?; + index_writer.add_document(doc!(string_field => "termc"))?; + index_writer.add_document(doc!(string_field => "termd"))?; + index_writer.add_document(doc!(string_field => "terme"))?; + index_writer.commit()?; + if merge_segments { + index_writer.wait_merging_threads()?; + } + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"myterm": {"terms": {"field": "string_id"}}} + ], + "size": 3 + } + } + })) + .unwrap(); + let res = exec_request(agg_req, &index)?; + let buckets = &res["my_composite"]["buckets"]; + // Should only return 3 buckets due to size limit + assert_eq!( + buckets, + &json!([ + {"key": {"myterm": "terma"}, "doc_count": 1}, + {"key": {"myterm": "termb"}, "doc_count": 1}, + {"key": {"myterm": "termc"}, "doc_count": 1} + ]) + ); + + // next page + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"myterm": {"terms": {"field": "string_id"}}} + ], + "size": 3, + "after": &res["my_composite"]["after_key"] + } + } + })) + .unwrap(); + let res = exec_request(agg_req, &index)?; + let buckets = &res["my_composite"]["buckets"]; + assert_eq!( + buckets, + &json!([ + {"key": {"myterm": "termd"}, "doc_count": 1}, + {"key": {"myterm": "terme"}, "doc_count": 1} + ]) + ); + assert!(res["my_composite"].get("after_key").is_none()); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_size_limit_single_segment() -> crate::Result<()> { + composite_aggregation_term_size_limit(true) + } + + #[test] + fn composite_aggregation_term_size_limit_multi_segment() -> crate::Result<()> { + composite_aggregation_term_size_limit(false) + } + + #[test] + fn composite_aggregation_term_ordering() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let string_field = schema_builder.add_text_field("string_id", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(string_field => "zebra"))?; + index_writer.add_document(doc!(string_field => "apple"))?; + index_writer.add_document(doc!(string_field => "banana"))?; + index_writer.add_document(doc!(string_field => "cherry"))?; + index_writer.add_document(doc!(string_field => "dog"))?; + index_writer.add_document(doc!(string_field => "elephant"))?; + index_writer.add_document(doc!(string_field => "fox"))?; + index_writer.add_document(doc!(string_field => "grape"))?; + index_writer.commit()?; + } + + // Test ascending order (default) + let agg_req: Aggregations = serde_json::from_value(json!({ + "fruity_aggreg": { + "composite": { + "sources": [ + {"myterm": {"terms": {"field": "string_id", "order": "asc"}}} + ], + "size": 5 + } + } + })) + .unwrap(); + let res = exec_request(agg_req, &index)?; + let buckets = &res["fruity_aggreg"]["buckets"]; + // Should return only 5 buckets due to size limit, in ascending order + assert_eq!( + buckets, + &json!([ + {"key": {"myterm": "apple"}, "doc_count": 1}, + {"key": {"myterm": "banana"}, "doc_count": 1}, + {"key": {"myterm": "cherry"}, "doc_count": 1}, + {"key": {"myterm": "dog"}, "doc_count": 1}, + {"key": {"myterm": "elephant"}, "doc_count": 1} + ]) + ); + + // Test descending order + let agg_req: Aggregations = serde_json::from_value(json!({ + "fruity_aggreg": { + "composite": { + "sources": [ + {"myterm": {"terms": {"field": "string_id", "order": "desc"}}} + ], + "size": 5 + } + } + })) + .unwrap(); + let res = exec_request(agg_req, &index)?; + let buckets = &res["fruity_aggreg"]["buckets"]; + // Should return only 5 buckets due to size limit, in descending order + assert_eq!( + buckets, + &json!([ + {"key": {"myterm": "zebra"}, "doc_count": 1}, + {"key": {"myterm": "grape"}, "doc_count": 1}, + {"key": {"myterm": "fox"}, "doc_count": 1}, + {"key": {"myterm": "elephant"}, "doc_count": 1}, + {"key": {"myterm": "dog"}, "doc_count": 1} + ]) + ); + + // next page in descending order + let agg_req: Aggregations = serde_json::from_value(json!({ + "fruity_aggreg": { + "composite": { + "sources": [ + {"myterm": {"terms": {"field": "string_id", "order": "desc"}}} + ], + "size": 5, + "after": &res["fruity_aggreg"]["after_key"] + } + } + })) + .unwrap(); + let res = exec_request(agg_req, &index)?; + let buckets = &res["fruity_aggreg"]["buckets"]; + // Should return only 5 buckets due to size limit, in descending order + assert_eq!( + buckets, + &json!([ + {"key": {"myterm": "cherry"}, "doc_count": 1}, + {"key": {"myterm": "banana"}, "doc_count": 1}, + {"key": {"myterm": "apple"}, "doc_count": 1} + ]) + ); + assert!(res["my_composite"].get("after_key").is_none()); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_missing_values() -> crate::Result<()> { + // Create index with some documents having missing values + let mut schema_builder = Schema::builder(); + let string_field = schema_builder.add_text_field("string_id", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.add_document(doc!(string_field => "termb"))?; + index_writer.add_document(doc!())?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.commit()?; + } + + // Test without missing bucket (should ignore missing values) + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"myterm": {"terms": {"field": "string_id", "missing_bucket": false}}} + ], + "size": 10 + }), + json!([ + {"key": {"myterm": "terma"}, "doc_count": 2}, + {"key": {"myterm": "termb"}, "doc_count": 1} + ]), + ); + + // Test with missing bucket enabled + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"myterm": {"terms": {"field": "string_id", "missing_bucket": true}}} + ], + "size": 10 + }), + // Should have 3 buckets including the missing bucket + // Missing bucket should come first in ascending order by default + json!([ + {"key": {"myterm": null}, "doc_count": 1}, + {"key": {"myterm": "terma"}, "doc_count": 2}, + {"key": {"myterm": "termb"}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_missing_order() -> crate::Result<()> { + // Create index with missing values + let mut schema_builder = Schema::builder(); + let string_field = schema_builder.add_text_field("string_id", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(string_field => "termb"))?; + index_writer.add_document(doc!())?; + index_writer.add_document(doc!(string_field => "terma"))?; + index_writer.commit()?; + } + + // Test missing_order: "first" + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + { + "myterm": { + "terms": { + "field": "string_id", + "missing_bucket": true, + "missing_order": "first", + "order": "asc" + } + } + } + ], + "size": 10 + }), + json!([ + {"key": {"myterm": null}, "doc_count": 1}, + {"key": {"myterm": "terma"}, "doc_count": 1}, + {"key": {"myterm": "termb"}, "doc_count": 1} + ]), + ); + + // Test missing_order: "last" + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + { + "myterm": { + "terms": { + "field": "string_id", + "missing_bucket": true, + "missing_order": "last", + "order": "asc" + } + } + } + ], + "size": 10 + }), + json!([ + {"key": {"myterm": "terma"}, "doc_count": 1}, + {"key": {"myterm": "termb"}, "doc_count": 1}, + {"key": {"myterm": null}, "doc_count": 1} + ]), + ); + + // Test missing_order: "default" with desc order + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + { + "myterm": { + "terms": { + "field": "string_id", + "missing_bucket": true, + "missing_order": "default", + "order": "desc" + } + } + } + ], + "size": 10 + }), + json!([ + {"key": {"myterm": "termb"}, "doc_count": 1}, + {"key": {"myterm": "terma"}, "doc_count": 1}, + {"key": {"myterm": null}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_multi_source() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let cat = schema_builder.add_text_field("category", STRING | FAST); + let status = schema_builder.add_text_field("status", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(cat => "electronics", status => "active"))?; + index_writer.add_document(doc!(cat => "electronics", status => "inactive"))?; + index_writer.add_document(doc!(cat => "electronics", status => "active"))?; + index_writer.add_document(doc!(cat => "books", status => "active"))?; + index_writer.add_document(doc!(cat => "books", status => "inactive"))?; + index_writer.add_document(doc!(cat => "clothing", status => "active"))?; + index_writer.commit()?; + } + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"category": {"terms": {"field": "category"}}}, + {"status": {"terms": {"field": "status"}}} + ], + "size": 10 + }), + // Should have composite keys with both dimensions in sorted order + json!([ + {"key": {"category": "books", "status": "active"}, "doc_count": 1}, + {"key": {"category": "books", "status": "inactive"}, "doc_count": 1}, + {"key": {"category": "clothing", "status": "active"}, "doc_count": 1}, + {"key": {"category": "electronics", "status": "active"}, "doc_count": 2}, + {"key": {"category": "electronics", "status": "inactive"}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_multi_source_ordering() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let cat = schema_builder.add_text_field("category", STRING | FAST); + let priority = schema_builder.add_text_field("priority", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(cat => "zebra", priority => "high"))?; + index_writer.add_document(doc!(cat => "apple", priority => "low"))?; + index_writer.add_document(doc!(cat => "zebra", priority => "low"))?; + index_writer.add_document(doc!(cat => "apple", priority => "high"))?; + index_writer.commit()?; + } + + // Test with different ordering on different sources + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"category": {"terms": {"field": "category", "order": "asc"}}}, + {"priority": {"terms": {"field": "priority", "order": "desc"}}} + ], + "size": 10 + }), + json!([ + {"key": {"category": "apple", "priority": "low"}, "doc_count": 1}, + {"key": {"category": "apple", "priority": "high"}, "doc_count": 1}, + {"key": {"category": "zebra", "priority": "low"}, "doc_count": 1}, + {"key": {"category": "zebra", "priority": "high"}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_with_sub_aggregations() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let score_field = schema_builder.add_f64_field("score_f64", FAST); + let string_field = schema_builder.add_text_field("string_id", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(score_field => 5.0f64, string_field => "terma"))?; + index_writer.add_document(doc!(score_field => 2.0f64, string_field => "termb"))?; + index_writer.add_document(doc!(score_field => 3.0f64, string_field => "terma"))?; + index_writer.add_document(doc!(score_field => 7.0f64, string_field => "termb"))?; + index_writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"myterm": {"terms": {"field": "string_id"}}} + ], + "size": 10 + }, + "aggs": { + "avg_score": { + "avg": { + "field": "score_f64" + } + }, + "max_score": { + "max": { + "field": "score_f64" + } + } + } + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + let buckets = &res["my_composite"]["buckets"]; + + // Check that sub-aggregations are computed for each bucket with specific values + assert_eq!( + buckets, + &json!([ + { + "key": {"myterm": "terma"}, + "doc_count": 2, + "avg_score": {"value": 4.0}, // (5+3)/2 + "max_score": {"value": 5.0} + }, + { + "key": {"myterm": "termb"}, + "doc_count": 2, + "avg_score": {"value": 4.5}, // (2+7)/2 + "max_score": {"value": 7.0} + } + ]) + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_validation_errors() -> crate::Result<()> { + // Create index with explicit document creation + let mut schema_builder = Schema::builder(); + let string_field = schema_builder.add_text_field("string_id", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(string_field => "term"))?; + index_writer.commit()?; + } + + // Test empty sources + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [], + "size": 10 + } + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index); + assert!(res.is_err()); + + // Test size = 0 + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"myterm": {"terms": {"field": "string_id"}}} + ], + "size": 0 + } + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index); + assert!(res.is_err()); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_numeric_fields() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let score_field = schema_builder.add_f64_field("score", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(score_field => 1.0f64))?; + index_writer.add_document(doc!(score_field => 2.0f64))?; + index_writer.add_document(doc!(score_field => 1.0f64))?; + index_writer.add_document(doc!(score_field => 3.33f64))?; + index_writer.commit()?; + index_writer.add_document(doc!(score_field => 1.0f64))?; + index_writer.commit()?; + } + + // Test composite aggregation on numeric field + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"score": {"terms": {"field": "score"}}} + ], + "size": 10 + }), + json!([ + {"key": {"score": 1}, "doc_count": 3}, + {"key": {"score": 2}, "doc_count": 1}, + {"key": {"score": 3.33}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_date_fields() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let date_field = schema_builder.add_date_field("timestamp", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + // Add documents with different dates + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-01-01T00:00:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2022-01-01T00:00:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-01-01T00:00:00Z")))?; // duplicate + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2023-01-01T00:00:00Z")))?; + index_writer.commit()?; + } + + // Test composite aggregation on date field + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"timestamp": {"terms": {"field": "timestamp"}}} + ], + "size": 10 + }), + json!([ + {"key": {"timestamp": "2021-01-01T00:00:00Z"}, "doc_count": 2}, + {"key": {"timestamp": "2022-01-01T00:00:00Z"}, "doc_count": 1}, + {"key": {"timestamp": "2023-01-01T00:00:00Z"}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_ip_fields() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let ip_field = schema_builder.add_ip_addr_field("ip_addr", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let ipv4 = |ip: &str| ip.parse::().unwrap().to_ipv6_mapped(); + let ipv6 = |ip: &str| ip.parse::().unwrap(); + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(ip_field => ipv4("192.168.1.1")))?; + index_writer.add_document(doc!(ip_field => ipv4("10.0.0.1")))?; + index_writer.add_document(doc!(ip_field => ipv4("192.168.1.1")))?; // duplicate + index_writer.add_document(doc!(ip_field => ipv4("172.16.0.1")))?; + index_writer.add_document(doc!(ip_field => ipv6("2001:db8::1")))?; + index_writer.add_document(doc!(ip_field => ipv6("::1")))?; // localhost + index_writer.add_document(doc!())?; + index_writer.add_document(doc!(ip_field => ipv6("2001:db8::1")))?; // duplicate + index_writer.commit()?; + } + + // Test composite aggregation on IP field + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"ip_addr": {"terms": {"field": "ip_addr"}}} + ], + "size": 10 + }), + json!([ + {"key": {"ip_addr": "::1"}, "doc_count": 1}, + {"key": {"ip_addr": "10.0.0.1"}, "doc_count": 1}, + {"key": {"ip_addr": "172.16.0.1"}, "doc_count": 1}, + {"key": {"ip_addr": "192.168.1.1"}, "doc_count": 2}, + {"key": {"ip_addr": "2001:db8::1"}, "doc_count": 2} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_multiple_column_types() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let score_field = schema_builder.add_f64_field("score", FAST); + let string_field = schema_builder.add_text_field("string_id", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(score_field => 1.0f64, string_field => "apple"))?; + index_writer.add_document(doc!(score_field => 2.0f64, string_field => "banana"))?; + index_writer.add_document(doc!(score_field => 1.0f64, string_field => "apple"))?; + index_writer.add_document(doc!(score_field => 2.0f64, string_field => "banana"))?; + index_writer.add_document(doc!(score_field => 3.0f64, string_field => "cherry"))?; + index_writer.add_document(doc!(score_field => 1.0f64, string_field => "banana"))?; + index_writer.commit()?; + } + + // Test composite aggregation mixing numeric and text fields + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"category": {"terms": {"field": "string_id", "order": "asc"}}}, + {"score": {"terms": {"field": "score", "order": "desc"}}} + ], + "size": 10 + }), + json!([ + {"key": {"category": "apple", "score": 1}, "doc_count": 2}, + {"key": {"category": "banana", "score": 2}, "doc_count": 2}, + {"key": {"category": "banana", "score": 1}, "doc_count": 1}, + {"key": {"category": "cherry", "score": 3}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_json_various_types() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let json_field = schema_builder.add_json_field("json_data", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document( + doc!(json_field => json!({"cat": "elec", "price": 999, "avail": true})), + )?; + index_writer.add_document( + doc!(json_field => json!({"cat": "books", "price": 15, "avail": false})), + )?; + index_writer.add_document( + doc!(json_field => json!({"cat": "elec", "price": 200, "avail": true})), + )?; + index_writer.add_document( + doc!(json_field => json!({"cat": "books", "price": 25, "avail": true})), + )?; + index_writer.commit()?; + } + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"cat": {"terms": {"field": "json_data.cat"}}}, + {"avail": {"terms": {"field": "json_data.avail"}}}, + {"price": {"terms": {"field": "json_data.price", "order": "desc"}}} + ], + "size": 10 + }), + json!([ + {"key": {"cat": "books", "avail": false, "price": 15}, "doc_count": 1}, + {"key": {"cat": "books", "avail": true, "price": 25}, "doc_count": 1}, + {"key": {"cat": "elec", "avail": true, "price": 999}, "doc_count": 1}, + {"key": {"cat": "elec", "avail": true, "price": 200}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_json_missing_fields() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let json_field = schema_builder.add_json_field("json_data", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer + .add_document(doc!(json_field => json!({"cat": "elec", "brand": "apple"})))?; + index_writer + .add_document(doc!(json_field => json!({"cat": "books", "brand": "gut"})))?; + index_writer.add_document(doc!(json_field => json!({"cat": "books"})))?; // missing brand + index_writer.add_document(doc!(json_field => json!({"brand": "samsung"})))?; // missing category + index_writer + .add_document(doc!(json_field => json!({"cat": "elec", "brand": "samsung"})))?; + index_writer.commit()?; + } + + // Test with missing bucket enabled + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"cat": {"terms": {"field": "json_data.cat", "missing_bucket": true}}}, + {"brand": {"terms": {"field": "json_data.brand", "missing_bucket": true, "missing_order": "last"}}} + ], + "size": 10 + }), + json!([ + {"key": {"cat": null, "brand": "samsung"}, "doc_count": 1}, + {"key": {"cat": "books", "brand": "gut"}, "doc_count": 1}, + {"key": {"cat": "books", "brand": null}, "doc_count": 1}, + {"key": {"cat": "elec", "brand": "apple"}, "doc_count": 1}, + {"key": {"cat": "elec", "brand": "samsung"}, "doc_count": 1} + ]), + ); + + // Small twist on the missing order of the second source + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"cat": {"terms": {"field": "json_data.cat", "missing_bucket": true}}}, + {"brand": {"terms": {"field": "json_data.brand", "missing_bucket": true, "missing_order": "first"}}} + ], + "size": 10 + }), + json!([ + {"key": {"cat": null, "brand": "samsung"}, "doc_count": 1}, + {"key": {"cat": "books", "brand": null}, "doc_count": 1}, + {"key": {"cat": "books", "brand": "gut"}, "doc_count": 1}, + {"key": {"cat": "elec", "brand": "apple"}, "doc_count": 1}, + {"key": {"cat": "elec", "brand": "samsung"}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_json_nested_fields() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let json_field = schema_builder.add_json_field("json_data", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document( + doc!(json_field => json!({"prod": {"name": "laptop", "cpu": "intel"}})), + )?; + index_writer.add_document( + doc!(json_field => json!({"prod": {"name": "phone", "cpu": "snap"}})), + )?; + index_writer.add_document( + doc!(json_field => json!({"prod": {"name": "laptop", "cpu": "amd"}})), + )?; + index_writer.add_document( + doc!(json_field => json!({"prod": {"name": "tablet", "cpu": "intel"}})), + )?; + index_writer.commit()?; + } + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"name": {"terms": {"field": "json_data.prod.name"}}}, + {"cpu": {"terms": {"field": "json_data.prod.cpu"}}} + ], + "size": 10 + }), + json!([ + {"key": {"name": "laptop", "cpu": "amd"}, "doc_count": 1}, + {"key": {"name": "laptop", "cpu": "intel"}, "doc_count": 1}, + {"key": {"name": "phone", "cpu": "snap"}, "doc_count": 1}, + {"key": {"name": "tablet", "cpu": "intel"}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_json_mixed_types() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let json_field = schema_builder.add_json_field("json_data", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(json_field => json!({"id": "doc1"})))?; + // this segment's numeric is i64 + index_writer.add_document(doc!(json_field => json!({"id": 100})))?; + index_writer.add_document(doc!(json_field => json!({"id": true})))?; + index_writer.add_document(doc!(json_field => json!({"id": "doc2"})))?; + index_writer.add_document(doc!(json_field => json!({"id": 50})))?; + index_writer.add_document(doc!(json_field => json!({"id": false})))?; + index_writer.add_document(doc!(json_field => json!({"id": "doc3"})))?; + index_writer.commit()?; + // this segment's numeric is f64 + index_writer.add_document(doc!(json_field => json!({"id": 33.3})))?; + index_writer.add_document(doc!(json_field => json!({"id": 50})))?; + index_writer.commit()?; + } + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"id": {"terms": {"field": "json_data.id", "order": "asc"}}} + ], + "size": 10 + }), + json!([ + {"key": {"id": false}, "doc_count": 1}, + {"key": {"id": true}, "doc_count": 1}, + {"key": {"id": "doc1"}, "doc_count": 1}, + {"key": {"id": "doc2"}, "doc_count": 1}, + {"key": {"id": "doc3"}, "doc_count": 1}, + {"key": {"id": 33.3}, "doc_count": 1}, + {"key": {"id": 50}, "doc_count": 2}, + {"key": {"id": 100}, "doc_count": 1} + ]), + ); + + // Test descending order + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"id": {"terms": {"field": "json_data.id", "order": "desc"}}} + ], + "size": 10 + }), + json!([ + {"key": {"id": 100}, "doc_count": 1}, + {"key": {"id": 50}, "doc_count": 2}, + {"key": {"id": 33.3}, "doc_count": 1}, + {"key": {"id": "doc3"}, "doc_count": 1}, + {"key": {"id": "doc2"}, "doc_count": 1}, + {"key": {"id": "doc1"}, "doc_count": 1}, + {"key": {"id": true}, "doc_count": 1}, + {"key": {"id": false}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_term_multi_value_fields() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", FAST | STRING); + let num_field = schema_builder.add_u64_field("num", FAST); + let index = Index::create_in_ram(schema_builder.build()); + + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + // Document with multiple values for text and num fields + index_writer.add_document(doc!( + text_field => "apple", + text_field => "banana", + num_field => 10u64, + num_field => 20u64, + ))?; + index_writer.add_document(doc!( + text_field => "cherry", + num_field => 30u64, + ))?; + // Multi valued document with duplicate values + index_writer.add_document(doc!( + text_field => "elderberry", + text_field => "date", + text_field => "elderberry", + num_field => 40u64, + ))?; + + index_writer.commit()?; + } + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"text_terms": {"terms": {"field": "text"}}} + ], + "size": 10 + }), + json!([ + {"key": {"text_terms": "apple"}, "doc_count": 1}, + {"key": {"text_terms": "banana"}, "doc_count": 1}, + {"key": {"text_terms": "cherry"}, "doc_count": 1}, + {"key": {"text_terms": "date"}, "doc_count": 1}, + // this is not the doc count but the term occurrence count + // https://github.com/quickwit-oss/tantivy/issues/2721 + {"key": {"text_terms": "elderberry"}, "doc_count": 2} + ]), + ); + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"num_terms": {"terms": {"field": "num"}}} + ], + "size": 10 + }), + json!([ + {"key": {"num_terms": 10}, "doc_count": 1}, + {"key": {"num_terms": 20}, "doc_count": 1}, + {"key": {"num_terms": 30}, "doc_count": 1}, + {"key": {"num_terms": 40}, "doc_count": 1} + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_histogram_basic() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let num_field = schema_builder.add_f64_field("value", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(num_field => -0.5f64))?; + index_writer.add_document(doc!(num_field => 1.0f64))?; + index_writer.add_document(doc!(num_field => 2.0f64))?; + index_writer.add_document(doc!(num_field => 5.0f64))?; + index_writer.add_document(doc!(num_field => 7.0f64))?; + index_writer.add_document(doc!(num_field => 11.0f64))?; + index_writer.commit()?; + } + + // Histogram with interval 5 + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"val_hist": {"histogram": {"field": "value", "interval": 5.0}}} + ], + "size": 10 + }), + json!([ + {"key": {"val_hist": -5.0}, "doc_count": 1}, + {"key": {"val_hist": 0.0}, "doc_count": 2}, + {"key": {"val_hist": 5.0}, "doc_count": 2}, + {"key": {"val_hist": 10.0}, "doc_count": 1} + ]), + ); + Ok(()) + } + + #[test] + fn composite_aggregation_histogram_json_mixed_types() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let json_field = schema_builder.add_json_field("json_data", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + // this segment's numeric is i64 + index_writer.add_document(doc!(json_field => json!({"id": "doc1"})))?; + index_writer.add_document(doc!(json_field => json!({"id": 100})))?; + index_writer.add_document(doc!(json_field => json!({"id": true})))?; + index_writer.add_document(doc!(json_field => json!({"id": "doc2"})))?; + index_writer.add_document(doc!(json_field => json!({"id": 50})))?; + index_writer.add_document(doc!(json_field => json!({"id": false})))?; + index_writer.add_document(doc!(json_field => json!({"id": "doc3"})))?; + index_writer.commit()?; + // this segment's numeric is f64 + index_writer.add_document(doc!(json_field => json!({"id": 33.3})))?; + index_writer.add_document(doc!(json_field => json!({"id": 50})))?; + index_writer.add_document(doc!(json_field => json!({"id": -0.01})))?; + index_writer.commit()?; + } + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"id": {"histogram": {"field": "json_data.id", "interval": 50, "order": "asc"}}} + ], + "size": 10 + }), + json!([ + {"key": {"id": -50.0}, "doc_count": 1}, + {"key": {"id": 0.0}, "doc_count": 1}, + {"key": {"id": 50.0}, "doc_count": 2}, + {"key": {"id": 100.0}, "doc_count": 1}, + + ]), + ); + + // Test descending order + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"id": {"histogram": {"field": "json_data.id", "interval": 50, "order": "desc"}}} + ], + "size": 10 + }), + json!([ + {"key": {"id": 100.0}, "doc_count": 1}, + {"key": {"id": 50.0}, "doc_count": 2}, + {"key": {"id": 0.0}, "doc_count": 1}, + {"key": {"id": -50.0}, "doc_count": 1}, + ]), + ); + + Ok(()) + } + + #[test] + fn composite_aggregation_date_histogram_calendar_interval() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let date_field = schema_builder.add_date_field("dt", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-01-01T00:00:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-02-01T00:00:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2022-01-01T00:00:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2023-01-01T00:00:00Z")))?; + index_writer.commit()?; + } + + // Date histogram with calendar_interval = "year" + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"dt_hist": {"date_histogram": {"field": "dt", "calendar_interval": "year"}}} + ], + "size": 10 + }), + json!([ + {"key": {"dt_hist": ms_timestamp_from_iso_str("2021-01-01T00:00:00Z")}, "doc_count": 2}, + {"key": {"dt_hist": ms_timestamp_from_iso_str("2022-01-01T00:00:00Z")}, "doc_count": 1}, + {"key": {"dt_hist": ms_timestamp_from_iso_str("2023-01-01T00:00:00Z")}, "doc_count": 1} + ]), + ); + Ok(()) + } + + #[test] + fn composite_aggregation_date_histogram_fixed_interval() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let date_field = schema_builder.add_date_field("dt", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-01-01T00:00:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-01-01T05:30:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-01-01T06:00:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-01-01T12:00:00Z")))?; + index_writer + .add_document(doc!(date_field => datetime_from_iso_str("2021-01-01T18:00:00Z")))?; + index_writer.commit()?; + } + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"dt_hist": {"date_histogram": {"field": "dt", "fixed_interval": "6h"}}} + ], + "size": 10 + }), + json!([ + {"key": {"dt_hist": ms_timestamp_from_iso_str("2021-01-01T00:00:00Z")}, "doc_count": 2}, + {"key": {"dt_hist": ms_timestamp_from_iso_str("2021-01-01T06:00:00Z")}, "doc_count": 1}, + {"key": {"dt_hist": ms_timestamp_from_iso_str("2021-01-01T12:00:00Z")}, "doc_count": 1}, + {"key": {"dt_hist": ms_timestamp_from_iso_str("2021-01-01T18:00:00Z")}, "doc_count": 1} + ]), + ); + Ok(()) + } + + #[test] + fn composite_aggregation_mixed_term_and_date_histogram() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let date_field = schema_builder.add_date_field("timestamp", FAST); + let category_field = schema_builder.add_text_field("category", STRING | FAST); + let index = Index::create_in_ram(schema_builder.build()); + + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!( + date_field => datetime_from_iso_str("2021-01-01T05:00:00Z"), + category_field => "electronics" + ))?; + index_writer.add_document(doc!( + date_field => datetime_from_iso_str("2021-01-15T10:30:00Z"), + category_field => "electronics" + ))?; + index_writer.add_document(doc!( + date_field => datetime_from_iso_str("2021-01-05T12:00:00Z"), + category_field => "books" + ))?; + index_writer.add_document(doc!( + date_field => datetime_from_iso_str("2021-02-10T08:45:00Z"), + category_field => "books" + ))?; + index_writer.add_document(doc!( + date_field => datetime_from_iso_str("2021-02-05T14:20:00Z"), + category_field => "clothing" + ))?; + index_writer.add_document(doc!( + date_field => datetime_from_iso_str("2021-02-20T09:15:00Z"), + category_field => "clothing" + ))?; + + index_writer.commit()?; + } + + exec_and_assert_all_paginations( + &index, + json!({ + "sources": [ + {"category": {"terms": {"field": "category"}}}, + {"month": {"date_histogram": {"field": "timestamp", "calendar_interval": "month"}}} + ], + "size": 10 + }), + json!([ + {"key": {"category": "books", "month": ms_timestamp_from_iso_str("2021-01-01T00:00:00Z")}, "doc_count": 1}, + {"key": {"category": "books", "month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z")}, "doc_count": 1}, + {"key": {"category": "clothing", "month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z")}, "doc_count": 2}, + {"key": {"category": "electronics", "month": ms_timestamp_from_iso_str("2021-01-01T00:00:00Z")}, "doc_count": 2} + ]), + ); + + // Test with different ordering for sources with a size limit + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"month": {"date_histogram": {"field": "timestamp", "calendar_interval": "month"}}}, + {"category": {"terms": {"field": "category", "order": "desc"}}} + ], + "size": 3 + } + } + })) + .unwrap(); + let res = exec_request(agg_req, &index)?; + let buckets = &res["my_composite"]["buckets"]; + assert_eq!( + buckets, + &json!([ + {"key": {"month": ms_timestamp_from_iso_str("2021-01-01T00:00:00Z"), "category": "electronics"}, "doc_count": 2}, + {"key": {"month": ms_timestamp_from_iso_str("2021-01-01T00:00:00Z"), "category": "books"}, "doc_count": 1}, + {"key": {"month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z"), "category": "clothing"}, "doc_count": 2}, + ]), + ); + + // next page + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"month": {"date_histogram": {"field": "timestamp", "calendar_interval": "month"}}}, + {"category": {"terms": {"field": "category", "order": "desc"}}} + ], + "size": 3, + "after": res["my_composite"]["after_key"] + } + } + })) + .unwrap(); + let res = exec_request(agg_req, &index)?; + let buckets = &res["my_composite"]["buckets"]; + assert_eq!( + buckets, + &json!([ + {"key": {"month": ms_timestamp_from_iso_str("2021-02-01T00:00:00Z"), "category": "books"}, "doc_count": 1}, + ]), + ); + assert!(res["my_composite"].get("after_key").is_none()); + + Ok(()) + } + + #[test] + fn composite_aggregation_no_matching_columns() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let date_field = schema_builder.add_f64_field("dt", FAST); + let index = Index::create_in_ram(schema_builder.build()); + { + let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; + index_writer.add_document(doc!(date_field => 1.0))?; + index_writer.add_document(doc!(date_field => 2.0))?; + index_writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"dt_hist": {"date_histogram": {"field": "dt", "fixed_interval": "6h"}}} + ], + "size": 10 + } + } + })) + .unwrap(); + let res = exec_request(agg_req, &index)?; + let buckets = &res["my_composite"]["buckets"]; + assert_eq!(buckets, &json!([])); + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_composite": { + "composite": { + "sources": [ + {"dt_hist": {"date_histogram": {"field": "dt", "fixed_interval": "6h", "missing_bucket": true}}} + ], + "size": 10, + } + } + })) + .unwrap(); + + let res = exec_request(agg_req, &index)?; + let buckets = &res["my_composite"]["buckets"]; + + assert_eq!( + buckets, + &json!([{"key": {"dt_hist": null}, "doc_count": 2}]) + ); + Ok(()) + } +} diff --git a/src/aggregation/bucket/composite/numeric_types.rs b/src/aggregation/bucket/composite/numeric_types.rs new file mode 100644 index 0000000000..34dd74e52b --- /dev/null +++ b/src/aggregation/bucket/composite/numeric_types.rs @@ -0,0 +1,460 @@ +/// This modules helps comparing numerical values of different types (i64, u64 +/// and f64). +pub(super) mod num_cmp { + use std::cmp::Ordering; + + use crate::TantivyError; + + pub fn cmp_i64_f64(left_i: i64, right_f: f64) -> crate::Result { + if right_f.is_nan() { + return Err(TantivyError::InvalidArgument( + "NaN comparison is not supported".to_string(), + )); + } + + // If right_f is < i64::MIN then left_i > right_f (i64::MIN=-2^63 can be + // exactly represented as f64) + if right_f < i64::MIN as f64 { + return Ok(Ordering::Greater); + } + // If right_f is >= i64::MAX then left_i < right_f (i64::MAX=2^63-1 cannot + // be exactly represented as f64) + if right_f >= i64::MAX as f64 { + return Ok(Ordering::Less); + } + + // Now right_f is in (i64::MIN, i64::MAX), so `right_f as i64` is + // well-defined (truncation toward 0) + let right_as_i = right_f as i64; + + let result = match left_i.cmp(&right_as_i) { + Ordering::Less => Ordering::Less, + Ordering::Greater => Ordering::Greater, + Ordering::Equal => { + // they have the same integer part, compare the fraction + let rem = right_f - (right_as_i as f64); + if rem == 0.0 { + Ordering::Equal + } else if right_f > 0.0 { + Ordering::Less + } else { + Ordering::Greater + } + } + }; + Ok(result) + } + + pub fn cmp_u64_f64(left_u: u64, right_f: f64) -> crate::Result { + if right_f.is_nan() { + return Err(TantivyError::InvalidArgument( + "NaN comparison is not supported".to_string(), + )); + } + + // Negative floats are always less than any u64 >= 0 + if right_f < 0.0 { + return Ok(Ordering::Greater); + } + + // If right_f is >= u64::MAX then left_u < right_f (u64::MAX=2^64-1 cannot be exactly) + let max_as_f = u64::MAX as f64; + if right_f > max_as_f { + return Ok(Ordering::Less); + } + + // Now right_f is in (0, u64::MAX), so `right_f as u64` is well-defined + // (truncation toward 0) + let right_as_u = right_f as u64; + + let result = match left_u.cmp(&right_as_u) { + Ordering::Less => Ordering::Less, + Ordering::Greater => Ordering::Greater, + Ordering::Equal => { + // they have the same integer part, compare the fraction + let rem = right_f - (right_as_u as f64); + if rem == 0.0 { + Ordering::Equal + } else { + Ordering::Less + } + } + }; + Ok(result) + } + + pub fn cmp_i64_u64(left_i: i64, right_u: u64) -> Ordering { + if left_i < 0 { + Ordering::Less + } else { + let left_as_u = left_i as u64; + left_as_u.cmp(&right_u) + } + } +} + +/// This modules helps projecting numerical values to other numerical types. +/// When the target value space cannot exactly represent the source value, the +/// next representable value is returned (or AfterLast if the source value is +/// larger than the largest representable value). +/// +/// All functions in this module assume that f64 values are not NaN. +pub(super) mod num_proj { + #[derive(Debug, PartialEq)] + pub enum ProjectedNumber { + Exact(T), + Next(T), + AfterLast, + } + + pub fn i64_to_u64(value: i64) -> ProjectedNumber { + if value < 0 { + ProjectedNumber::Next(0) + } else { + ProjectedNumber::Exact(value as u64) + } + } + + pub fn u64_to_i64(value: u64) -> ProjectedNumber { + if value > i64::MAX as u64 { + ProjectedNumber::AfterLast + } else { + ProjectedNumber::Exact(value as i64) + } + } + + pub fn f64_to_u64(value: f64) -> ProjectedNumber { + if value < 0.0 { + ProjectedNumber::Next(0) + } else if value > u64::MAX as f64 { + ProjectedNumber::AfterLast + } else if value.fract() == 0.0 { + ProjectedNumber::Exact(value as u64) + } else { + // casting f64 to u64 truncates toward zero + ProjectedNumber::Next(value as u64 + 1) + } + } + + pub fn f64_to_i64(value: f64) -> ProjectedNumber { + if value < (i64::MIN as f64) { + return ProjectedNumber::Next(i64::MIN); + } else if value >= (i64::MAX as f64) { + return ProjectedNumber::AfterLast; + } else if value.fract() == 0.0 { + ProjectedNumber::Exact(value as i64) + } else if value > 0.0 { + // casting f64 to i64 truncates toward zero + ProjectedNumber::Next(value as i64 + 1) + } else { + ProjectedNumber::Next(value as i64) + } + } + + pub fn i64_to_f64(value: i64) -> ProjectedNumber { + let value_f = value as f64; + let k_roundtrip = value_f as i64; + if k_roundtrip == value { + // between -2^53 and 2^53 all i64 are exactly represented as f64 + ProjectedNumber::Exact(value_f) + } else { + // for very large/small i64 values, it is approximated to the closest f64 + if k_roundtrip > value { + ProjectedNumber::Next(value_f) + } else { + ProjectedNumber::Next(value_f.next_up()) + } + } + } + + pub fn u64_to_f64(value: u64) -> ProjectedNumber { + let value_f = value as f64; + let k_roundtrip = value_f as u64; + if k_roundtrip == value { + // between 0 and 2^53 all u64 are exactly represented as f64 + ProjectedNumber::Exact(value_f) + } else if k_roundtrip > value { + ProjectedNumber::Next(value_f) + } else { + ProjectedNumber::Next(value_f.next_up()) + } + } +} + +#[cfg(test)] +mod num_cmp_tests { + use std::cmp::Ordering; + + use super::num_cmp::*; + + #[test] + fn test_cmp_u64_f64() { + // Basic comparisons + assert_eq!(cmp_u64_f64(5, 5.0).unwrap(), Ordering::Equal); + assert_eq!(cmp_u64_f64(5, 6.0).unwrap(), Ordering::Less); + assert_eq!(cmp_u64_f64(6, 5.0).unwrap(), Ordering::Greater); + assert_eq!(cmp_u64_f64(0, 0.0).unwrap(), Ordering::Equal); + assert_eq!(cmp_u64_f64(0, 0.1).unwrap(), Ordering::Less); + + // Negative float values should always be less than any u64 + assert_eq!(cmp_u64_f64(0, -0.1).unwrap(), Ordering::Greater); + assert_eq!(cmp_u64_f64(5, -5.0).unwrap(), Ordering::Greater); + assert_eq!(cmp_u64_f64(u64::MAX, -1e20).unwrap(), Ordering::Greater); + + // Tests with extreme values + assert_eq!(cmp_u64_f64(u64::MAX, 1e20).unwrap(), Ordering::Less); + + // Precision edge cases: large u64 that loses precision when converted to f64 + // => 2^54, exactly represented as f64 + let large_f64 = 18_014_398_509_481_984.0; + let large_u64 = 18_014_398_509_481_984; + // prove that large_u64 is exactly represented as f64 + assert_eq!(large_u64 as f64, large_f64); + assert_eq!(cmp_u64_f64(large_u64, large_f64).unwrap(), Ordering::Equal); + // => (2^54 + 1) cannot be exactly represented in f64 + let large_u64_plus_1 = 18_014_398_509_481_985; + // prove that it is represented as f64 by large_f64 + assert_eq!(large_u64_plus_1 as f64, large_f64); + assert_eq!( + cmp_u64_f64(large_u64_plus_1, large_f64).unwrap(), + Ordering::Greater + ); + // => (2^54 - 1) cannot be exactly represented in f64 + let large_u64_minus_1 = 18_014_398_509_481_983; + // prove that it is also represented as f64 by large_f64 + assert_eq!(large_u64_minus_1 as f64, large_f64); + assert_eq!( + cmp_u64_f64(large_u64_minus_1, large_f64).unwrap(), + Ordering::Less + ); + + // NaN comparison results in an error + assert!(cmp_u64_f64(0, f64::NAN).is_err()); + } + + #[test] + fn test_cmp_i64_f64() { + // Basic comparisons + assert_eq!(cmp_i64_f64(5, 5.0).unwrap(), Ordering::Equal); + assert_eq!(cmp_i64_f64(5, 6.0).unwrap(), Ordering::Less); + assert_eq!(cmp_i64_f64(6, 5.0).unwrap(), Ordering::Greater); + assert_eq!(cmp_i64_f64(-5, -5.0).unwrap(), Ordering::Equal); + assert_eq!(cmp_i64_f64(-5, -4.0).unwrap(), Ordering::Less); + assert_eq!(cmp_i64_f64(-4, -5.0).unwrap(), Ordering::Greater); + assert_eq!(cmp_i64_f64(-5, 5.0).unwrap(), Ordering::Less); + assert_eq!(cmp_i64_f64(5, -5.0).unwrap(), Ordering::Greater); + assert_eq!(cmp_i64_f64(0, -0.1).unwrap(), Ordering::Greater); + assert_eq!(cmp_i64_f64(0, 0.1).unwrap(), Ordering::Less); + assert_eq!(cmp_i64_f64(-1, -0.5).unwrap(), Ordering::Less); + assert_eq!(cmp_i64_f64(-1, 0.0).unwrap(), Ordering::Less); + assert_eq!(cmp_i64_f64(0, 0.0).unwrap(), Ordering::Equal); + + // Tests with extreme values + assert_eq!(cmp_i64_f64(i64::MAX, 1e20).unwrap(), Ordering::Less); + assert_eq!(cmp_i64_f64(i64::MIN, -1e20).unwrap(), Ordering::Greater); + + // Precision edge cases: large i64 that loses precision when converted to f64 + // => 2^54, exactly represented as f64 + let large_f64 = 18_014_398_509_481_984.0; + let large_i64 = 18_014_398_509_481_984; + // prove that large_i64 is exactly represented as f64 + assert_eq!(large_i64 as f64, large_f64); + assert_eq!(cmp_i64_f64(large_i64, large_f64).unwrap(), Ordering::Equal); + // => (1_i64 << 54) + 1 cannot be exactly represented in f64 + let large_i64_plus_1 = 18_014_398_509_481_985; + // prove that it is represented as f64 by large_f64 + assert_eq!(large_i64_plus_1 as f64, large_f64); + assert_eq!( + cmp_i64_f64(large_i64_plus_1, large_f64).unwrap(), + Ordering::Greater + ); + // => (1_i64 << 54) - 1 cannot be exactly represented in f64 + let large_i64_minus_1 = 18_014_398_509_481_983; + // prove that it is also represented as f64 by large_f64 + assert_eq!(large_i64_minus_1 as f64, large_f64); + assert_eq!( + cmp_i64_f64(large_i64_minus_1, large_f64).unwrap(), + Ordering::Less + ); + + // Same precision edge case but with negative values + // => -2^54, exactly represented as f64 + let large_neg_f64 = -18_014_398_509_481_984.0; + let large_neg_i64 = -18_014_398_509_481_984; + // prove that large_neg_i64 is exactly represented as f64 + assert_eq!(large_neg_i64 as f64, large_neg_f64); + assert_eq!( + cmp_i64_f64(large_neg_i64, large_neg_f64).unwrap(), + Ordering::Equal + ); + // => (-2^54 + 1) cannot be exactly represented in f64 + let large_neg_i64_plus_1 = -18_014_398_509_481_985; + // prove that it is represented as f64 by large_neg_f64 + assert_eq!(large_neg_i64_plus_1 as f64, large_neg_f64); + assert_eq!( + cmp_i64_f64(large_neg_i64_plus_1, large_neg_f64).unwrap(), + Ordering::Less + ); + // => (-2^54 - 1) cannot be exactly represented in f64 + let large_neg_i64_minus_1 = -18_014_398_509_481_983; + // prove that it is also represented as f64 by large_neg_f64 + assert_eq!(large_neg_i64_minus_1 as f64, large_neg_f64); + assert_eq!( + cmp_i64_f64(large_neg_i64_minus_1, large_neg_f64).unwrap(), + Ordering::Greater + ); + + // NaN comparison results in an error + assert!(cmp_i64_f64(0, f64::NAN).is_err()); + } + + #[test] + fn test_cmp_i64_u64() { + // Test with negative i64 values (should always be less than any u64) + assert_eq!(cmp_i64_u64(-1, 0), Ordering::Less); + assert_eq!(cmp_i64_u64(i64::MIN, 0), Ordering::Less); + assert_eq!(cmp_i64_u64(i64::MIN, u64::MAX), Ordering::Less); + + // Test with positive i64 values + assert_eq!(cmp_i64_u64(0, 0), Ordering::Equal); + assert_eq!(cmp_i64_u64(1, 0), Ordering::Greater); + assert_eq!(cmp_i64_u64(1, 1), Ordering::Equal); + assert_eq!(cmp_i64_u64(0, 1), Ordering::Less); + assert_eq!(cmp_i64_u64(5, 10), Ordering::Less); + assert_eq!(cmp_i64_u64(10, 5), Ordering::Greater); + + // Test with values near i64::MAX and u64 conversion + assert_eq!(cmp_i64_u64(i64::MAX, i64::MAX as u64), Ordering::Equal); + assert_eq!(cmp_i64_u64(i64::MAX, (i64::MAX as u64) + 1), Ordering::Less); + assert_eq!(cmp_i64_u64(i64::MAX, u64::MAX), Ordering::Less); + } +} + +#[cfg(test)] +mod num_proj_tests { + use super::num_proj::{self, ProjectedNumber}; + + #[test] + fn test_i64_to_u64() { + assert_eq!(num_proj::i64_to_u64(-1), ProjectedNumber::Next(0)); + assert_eq!(num_proj::i64_to_u64(i64::MIN), ProjectedNumber::Next(0)); + assert_eq!(num_proj::i64_to_u64(0), ProjectedNumber::Exact(0)); + assert_eq!(num_proj::i64_to_u64(42), ProjectedNumber::Exact(42)); + assert_eq!( + num_proj::i64_to_u64(i64::MAX), + ProjectedNumber::Exact(i64::MAX as u64) + ); + } + + #[test] + fn test_u64_to_i64() { + assert_eq!(num_proj::u64_to_i64(0), ProjectedNumber::Exact(0)); + assert_eq!(num_proj::u64_to_i64(42), ProjectedNumber::Exact(42)); + assert_eq!( + num_proj::u64_to_i64(i64::MAX as u64), + ProjectedNumber::Exact(i64::MAX) + ); + assert_eq!( + num_proj::u64_to_i64((i64::MAX as u64) + 1), + ProjectedNumber::AfterLast + ); + assert_eq!(num_proj::u64_to_i64(u64::MAX), ProjectedNumber::AfterLast); + } + + #[test] + fn test_f64_to_u64() { + assert_eq!(num_proj::f64_to_u64(-1e25), ProjectedNumber::Next(0)); + assert_eq!(num_proj::f64_to_u64(-0.1), ProjectedNumber::Next(0)); + assert_eq!(num_proj::f64_to_u64(1e20), ProjectedNumber::AfterLast); + assert_eq!( + num_proj::f64_to_u64(f64::INFINITY), + ProjectedNumber::AfterLast + ); + assert_eq!(num_proj::f64_to_u64(0.0), ProjectedNumber::Exact(0)); + assert_eq!(num_proj::f64_to_u64(42.0), ProjectedNumber::Exact(42)); + assert_eq!(num_proj::f64_to_u64(0.5), ProjectedNumber::Next(1)); + assert_eq!(num_proj::f64_to_u64(42.1), ProjectedNumber::Next(43)); + } + + #[test] + fn test_f64_to_i64() { + assert_eq!(num_proj::f64_to_i64(-1e20), ProjectedNumber::Next(i64::MIN)); + assert_eq!( + num_proj::f64_to_i64(f64::NEG_INFINITY), + ProjectedNumber::Next(i64::MIN) + ); + assert_eq!(num_proj::f64_to_i64(1e20), ProjectedNumber::AfterLast); + assert_eq!( + num_proj::f64_to_i64(f64::INFINITY), + ProjectedNumber::AfterLast + ); + assert_eq!(num_proj::f64_to_i64(0.0), ProjectedNumber::Exact(0)); + assert_eq!(num_proj::f64_to_i64(42.0), ProjectedNumber::Exact(42)); + assert_eq!(num_proj::f64_to_i64(-42.0), ProjectedNumber::Exact(-42)); + assert_eq!(num_proj::f64_to_i64(0.5), ProjectedNumber::Next(1)); + assert_eq!(num_proj::f64_to_i64(42.1), ProjectedNumber::Next(43)); + assert_eq!(num_proj::f64_to_i64(-0.5), ProjectedNumber::Next(0)); + assert_eq!(num_proj::f64_to_i64(-42.1), ProjectedNumber::Next(-42)); + } + + #[test] + fn test_i64_to_f64() { + assert_eq!(num_proj::i64_to_f64(0), ProjectedNumber::Exact(0.0)); + assert_eq!(num_proj::i64_to_f64(42), ProjectedNumber::Exact(42.0)); + assert_eq!(num_proj::i64_to_f64(-42), ProjectedNumber::Exact(-42.0)); + + let max_exact = 9_007_199_254_740_992; // 2^53 + assert_eq!( + num_proj::i64_to_f64(max_exact), + ProjectedNumber::Exact(max_exact as f64) + ); + + // Test values that cannot be exactly represented as f64 (integers above 2^53) + let large_i64 = 9_007_199_254_740_993; // 2^53 + 1 + let closest_f64 = 9_007_199_254_740_992.0; + assert_eq!(large_i64 as f64, closest_f64); + if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_i64) { + // Verify that the returned float is different from the direct cast + assert!(val > closest_f64); + assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64); + } else { + panic!("Expected ProjectedNumber::Next for large_i64"); + } + + // Test with very large negative value + let large_neg_i64 = -9_007_199_254_740_993; // -(2^53 + 1) + let closest_neg_f64 = -9_007_199_254_740_992.0; + assert_eq!(large_neg_i64 as f64, closest_neg_f64); + if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_neg_i64) { + // Verify that the returned float is the closest representable f64 + assert_eq!(val, closest_neg_f64); + } else { + panic!("Expected ProjectedNumber::Next for large_neg_i64"); + } + } + + #[test] + fn test_u64_to_f64() { + assert_eq!(num_proj::u64_to_f64(0), ProjectedNumber::Exact(0.0)); + assert_eq!(num_proj::u64_to_f64(42), ProjectedNumber::Exact(42.0)); + + // Test the largest u64 value that can be exactly represented as f64 (2^53) + let max_exact = 9_007_199_254_740_992; // 2^53 + assert_eq!( + num_proj::u64_to_f64(max_exact), + ProjectedNumber::Exact(max_exact as f64) + ); + + // Test values that cannot be exactly represented as f64 (integers above 2^53) + let large_u64 = 9_007_199_254_740_993; // 2^53 + 1 + let closest_f64 = 9_007_199_254_740_992.0; + assert_eq!(large_u64 as f64, closest_f64); + if let ProjectedNumber::Next(val) = num_proj::u64_to_f64(large_u64) { + // Verify that the returned float is different from the direct cast + assert!(val > closest_f64); + assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64); + } else { + panic!("Expected ProjectedNumber::Next for large_u64"); + } + } +} diff --git a/src/aggregation/bucket/histogram/date_histogram.rs b/src/aggregation/bucket/histogram/date_histogram.rs index dd85cc51e0..aad9e3aeb9 100644 --- a/src/aggregation/bucket/histogram/date_histogram.rs +++ b/src/aggregation/bucket/histogram/date_histogram.rs @@ -207,7 +207,7 @@ fn parse_offset_into_milliseconds(input: &str) -> Result } } -fn parse_into_milliseconds(input: &str) -> Result { +pub(crate) fn parse_into_milliseconds(input: &str) -> Result { let split_boundary = input .as_bytes() .iter() diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs index 52a952cb84..c3542cb736 100644 --- a/src/aggregation/bucket/mod.rs +++ b/src/aggregation/bucket/mod.rs @@ -22,6 +22,7 @@ //! - [Range](RangeAggregation) //! - [Terms](TermsAggregation) +mod composite; mod histogram; mod range; mod term_agg; @@ -30,6 +31,7 @@ mod term_missing_agg; use std::collections::HashMap; use std::fmt; +pub use composite::*; pub use histogram::*; pub use range::*; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; @@ -48,6 +50,12 @@ pub enum Order { Desc, } +impl Order { + pub(crate) fn asc() -> Self { + Order::Asc + } +} + #[derive(Clone, Debug, PartialEq)] /// Order property by which to apply the order #[derive(Default)] diff --git a/src/aggregation/date.rs b/src/aggregation/date.rs index 97befe7b9e..848f3a2d21 100644 --- a/src/aggregation/date.rs +++ b/src/aggregation/date.rs @@ -14,3 +14,38 @@ pub(crate) fn format_date(val: i64) -> crate::Result { .map_err(|_err| TantivyError::InvalidArgument("Could not serialize date".to_string()))?; Ok(key_as_string) } + +pub(crate) fn parse_date(date_string: &str) -> crate::Result { + OffsetDateTime::parse(date_string, &Rfc3339) + .map_err(|err| { + TantivyError::InvalidArgument(format!( + "Could not parse '{date_string}' as RFC3339 date, err: {err:?}" + )) + }) + .map(|datetime| datetime.unix_timestamp_nanos() as i64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_date_roundtrip() -> crate::Result<()> { + let timestamp = 1697548800_001_001_001i64; // 2023-10-17T13:20:00Z + let date_string = format_date(timestamp)?; + let parsed_timestamp = parse_date(&date_string)?; + assert_eq!(timestamp, parsed_timestamp, "Roundtrip conversion failed"); + + Ok(()) + } + + #[test] + fn test_invalid_date_parsing() { + // Test with invalid date format + let result = parse_date("invalid date"); + assert!(result.is_err(), "Should error on invalid date format"); + + let result = parse_date("2023/10/17 12:00:00"); + assert!(result.is_err(), "Should error on non-RFC3339 format"); + } +} diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index f5f373bb04..c5a566d34f 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -24,8 +24,13 @@ use super::metric::{ }; use super::segment_agg_result::AggregationLimitsGuard; use super::{format_date, AggregationError, Key, SerializedKey}; -use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; -use crate::aggregation::bucket::TermsAggregationInternal; +use crate::aggregation::agg_result::{ + AggregationResults, BucketEntries, BucketEntry, CompositeBucketEntry, +}; +use crate::aggregation::bucket::{ + composite_intermediate_key_ordering, CompositeAggregation, MissingOrder, + TermsAggregationInternal, +}; use crate::aggregation::metric::CardinalityCollector; use crate::TantivyError; @@ -216,6 +221,11 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult is_date_agg: true, }) } + Composite(_) => { + IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite { + buckets: Default::default(), + }) + } Average(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Average( IntermediateAverage::default(), )), @@ -432,6 +442,11 @@ pub enum IntermediateBucketResult { /// The term buckets buckets: IntermediateTermBucketResult, }, + /// Composite aggregation + Composite { + /// The composite buckets + buckets: IntermediateCompositeBucketResult, + }, } impl IntermediateBucketResult { @@ -515,6 +530,13 @@ impl IntermediateBucketResult { req.sub_aggregation(), limits, ), + IntermediateBucketResult::Composite { buckets } => buckets.into_final_result( + req.agg + .as_composite() + .expect("unexpected aggregation, expected composite aggregation"), + req.sub_aggregation(), + limits, + ), } } @@ -568,6 +590,16 @@ impl IntermediateBucketResult { *buckets_left = buckets?; } + ( + IntermediateBucketResult::Composite { + buckets: buckets_left, + }, + IntermediateBucketResult::Composite { + buckets: buckets_right, + }, + ) => { + buckets_left.merge_fruits(buckets_right)?; + } (IntermediateBucketResult::Range(_), _) => { panic!("try merge on different types") } @@ -577,6 +609,9 @@ impl IntermediateBucketResult { (IntermediateBucketResult::Terms { .. }, _) => { panic!("try merge on different types") } + (IntermediateBucketResult::Composite { .. }, _) => { + panic!("try merge on different types") + } } Ok(()) } @@ -805,6 +840,169 @@ pub struct IntermediateTermBucketEntry { pub sub_aggregation: IntermediateAggregationResults, } +/// Entry for the composite bucket. +pub type IntermediateCompositeBucketEntry = IntermediateTermBucketEntry; + +/// The fully typed key for composite aggregation +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum CompositeIntermediateKey { + /// Bool key + Bool(bool), + /// String key + Str(String), + /// Float key + F64(f64), + /// Signed integer key + I64(i64), + /// Unsigned integer key + U64(u64), + /// DateTime key + DateTime(i64), + /// IP Address key + IpAddr(Ipv6Addr), + /// Missing value key + Null, +} + +impl Eq for CompositeIntermediateKey {} + +impl std::hash::Hash for CompositeIntermediateKey { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + match self { + CompositeIntermediateKey::Bool(val) => val.hash(state), + CompositeIntermediateKey::Str(text) => text.hash(state), + CompositeIntermediateKey::F64(val) => val.to_bits().hash(state), + CompositeIntermediateKey::U64(val) => val.hash(state), + CompositeIntermediateKey::I64(val) => val.hash(state), + CompositeIntermediateKey::DateTime(val) => val.hash(state), + CompositeIntermediateKey::IpAddr(val) => val.hash(state), + CompositeIntermediateKey::Null => {} + } + } +} + +/// Composite aggregation page. +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateCompositeBucketResult { + pub(crate) entries: FxHashMap, IntermediateCompositeBucketEntry>, + pub(crate) target_size: u32, + pub(crate) orders: Vec<(Order, MissingOrder)>, +} + +impl IntermediateCompositeBucketResult { + pub(crate) fn into_final_result( + self, + req: &CompositeAggregation, + sub_aggregation_req: &Aggregations, + limits: &mut AggregationLimitsGuard, + ) -> crate::Result { + let trimmed_entry_vec = + trim_composite_buckets(self.entries, &self.orders, self.target_size)?; + let buckets = trimmed_entry_vec + .into_iter() + .map(|(intermediate_key, entry)| { + let key = intermediate_key + .into_iter() + .enumerate() + .map(|(idx, intermediate_key)| { + let source = &req.sources[idx]; + (source.name().to_string(), intermediate_key.into()) + }) + .collect(); + Ok(CompositeBucketEntry { + key, + doc_count: entry.doc_count as u64, + sub_aggregation: entry + .sub_aggregation + .into_final_result_internal(sub_aggregation_req, limits)?, + }) + }) + .collect::>>()?; + + let after_key = if buckets.len() == req.size as usize { + buckets.last().map(|bucket| bucket.key.clone()).unwrap() + } else { + FxHashMap::default() + }; + + Ok(BucketResult::Composite { after_key, buckets }) + } + + fn merge_fruits(&mut self, other: IntermediateCompositeBucketResult) -> crate::Result<()> { + merge_maps(&mut self.entries, other.entries)?; + if self.entries.len() as u32 > 2 * self.target_size { + // 2x factor used to avoid trimming too often (expensive operation) + // an optimal threshold could probably be figured out + self.trim()?; + } + Ok(()) + } + + /// Trim the composite buckets to the target size, according to the ordering. + /// + /// Returns an error if the ordering comparison fails. + pub(crate) fn trim(&mut self) -> crate::Result<()> { + if self.entries.len() as u32 <= self.target_size { + return Ok(()); + } + + let sorted_entries = trim_composite_buckets( + std::mem::take(&mut self.entries), + &self.orders, + self.target_size, + )?; + + self.entries = sorted_entries.into_iter().collect(); + Ok(()) + } +} + +fn trim_composite_buckets( + entries: FxHashMap, IntermediateCompositeBucketEntry>, + orders: &[(Order, MissingOrder)], + target_size: u32, +) -> crate::Result< + Vec<( + Vec, + IntermediateCompositeBucketEntry, + )>, +> { + let mut entries: Vec<_> = entries.into_iter().collect(); + let mut sort_error: Option = None; + entries.sort_by(|(left_key, _), (right_key, _)| { + // Only attempt sorting if we haven't encountered an error yet + if sort_error.is_some() { + return Ordering::Equal; // Return a default, we'll handle the error after sorting + } + + for i in 0..orders.len() { + match composite_intermediate_key_ordering( + &left_key[i], + &right_key[i], + orders[i].0, + orders[i].1, + ) { + Ok(ordering) if ordering != Ordering::Equal => return ordering, + Ok(_) => continue, // Equal, try next key + Err(err) => { + sort_error = Some(err); + break; + } + } + } + Ordering::Equal + }); + + // If we encountered an error during sorting, return it now + if let Some(err) = sort_error { + return Err(err); + } + + entries.truncate(target_size as usize); + Ok(entries) +} + impl MergeFruits for IntermediateTermBucketEntry { fn merge_fruits(&mut self, other: IntermediateTermBucketEntry) -> crate::Result<()> { self.doc_count += other.doc_count; diff --git a/sstable/src/lib.rs b/sstable/src/lib.rs index 82461c3637..74e8ff80d6 100644 --- a/sstable/src/lib.rs +++ b/sstable/src/lib.rs @@ -51,7 +51,7 @@ mod sstable_index_v3; pub use sstable_index_v3::{BlockAddr, SSTableIndex, SSTableIndexBuilder, SSTableIndexV3}; mod sstable_index_v2; pub(crate) mod vint; -pub use dictionary::Dictionary; +pub use dictionary::{Dictionary, TermOrdHit}; pub use streamer::{Streamer, StreamerBuilder}; mod block_reader;