Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions native/spark-expr/src/bloom_filter/spark_bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,27 @@ impl SparkBloomFilter {
}

pub fn merge_filter(&mut self, other: &[u8]) {
// Extract bits data if other is in Spark's full serialization format
// We need to compute the expected size and extract data before borrowing self.bits mutably
let expected_bits_size = self.bits.byte_size();
const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4)

let bits_data = if other.len() == SPARK_HEADER_SIZE + expected_bits_size {
// This is Spark's full format, extract bits data (skip header)
&other[SPARK_HEADER_SIZE..]
} else {
// This is already just bits data (Comet format)
other
};

assert_eq!(
other.len(),
self.bits.byte_size(),
"Cannot merge SparkBloomFilters with different lengths."
bits_data.len(),
expected_bits_size,
"Cannot merge SparkBloomFilters with different lengths. Expected {} bytes, got {} bytes (full buffer size: {} bytes)",
expected_bits_size,
bits_data.len(),
other.len()
);
self.bits.merge_bits(other);
self.bits.merge_bits(bits_data);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,13 @@ trait CometBaseAggregate {
val multiMode = modes.size > 1
// For a final mode HashAggregate, we only need to transform the HashAggregate
// if there is Comet partial aggregation.
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
// Exception: BloomFilterAggregate supports Spark partial / Comet final because
// merge_filter() handles Spark's serialization format (12-byte header + bits).
val hasBloomFilterAgg = aggregate.aggregateExpressions.exists(expr =>
expr.aggregateFunction.getClass.getSimpleName == "BloomFilterAggregate")
val sparkFinalMode = modes.contains(Final) &&
findCometPartialAgg(aggregate.child).isEmpty &&
!hasBloomFilterAgg

if (multiMode || sparkFinalMode) {
return None
Expand Down
65 changes: 64 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Bloom
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
Expand Down Expand Up @@ -1149,6 +1150,68 @@ class CometExecSuite extends CometTestBase {
spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
}

test("bloom_filter_agg - Spark partial / Comet final merge") {
// This test exercises the merge_filter() fix that handles Spark's full serialization
// format (12-byte header + bits) when merging from Spark partial to Comet final aggregates.
val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
spark.sessionState.functionRegistry.registerFunction(
funcId_bloom_filter_agg,
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
(children: Seq[Expression]) =>
children.size match {
case 1 => new BloomFilterAggregate(children.head)
case 2 => new BloomFilterAggregate(children.head, children(1))
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
})

// Helper to count operators in plan
def countOperators(plan: SparkPlan, opClass: Class[_]): Int = {
stripAQEPlan(plan).collect {
case stage: QueryStageExec =>
countOperators(stage.plan, opClass)
case op if op.getClass.isAssignableFrom(opClass) => 1
}.sum
}

withParquetTable(
(0 until 1000)
.map(_ => (Random.nextInt(1000), Random.nextInt(100))),
"tbl") {

withSQLConf(
// Disable Comet partial aggregates to force Spark partial / Comet final scenario
CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false",
CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {

val df = sql(
"SELECT bloom_filter_agg(cast(_2 as long), cast(1000 as long)) FROM tbl GROUP BY _1")

// Verify the query executes successfully (tests merge_filter compatibility)
checkSparkAnswer(df)

// Verify we have Spark partial aggregates and Comet final aggregates
val plan = stripAQEPlan(df.queryExecution.executedPlan)
val sparkPartialAggs = plan.collect {
case agg: HashAggregateExec if agg.aggregateExpressions.exists(_.mode == Partial) => agg
}
val cometFinalAggs = plan.collect {
case agg: CometHashAggregateExec if agg.aggregateExpressions.exists(_.mode == Final) =>
agg
}

assert(
sparkPartialAggs.nonEmpty,
s"Expected Spark partial aggregates but found none. Plan: $plan")
assert(
cometFinalAggs.nonEmpty,
s"Expected Comet final aggregates but found none. Plan: $plan")
}
}

spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
}

test("sort (non-global)") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc)
Expand Down
Loading