Skip to content

Commit f67f4bf

Browse files
committed
Fix array_union on NaN
1 parent 8fd999d commit f67f4bf

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "velox/expression/ComplexViewTypes.h"
19+
#include "velox/functions/Udf.h"
20+
#include "velox/functions/prestosql/CheckedArithmetic.h"
21+
#include "velox/type/Conversions.h"
22+
23+
namespace facebook::velox::functions::sparksql {
24+
25+
/// This class implements the array union function.
26+
///
27+
/// DEFINITION:
28+
/// array_union(x, y) → array
29+
/// Returns an array of the elements in the union of x and y, without
30+
/// duplicates.
31+
template <typename T>
32+
struct ArrayUnionFunction {
33+
VELOX_DEFINE_FUNCTION_TYPES(T)
34+
35+
// Fast path for primitives.
36+
template <typename Out, typename In>
37+
void call(Out& out, const In& inputArray1, const In& inputArray2) {
38+
folly::F14FastSet<typename In::element_t> elementSet;
39+
bool nullAdded = false;
40+
bool nanAdded = false;
41+
auto addItems = [&](auto& inputArray) {
42+
for (const auto& item : inputArray) {
43+
if (item.has_value()) {
44+
if constexpr (
45+
std::is_same_v<In, arg_type<Array<float>>> ||
46+
std::is_same_v<In, arg_type<Array<double>>>) {
47+
bool isNaN = std::isnan(item.value());
48+
if ((!nanAdded || !isNaN) &&
49+
elementSet.insert(item.value()).second) {
50+
auto& newItem = out.add_item();
51+
newItem = item.value();
52+
}
53+
if (!nanAdded && isNaN) {
54+
nanAdded = true;
55+
}
56+
} else if (elementSet.insert(item.value()).second) {
57+
auto& newItem = out.add_item();
58+
newItem = item.value();
59+
}
60+
} else if (!nullAdded) {
61+
nullAdded = true;
62+
out.add_null();
63+
}
64+
}
65+
};
66+
addItems(inputArray1);
67+
addItems(inputArray2);
68+
}
69+
};
70+
} // namespace facebook::velox::functions::sparksql

velox/functions/sparksql/Register.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "velox/functions/prestosql/JsonFunctions.h"
2525
#include "velox/functions/prestosql/Rand.h"
2626
#include "velox/functions/prestosql/StringFunctions.h"
27+
#include "velox/functions/sparksql/ArrayFunctions.h"
2728
#include "velox/functions/sparksql/ArraySort.h"
2829
#include "velox/functions/sparksql/Bitwise.h"
2930
#include "velox/functions/sparksql/DateTimeFunctions.h"
@@ -95,6 +96,12 @@ void registerExpressionGeneralFunctions(const std::string& prefix) {
9596
makeUnscaledValue());
9697
}
9798

99+
template <typename T>
100+
inline void registerArrayUnionFunctions(const std::string& prefix) {
101+
registerFunction<sparksql::ArrayUnionFunction, Array<T>, Array<T>, Array<T>>(
102+
{prefix + "array_union"});
103+
}
104+
98105
void registerFunctions(const std::string& prefix) {
99106
registerAllSpecialFormGeneralFunctions();
100107
registerFunction<RandFunction, double>({prefix + "rand"});
@@ -265,6 +272,19 @@ void registerFunctions(const std::string& prefix) {
265272

266273
// Register expression general functions.
267274
registerExpressionGeneralFunctions(prefix);
275+
276+
// Array union.
277+
registerArrayUnionFunctions<int8_t>(prefix);
278+
registerArrayUnionFunctions<int16_t>(prefix);
279+
registerArrayUnionFunctions<int32_t>(prefix);
280+
registerArrayUnionFunctions<int64_t>(prefix);
281+
registerArrayUnionFunctions<int128_t>(prefix);
282+
registerArrayUnionFunctions<float>(prefix);
283+
registerArrayUnionFunctions<double>(prefix);
284+
registerArrayUnionFunctions<bool>(prefix);
285+
registerArrayUnionFunctions<Timestamp>(prefix);
286+
registerArrayUnionFunctions<Date>(prefix);
287+
registerArrayUnionFunctions<Varbinary>(prefix);
268288
}
269289

270290
} // namespace sparksql

0 commit comments

Comments
 (0)