diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index e8558c7643d07..db30dd6ed26e2 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -457,6 +457,17 @@ impl LexOrdering { req.expr.eq(&cur.expr) && is_reversed_sort_options(&req.options, &cur.options) }) } + + /// Returns the sort options for the given expression if one is defined in this `LexOrdering`. + pub fn get_sort_options(&self, expr: &dyn PhysicalExpr) -> Option { + for e in self { + if e.expr.as_ref().dyn_eq(expr) { + return Some(e.options); + } + } + + None + } } /// Check if two SortOptions represent reversed orderings. diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index f419328d11252..2f3b1a19e7d73 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -22,7 +22,7 @@ use arrow::array::types::{ Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow::array::{ArrayRef, RecordBatch, downcast_primitive}; +use arrow::array::{ArrayRef, downcast_primitive}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use datafusion_common::Result; @@ -112,7 +112,7 @@ pub trait GroupValues: Send { fn emit(&mut self, emit_to: EmitTo) -> Result>; /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) - fn clear_shrink(&mut self, batch: &RecordBatch); + fn clear_shrink(&mut self, num_rows: usize); } /// Return a specialized implementation of [`GroupValues`] for the given schema. diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index b62bc11aff018..4c9e376fc4008 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -30,7 +30,7 @@ use crate::aggregates::group_values::multi_group_by::{ bytes_view::ByteViewGroupValueBuilder, primitive::PrimitiveGroupValueBuilder, }; use ahash::RandomState; -use arrow::array::{Array, ArrayRef, RecordBatch}; +use arrow::array::{Array, ArrayRef}; use arrow::compute::cast; use arrow::datatypes::{ BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type, Float32Type, @@ -1181,14 +1181,13 @@ impl GroupValues for GroupValuesColumn { Ok(output) } - fn clear_shrink(&mut self, batch: &RecordBatch) { - let count = batch.num_rows(); + fn clear_shrink(&mut self, num_rows: usize) { self.group_values.clear(); self.map.clear(); - self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); - self.hashes_buffer.shrink_to(count); + self.hashes_buffer.shrink_to(num_rows); // Such structures are only used in `non-streaming` case if !STREAMING { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index a5e5c16006028..dd794c957350d 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -17,7 +17,7 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; -use arrow::array::{Array, ArrayRef, ListArray, RecordBatch, StructArray}; +use arrow::array::{Array, ArrayRef, ListArray, StructArray}; use arrow::compute::cast; use arrow::datatypes::{DataType, SchemaRef}; use arrow::row::{RowConverter, Rows, SortField}; @@ -243,17 +243,16 @@ impl GroupValues for GroupValuesRows { Ok(output) } - fn clear_shrink(&mut self, batch: &RecordBatch) { - let count = batch.num_rows(); + fn clear_shrink(&mut self, num_rows: usize) { self.group_values = self.group_values.take().map(|mut rows| { rows.clear(); rows }); self.map.clear(); - self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); - self.hashes_buffer.shrink_to(count); + self.hashes_buffer.shrink_to(num_rows); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index 44b763a91f523..e993c0c53d199 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -19,7 +19,6 @@ use crate::aggregates::group_values::GroupValues; use arrow::array::{ ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder, NullBufferBuilder, - RecordBatch, }; use datafusion_common::Result; use datafusion_expr::EmitTo; @@ -146,7 +145,7 @@ impl GroupValues for GroupValuesBoolean { Ok(vec![Arc::new(BooleanArray::new(values, nulls)) as _]) } - fn clear_shrink(&mut self, _batch: &RecordBatch) { + fn clear_shrink(&mut self, _num_rows: usize) { self.false_group = None; self.true_group = None; self.null_group = None; diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index b901aee313fb7..b881a51b25474 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -19,7 +19,7 @@ use std::mem::size_of; use crate::aggregates::group_values::GroupValues; -use arrow::array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; +use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use datafusion_common::Result; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; @@ -120,7 +120,7 @@ impl GroupValues for GroupValuesBytes { Ok(vec![group_values]) } - fn clear_shrink(&mut self, _batch: &RecordBatch) { + fn clear_shrink(&mut self, _num_rows: usize) { // in theory we could potentially avoid this reallocation and clear the // contents of the maps, but for now we just reset the map from the beginning self.map.take(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index be9a0334e3ee6..7a56f7c52c11a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -16,7 +16,7 @@ // under the License. use crate::aggregates::group_values::GroupValues; -use arrow::array::{Array, ArrayRef, RecordBatch}; +use arrow::array::{Array, ArrayRef}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; @@ -122,7 +122,7 @@ impl GroupValues for GroupValuesBytesView { Ok(vec![group_values]) } - fn clear_shrink(&mut self, _batch: &RecordBatch) { + fn clear_shrink(&mut self, _num_rows: usize) { // in theory we could potentially avoid this reallocation and clear the // contents of the maps, but for now we just reset the map from the beginning self.map.take(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 41d34218f6a0e..c46cde8786eb4 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -23,7 +23,6 @@ use arrow::array::{ cast::AsArray, }; use arrow::datatypes::{DataType, i256}; -use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; @@ -213,11 +212,10 @@ where Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) } - fn clear_shrink(&mut self, batch: &RecordBatch) { - let count = batch.num_rows(); + fn clear_shrink(&mut self, num_rows: usize) { self.values.clear(); - self.values.shrink_to(count); + self.values.shrink_to(num_rows); self.map.clear(); - self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b0d432a9deffe..06f12a90195d2 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -31,7 +31,6 @@ use crate::filter_pushdown::{ FilterPushdownPropagation, PushedDownPredicate, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, SendableRecordBatchStream, Statistics, @@ -613,12 +612,13 @@ impl AggregateExec { // If existing ordering satisfies a prefix of the GROUP BY expressions, // prefix requirements with this section. In this case, aggregation will // work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, &input)?; - let mut new_requirements = indices - .iter() - .map(|&idx| { - PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]), None) - }) + // Copy the `PhysicalSortExpr`s to retain the sort options. + let (new_sort_exprs, indices) = + input_eq_properties.find_longest_permutation(&groupby_exprs)?; + + let mut new_requirements = new_sort_exprs + .into_iter() + .map(PhysicalSortRequirement::from) .collect::>(); let req = get_finer_aggregate_exprs_requirement( @@ -1815,7 +1815,7 @@ mod tests { use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; use arrow::array::{ - DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray, + DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray, UInt32Array, UInt64Array, }; use arrow::compute::{SortOptions, concat_batches}; @@ -1837,6 +1837,8 @@ mod tests { use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::expressions::lit; + use crate::projection::ProjectionExec; + use datafusion_physical_expr::projection::ProjectionExpr; use futures::{FutureExt, Stream}; use insta::{allow_duplicates, assert_snapshot}; @@ -2484,7 +2486,7 @@ mod tests { ] { let n_aggr = aggregates.len(); let partial_aggregate = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, + AggregateMode::Single, groups, aggregates, vec![None; n_aggr], @@ -3420,6 +3422,117 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> { + // test with spill + fn create_record_batch( + schema: &Arc, + data: (Vec, Vec), + ) -> Result { + Ok(RecordBatch::try_new( + Arc::clone(schema), + vec![ + Arc::new(UInt32Array::from(data.0)), + Arc::new(Float64Array::from(data.1)), + ], + )?) + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + let batches = vec![ + create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, + create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, + ]; + let plan: Arc = + TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?; + let proj = ProjectionExec::try_new( + vec![ + ProjectionExpr::new(lit("0"), "l".to_string()), + ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?, + ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?, + ], + plan, + )?; + let plan: Arc = Arc::new(proj); + let schema = plan.schema(); + + let grouping_set = PhysicalGroupBy::new( + vec![ + (col("l", &schema)?, "l".to_string()), + (col("a", &schema)?, "a".to_string()), + ], + vec![], + vec![vec![false, false]], + false, + ); + + // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). + let aggregates: Vec> = vec![ + Arc::new( + AggregateExprBuilder::new( + datafusion_functions_aggregate::min_max::min_udaf(), + vec![col("b", &schema)?], + ) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + ), + ]; + + let single_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + grouping_set, + aggregates, + vec![None, None], + plan, + Arc::clone(&schema), + )?); + + let batch_size = 2; + let memory_pool = Arc::new(FairSpillPool::new(2000)); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )), + ); + + let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await; + match result { + Ok(result) => { + assert_spill_count_metric(true, single_aggregate); + + allow_duplicates! { + assert_snapshot!(batches_to_string(&result), @r" + +---+---+--------+--------+ + | l | a | MIN(b) | AVG(b) | + +---+---+--------+--------+ + | 0 | 2 | 1.0 | 1.0 | + | 0 | 3 | 2.0 | 2.0 | + | 0 | 4 | 3.0 | 3.5 | + +---+---+--------+--------+ + "); + } + } + Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))), + } + + Ok(()) + } + #[tokio::test] async fn test_aggregate_statistics_edge_cases() -> Result<()> { use crate::test::exec::StatisticsExec; @@ -3492,4 +3605,87 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_order_is_retained_when_spilling() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + Field::new("c", DataType::Int64, false), + ])); + + let batches = vec![vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![2])), + Arc::new(Int64Array::from(vec![2])), + Arc::new(Int64Array::from(vec![1])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![1])), + Arc::new(Int64Array::from(vec![1])), + Arc::new(Int64Array::from(vec![1])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![0])), + Arc::new(Int64Array::from(vec![0])), + Arc::new(Int64Array::from(vec![1])), + ], + )?, + ]]; + let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?; + let scan = scan.try_with_sort_information(vec![ + LexOrdering::new([PhysicalSortExpr::new( + col("b", schema.as_ref())?, + SortOptions::default().desc(), + )]) + .unwrap(), + ])?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![ + (col("b", schema.as_ref())?, "b".to_string()), + (col("c", schema.as_ref())?, "c".to_string()), + ], + vec![], + vec![vec![false, false]], + false, + ), + vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?]) + .schema(Arc::clone(&schema)) + .alias("SUM(c)") + .build()?, + )], + vec![None], + Arc::new(scan) as Arc, + Arc::clone(&schema), + )?); + + let task_ctx = new_spill_ctx(1, 600); + let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?; + assert_spill_count_metric(true, aggr); + + allow_duplicates! { + assert_snapshot!(batches_to_string(&result), @r" + +---+---+--------+ + | b | c | SUM(c) | + +---+---+--------+ + | 2 | 1 | 1 | + | 1 | 1 | 1 | + | 0 | 1 | 1 | + +---+---+--------+ + "); + } + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1e7757de4aac2..cb22fbf9a06a1 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -33,7 +33,6 @@ use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::spill_manager::SpillManager; -use crate::stream::RecordBatchStreamAdapter; use crate::{PhysicalExpr, aggregates, metrics}; use crate::{RecordBatchStream, SendableRecordBatchStream}; @@ -210,6 +209,17 @@ impl SkipAggregationProbe { } } +/// Controls the behavior when an out-of-memory condition occurs. +#[derive(PartialEq, Debug)] +enum OutOfMemoryMode { + /// When out of memory occurs, spill state to disk + Spill, + /// When out of memory occurs, attempt to emit group values early + EmitEarly, + /// When out of memory occurs, immediately report the error + ReportError, +} + /// HashTable based Grouping Aggregator /// /// # Design Goals @@ -433,6 +443,9 @@ pub(crate) struct GroupedHashAggregateStream { /// The memory reservation for this grouping reservation: MemoryReservation, + /// The behavior to trigger when out of memory occurs + oom_mode: OutOfMemoryMode, + /// Execution metrics baseline_metrics: BaselineMetrics, @@ -510,12 +523,12 @@ impl GroupedHashAggregateStream { // Therefore, when we spill these intermediate states or pass them to another // aggregation operator, we must use a schema that includes both the group // columns **and** the partial-state columns. - let partial_agg_schema = create_schema( + let spill_schema = Arc::new(create_schema( &agg.input().schema(), &agg_group_by, &aggregate_exprs, AggregateMode::Partial, - )?; + )?); // Need to update the GROUP BY expressions to point to the correct column after schema change let merging_group_by_expr = agg_group_by @@ -527,20 +540,25 @@ impl GroupedHashAggregateStream { }) .collect(); - let partial_agg_schema = Arc::new(partial_agg_schema); + let output_ordering = agg.cache.output_ordering(); - let spill_expr = + let spill_sort_exprs = group_schema .fields .into_iter() .enumerate() .map(|(idx, field)| { - PhysicalSortExpr::new_default(Arc::new(Column::new( - field.name().as_str(), - idx, - )) as _) + let output_expr = Column::new(field.name().as_str(), idx); + + // Try to use the sort options from the output ordering, if available. + // This ensures that spilled state is sorted in the required order as well. + let sort_options = output_ordering + .and_then(|o| o.get_sort_options(&output_expr)) + .unwrap_or_default(); + + PhysicalSortExpr::new(Arc::new(output_expr), sort_options) }); - let Some(spill_expr) = LexOrdering::new(spill_expr) else { + let Some(spill_ordering) = LexOrdering::new(spill_sort_exprs) else { return internal_err!("Spill expression is empty"); }; @@ -550,11 +568,35 @@ impl GroupedHashAggregateStream { .collect::>() .join(", "); let name = format!("GroupedHashAggregateStream[{partition}] ({agg_fn_names})"); - let reservation = MemoryConsumer::new(name) - .with_can_spill(true) - .register(context.memory_pool()); let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?; + let oom_mode = match (agg.mode, &group_ordering) { + // In partial aggregation mode, always prefer to emit incomplete results early. + (AggregateMode::Partial, _) => OutOfMemoryMode::EmitEarly, + // For non-partial aggregation modes, emitting incomplete results is not an option. + // Instead, use disk spilling to store sorted, incomplete results, and merge them + // afterwards. + (_, GroupOrdering::None | GroupOrdering::Partial(_)) + if context.runtime_env().disk_manager.tmp_files_enabled() => + { + OutOfMemoryMode::Spill + } + // For `GroupOrdering::Full`, the incoming stream is already sorted. This ensures the + // number of incomplete groups can be kept small at all times. If we still hit + // an out-of-memory condition, spilling to disk would not be beneficial since the same + // situation is likely to reoccur when reading back the spilled data. + // Therefore, we fall back to simply reporting the error immediately. + // This mode will also be used if the `DiskManager` is not configured to allow spilling + // to disk. + _ => OutOfMemoryMode::ReportError, + }; + let group_values = new_group_values(group_schema, &group_ordering)?; + let reservation = MemoryConsumer::new(name) + // We interpret 'can spill' as 'can handle memory back pressure'. + // This value needs to be set to true for the default memory pool implementations + // to ensure fair application of back pressure amongst the memory consumers. + .with_can_spill(oom_mode != OutOfMemoryMode::ReportError) + .register(context.memory_pool()); timer.done(); let exec_state = ExecutionState::ReadingInput; @@ -562,14 +604,14 @@ impl GroupedHashAggregateStream { let spill_manager = SpillManager::new( context.runtime_env(), metrics::SpillMetrics::new(&agg.metrics, partition), - Arc::clone(&partial_agg_schema), + Arc::clone(&spill_schema), ) .with_compression_type(context.session_config().spill_compression()); let spill_state = SpillState { spills: vec![], - spill_expr, - spill_schema: partial_agg_schema, + spill_expr: spill_ordering, + spill_schema, is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(merging_group_by_expr), @@ -627,6 +669,7 @@ impl GroupedHashAggregateStream { filter_expressions, group_by: agg_group_by, reservation, + oom_mode, group_values, current_group_indices: Default::default(), exec_state, @@ -676,21 +719,24 @@ impl Stream for GroupedHashAggregateStream { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { - // New batch to aggregate in partial aggregation operator - Some(Ok(batch)) if self.mode == AggregateMode::Partial => { + // New batch to aggregate + Some(Ok(batch)) => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); - if let Some(reduction_factor) = self.reduction_factor.as_ref() + if self.mode == AggregateMode::Partial + && let Some(reduction_factor) = + self.reduction_factor.as_ref() { reduction_factor.add_total(input_rows); } - // Do the grouping + // Do the grouping. + // `group_aggregate_batch` will _not_ have updated the memory reservation yet. + // The rest of the code will first try to reduce memory usage by + // already emitting results. self.group_aggregate_batch(&batch)?; - // If we can begin emitting rows, do so, - // otherwise keep consuming input assert!(!self.input_done); // If the number of group values equals or exceeds the soft limit, @@ -702,7 +748,13 @@ impl Stream for GroupedHashAggregateStream { break 'reading_input; } - if let Some(to_emit) = self.group_ordering.emit_to() { + // Try to emit completed groups if possible. + // If we already started spilling, we can no longer emit since + // this might lead to incorrect output ordering + if (self.spill_state.spills.is_empty() + || self.spill_state.is_stream_merging) + && let Some(to_emit) = self.group_ordering.emit_to() + { timer.done(); if let Some(batch) = self.emit(to_emit, false)? { self.exec_state = @@ -712,18 +764,28 @@ impl Stream for GroupedHashAggregateStream { break 'reading_input; } - // Check if we should switch to skip aggregation mode - // It's important that we do this before we early emit since we've - // already updated the probe. - self.update_skip_aggregation_probe(input_rows); - if let Some(new_state) = self.switch_to_skip_aggregation()? { - timer.done(); - self.exec_state = new_state; - break 'reading_input; + if self.mode == AggregateMode::Partial { + // Spilling should never be activated in partial aggregation mode. + assert!(!self.spill_state.is_stream_merging); + + // Check if we should switch to skip aggregation mode + // It's important that we do this before we early emit since we've + // already updated the probe. + self.update_skip_aggregation_probe(input_rows); + if let Some(new_state) = + self.switch_to_skip_aggregation()? + { + timer.done(); + self.exec_state = new_state; + break 'reading_input; + } } - // Check if we need to emit early due to memory pressure - if let Some(new_state) = self.emit_early_if_necessary()? { + // If we reach this point, try to update the memory reservation + // handling out-of-memory conditions as determined by the OOM mode. + if let Some(new_state) = + self.try_update_memory_reservation()? + { timer.done(); self.exec_state = new_state; break 'reading_input; @@ -732,43 +794,6 @@ impl Stream for GroupedHashAggregateStream { timer.done(); } - // New batch to aggregate in terminal aggregation operator - // (Final/FinalPartitioned/Single/SinglePartitioned) - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - - // Make sure we have enough capacity for `batch`, otherwise spill - self.spill_previous_if_necessary(&batch)?; - - // Do the grouping - self.group_aggregate_batch(&batch)?; - - // If we can begin emitting rows, do so, - // otherwise keep consuming input - assert!(!self.input_done); - - // If the number of group values equals or exceeds the soft limit, - // emit all groups and switch to producing output - if self.hit_soft_group_limit() { - timer.done(); - self.set_input_done_and_produce_output()?; - // make sure the exec_state just set is not overwritten below - break 'reading_input; - } - - if let Some(to_emit) = self.group_ordering.emit_to() { - timer.done(); - if let Some(batch) = self.emit(to_emit, false)? { - self.exec_state = - ExecutionState::ProducingOutput(batch); - }; - // make sure the exec_state just set is not overwritten below - break 'reading_input; - } - - timer.done(); - } - // Found error from input stream Some(Err(e)) => { // inner had error, return to caller @@ -987,25 +1012,56 @@ impl GroupedHashAggregateStream { } } - match self.update_memory_reservation() { - // Here we can ignore `insufficient_capacity_err` because we will spill later, - // but at least one batch should fit in the memory - Err(DataFusionError::ResourcesExhausted(_)) - if self.group_values.len() >= self.batch_size => - { - Ok(()) + Ok(()) + } + + /// Attempts to update the memory reservation. If that fails due to a + /// [DataFusionError::ResourcesExhausted] error, an attempt will be made to resolve + /// the out-of-memory condition based on the [out-of-memory handling mode](OutOfMemoryMode). + /// + /// If the out-of-memory condition can not be resolved, an `Err` value will be returned + /// + /// Returns `Ok(Some(ExecutionState))` if the state should be changed, `Ok(None)` otherwise. + fn try_update_memory_reservation(&mut self) -> Result> { + let oom = match self.update_memory_reservation() { + Err(e @ DataFusionError::ResourcesExhausted(_)) => e, + Err(e) => return Err(e), + Ok(_) => return Ok(None), + }; + + match self.oom_mode { + OutOfMemoryMode::Spill if !self.group_values.is_empty() => { + self.spill()?; + self.clear_shrink(self.batch_size); + self.update_memory_reservation()?; + Ok(None) } - other => other, + OutOfMemoryMode::EmitEarly if self.group_values.len() > 1 => { + let n = if self.group_values.len() >= self.batch_size { + // Try to emit an integer multiple of batch size if possible + self.group_values.len() / self.batch_size * self.batch_size + } else { + // Otherwise emit whatever we can + self.group_values.len() + }; + + if let Some(batch) = self.emit(EmitTo::First(n), false)? { + Ok(Some(ExecutionState::ProducingOutput(batch))) + } else { + Err(oom) + } + } + _ => Err(oom), } } fn update_memory_reservation(&mut self) -> Result<()> { let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - let reservation_result = self.reservation.try_resize( - acc + self.group_values.size() - + self.group_ordering.size() - + self.current_group_indices.allocated_size(), - ); + let new_size = acc + + self.group_values.size() + + self.group_ordering.size() + + self.current_group_indices.allocated_size(); + let reservation_result = self.reservation.try_resize(new_size); if reservation_result.is_ok() { self.spill_state @@ -1060,24 +1116,6 @@ impl GroupedHashAggregateStream { Ok(Some(batch)) } - /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly - /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to disk and clear the - /// memory. Currently only [`GroupOrdering::None`] is supported for spilling. - fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> { - // TODO: support group_ordering for spilling - if !self.group_values.is_empty() - && batch.num_rows() > 0 - && matches!(self.group_ordering, GroupOrdering::None) - && !self.spill_state.is_stream_merging - && self.update_memory_reservation().is_err() - { - assert_ne!(self.mode, AggregateMode::Partial); - self.spill()?; - self.clear_shrink(batch); - } - Ok(()) - } - /// Emit all intermediate aggregation states, sort them, and store them on disk. /// This process helps in reducing memory pressure by allowing the data to be /// read back with streaming merge. @@ -1115,72 +1153,15 @@ impl GroupedHashAggregateStream { } /// Clear memory and shirk capacities to the size of the batch. - fn clear_shrink(&mut self, batch: &RecordBatch) { - self.group_values.clear_shrink(batch); + fn clear_shrink(&mut self, num_rows: usize) { + self.group_values.clear_shrink(num_rows); self.current_group_indices.clear(); - self.current_group_indices.shrink_to(batch.num_rows()); + self.current_group_indices.shrink_to(num_rows); } /// Clear memory and shirk capacities to zero. fn clear_all(&mut self) { - let s = self.schema(); - self.clear_shrink(&RecordBatch::new_empty(s)); - } - - /// Emit if the used memory exceeds the target for partial aggregation. - /// Currently only [`GroupOrdering::None`] is supported for early emitting. - /// TODO: support group_ordering for early emitting - /// - /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise. - fn emit_early_if_necessary(&mut self) -> Result> { - if self.group_values.len() >= self.batch_size - && matches!(self.group_ordering, GroupOrdering::None) - && self.update_memory_reservation().is_err() - { - assert_eq!(self.mode, AggregateMode::Partial); - let n = self.group_values.len() / self.batch_size * self.batch_size; - if let Some(batch) = self.emit(EmitTo::First(n), false)? { - return Ok(Some(ExecutionState::ProducingOutput(batch))); - }; - } - Ok(None) - } - - /// At this point, all the inputs are read and there are some spills. - /// Emit the remaining rows and create a batch. - /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully - /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. - fn update_merged_stream(&mut self) -> Result<()> { - let Some(batch) = self.emit(EmitTo::All, true)? else { - return Ok(()); - }; - // clear up memory for streaming_merge - self.clear_all(); - self.update_memory_reservation()?; - let mut streams: Vec = vec![]; - let expr = self.spill_state.spill_expr.clone(); - let schema = batch.schema(); - streams.push(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, &expr, None) - })), - ))); - - self.spill_state.is_stream_merging = true; - self.input = StreamingMergeBuilder::new() - .with_streams(streams) - .with_schema(schema) - .with_spill_manager(self.spill_state.spill_manager.clone()) - .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills)) - .with_expressions(&self.spill_state.spill_expr) - .with_metrics(self.baseline_metrics.clone()) - .with_batch_size(self.batch_size) - .with_reservation(self.reservation.new_empty()) - .build()?; - self.input_done = false; - self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); - Ok(()) + self.clear_shrink(0); } /// returns true if there is a soft groups limit and the number of distinct @@ -1192,18 +1173,60 @@ impl GroupedHashAggregateStream { group_values_soft_limit <= self.group_values.len() } - /// common function for signalling end of processing of the input stream + /// Finalizes reading of the input stream and prepares for producing output values. + /// + /// This method is called both when the original input stream and, + /// in case of disk spilling, the SPM stream have been drained. fn set_input_done_and_produce_output(&mut self) -> Result<()> { self.input_done = true; self.group_ordering.input_done(); let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { + // Input has been entirely processed without spilling to disk. + + // Flush any remaining group values. let batch = self.emit(EmitTo::All, false)?; + + // If there are none, we're done; otherwise switch to emitting them batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput) } else { - // If spill files exist, stream-merge them. - self.update_merged_stream()?; + // Spill any remaining data to disk. There is some performance overhead in + // writing out this last chunk of data and reading it back. The benefit of + // doing this is that memory usage for this stream is reduced, and the more + // sophisticated memory handling in `MultiLevelMergeBuilder` can take over + // instead. + // Spilling to disk and reading back also ensures batch size is consistent + // rather than potentially having one significantly larger last batch. + self.spill()?; + + // Mark that we're switching to stream merging mode. + self.spill_state.is_stream_merging = true; + + self.input = StreamingMergeBuilder::new() + .with_schema(Arc::clone(&self.spill_state.spill_schema)) + .with_spill_manager(self.spill_state.spill_manager.clone()) + .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills)) + .with_expressions(&self.spill_state.spill_expr) + .with_metrics(self.baseline_metrics.clone()) + .with_batch_size(self.batch_size) + .with_reservation(self.reservation.new_empty()) + .build()?; + self.input_done = false; + + // Reset the group values collectors. + self.clear_all(); + + // We can now use `GroupOrdering::Full` since the spill files are sorted + // on the grouping columns. + self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); + + // Use `OutOfMemoryMode::ReportError` from this point on + // to ensure we don't spill the spilled data to disk again. + self.oom_mode = OutOfMemoryMode::ReportError; + + self.update_memory_reservation()?; + ExecutionState::ReadingInput }; timer.done(); diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index f2336920b3571..c94b5a4131397 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -342,6 +342,7 @@ impl TestMemoryExec { } self.sort_information = sort_information; + self.cache = self.compute_properties(); Ok(self) }