Skip to content

Commit 4cdbb77

Browse files
author
weixiuli
committed
Support the null value in bloom_filter_agg Spark aggregate function
1 parent 8a6ef2b commit 4cdbb77

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,6 @@ class BloomFilterAggAggregate : public exec::Aggregate {
8585
}
8686
}
8787

88-
static FOLLY_ALWAYS_INLINE void checkBloomFilterNotNull(
89-
DecodedVector& decoded,
90-
vector_size_t idx) {
91-
VELOX_USER_CHECK(
92-
!decoded.isNullAt(idx),
93-
"First argument of bloom_filter_agg cannot be null");
94-
}
95-
9688
void addRawInput(
9789
char** groups,
9890
const SelectivityVector& rows,
@@ -102,8 +94,8 @@ class BloomFilterAggAggregate : public exec::Aggregate {
10294
computeCapacity();
10395
auto mayHaveNulls = decodedRaw_.mayHaveNulls();
10496
rows.applyToSelected([&](vector_size_t row) {
105-
if (mayHaveNulls) {
106-
checkBloomFilterNotNull(decodedRaw_, row);
97+
if (mayHaveNulls && decodedRaw_.isNullAt(row)) {
98+
return;
10799
}
108100
auto group = groups[row];
109101
auto tracker = trackRowSize(group);
@@ -144,17 +136,18 @@ class BloomFilterAggAggregate : public exec::Aggregate {
144136
accumulator->init(capacity_);
145137
if (decodedRaw_.isConstantMapping()) {
146138
// All values are same, just do for the first.
147-
checkBloomFilterNotNull(decodedRaw_, 0);
148-
accumulator->insert(decodedRaw_.valueAt<int64_t>(0));
149-
return;
150-
}
151-
auto mayHaveNulls = decodedRaw_.mayHaveNulls();
152-
rows.applyToSelected([&](vector_size_t row) {
153-
if (mayHaveNulls) {
154-
checkBloomFilterNotNull(decodedRaw_, row);
139+
if (!decodedRaw_.isNullAt(0)) {
140+
accumulator->insert(decodedRaw_.valueAt<int64_t>(0));
155141
}
156-
accumulator->insert(decodedRaw_.valueAt<int64_t>(row));
157-
});
142+
} else {
143+
auto mayHaveNulls = decodedRaw_.mayHaveNulls();
144+
rows.applyToSelected([&](vector_size_t row) {
145+
if (mayHaveNulls && decodedRaw_.isNullAt(row)) {
146+
return;
147+
}
148+
accumulator->insert(decodedRaw_.valueAt<int64_t>(row));
149+
});
150+
};
158151
}
159152

160153
void addSingleGroupIntermediateResults(

velox/functions/sparksql/aggregates/tests/BloomFilterAggAggregateTest.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,12 @@ TEST_F(BloomFilterAggAggregateTest, emptyInput) {
7171
testAggregations(vectors, {}, {"bloom_filter_agg(c0, 5, 64)"}, expected);
7272
}
7373

74-
TEST_F(BloomFilterAggAggregateTest, nullBloomFilter) {
75-
auto vectors = {makeRowVector({makeAllNullFlatVector<int64_t>(2)})};
76-
auto expectedFake = {makeRowVector(
77-
{makeNullableFlatVector<StringView>({std::nullopt}, VARBINARY())})};
78-
VELOX_ASSERT_THROW(
79-
testAggregations(
80-
vectors, {}, {"bloom_filter_agg(c0, 5, 64)"}, expectedFake),
81-
"First argument of bloom_filter_agg cannot be null");
74+
TEST_F(BloomFilterAggAggregateTest, nullInput) {
75+
auto vectors = {makeRowVector(
76+
{makeFlatVector<int64_t>(100, [](vector_size_t row) { return row % 9; }),
77+
makeAllNullFlatVector<int64_t>(1)})};
78+
auto expected = {makeRowVector({getSerializedBloomFilter(4)})};
79+
testAggregations(vectors, {}, {"bloom_filter_agg(c0, 5, 64)"}, expected);
8280
}
8381

8482
TEST_F(BloomFilterAggAggregateTest, config) {

0 commit comments

Comments
 (0)