Skip to content

Commit 88e5a7e

Browse files
natashasehgalfacebook-github-bot
authored andcommitted
feat: Add support for Array types in Greatest and Least functions
Differential Revision: D86489915
1 parent 66bea5d commit 88e5a7e

File tree

3 files changed

+281
-0
lines changed

3 files changed

+281
-0
lines changed

velox/functions/prestosql/GreatestLeast.h

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,86 @@ using LeastFunction = details::ExtremeValueFunction<TExec, T, true>;
109109

110110
template <typename TExec, typename T>
111111
using GreatestFunction = details::ExtremeValueFunction<TExec, T, false>;
112+
// Array overloads for least() and greatest()
113+
template <typename TExec, typename T, bool isLeast>
114+
struct ArrayExtremeValueFunction {
115+
VELOX_DEFINE_FUNCTION_TYPES(TExec);
116+
117+
FOLLY_ALWAYS_INLINE bool call(
118+
out_type<T>& result,
119+
const arg_type<Array<T>>& array) {
120+
if (array.size() == 0) {
121+
// Empty array returns null
122+
return false;
123+
}
124+
125+
bool hasValue = false;
126+
std::optional<typename arg_type<Array<T>>::element_t> currentValue;
127+
128+
for (const auto& element : array) {
129+
if (element.has_value()) {
130+
if (!hasValue) {
131+
currentValue = element.value();
132+
hasValue = true;
133+
} else {
134+
if constexpr (isLeast) {
135+
if (smallerThan(element.value(), currentValue.value())) {
136+
currentValue = element.value();
137+
}
138+
} else {
139+
if (greaterThan(element.value(), currentValue.value())) {
140+
currentValue = element.value();
141+
}
142+
}
143+
}
144+
}
145+
}
146+
147+
if (!hasValue) {
148+
// Array contains only nulls
149+
return false;
150+
}
151+
152+
result = currentValue.value();
153+
return true;
154+
}
155+
156+
private:
157+
template <typename K>
158+
bool greaterThan(const K& lhs, const K& rhs) const {
159+
if constexpr (std::is_same_v<K, double> || std::is_same_v<K, float>) {
160+
if (std::isnan(lhs)) {
161+
return true;
162+
}
163+
164+
if (std::isnan(rhs)) {
165+
return false;
166+
}
167+
}
168+
169+
return lhs > rhs;
170+
}
171+
172+
template <typename K>
173+
bool smallerThan(const K& lhs, const K& rhs) const {
174+
if constexpr (std::is_same_v<K, double> || std::is_same_v<K, float>) {
175+
if (std::isnan(lhs)) {
176+
return false;
177+
}
178+
179+
if (std::isnan(rhs)) {
180+
return true;
181+
}
182+
}
183+
184+
return lhs < rhs;
185+
}
186+
};
187+
188+
template <typename TExec, typename T>
189+
using ArrayLeastFunction = ArrayExtremeValueFunction<TExec, T, true>;
190+
191+
template <typename TExec, typename T>
192+
using ArrayGreatestFunction = ArrayExtremeValueFunction<TExec, T, false>;
112193

113194
} // namespace facebook::velox::functions

velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,80 @@ void registerAllGreatestLeastFunctions(const std::string& prefix) {
6060
registerGreatestLeastFunction<Timestamp>(prefix);
6161
registerGreatestLeastFunction<TimestampWithTimezone>(prefix);
6262
registerGreatestLeastFunction<IPAddress>(prefix);
63+
64+
registerFunction<
65+
ParameterBinder<ArrayGreatestFunction, int8_t>,
66+
int8_t,
67+
Array<int8_t>>({prefix + "greatest"});
68+
registerFunction<
69+
ParameterBinder<ArrayGreatestFunction, int16_t>,
70+
int16_t,
71+
Array<int16_t>>({prefix + "greatest"});
72+
registerFunction<
73+
ParameterBinder<ArrayGreatestFunction, int32_t>,
74+
int32_t,
75+
Array<int32_t>>({prefix + "greatest"});
76+
registerFunction<
77+
ParameterBinder<ArrayGreatestFunction, int64_t>,
78+
int64_t,
79+
Array<int64_t>>({prefix + "greatest"});
80+
registerFunction<
81+
ParameterBinder<ArrayGreatestFunction, float>,
82+
float,
83+
Array<float>>({prefix + "greatest"});
84+
registerFunction<
85+
ParameterBinder<ArrayGreatestFunction, double>,
86+
double,
87+
Array<double>>({prefix + "greatest"});
88+
registerFunction<
89+
ParameterBinder<ArrayGreatestFunction, Varchar>,
90+
Varchar,
91+
Array<Varchar>>({prefix + "greatest"});
92+
registerFunction<
93+
ParameterBinder<ArrayGreatestFunction, Date>,
94+
Date,
95+
Array<Date>>({prefix + "greatest"});
96+
registerFunction<
97+
ParameterBinder<ArrayGreatestFunction, Timestamp>,
98+
Timestamp,
99+
Array<Timestamp>>({prefix + "greatest"});
100+
101+
registerFunction<
102+
ParameterBinder<ArrayLeastFunction, int8_t>,
103+
int8_t,
104+
Array<int8_t>>({prefix + "least"});
105+
registerFunction<
106+
ParameterBinder<ArrayLeastFunction, int16_t>,
107+
int16_t,
108+
Array<int16_t>>({prefix + "least"});
109+
registerFunction<
110+
ParameterBinder<ArrayLeastFunction, int32_t>,
111+
int32_t,
112+
Array<int32_t>>({prefix + "least"});
113+
registerFunction<
114+
ParameterBinder<ArrayLeastFunction, int64_t>,
115+
int64_t,
116+
Array<int64_t>>({prefix + "least"});
117+
registerFunction<
118+
ParameterBinder<ArrayLeastFunction, float>,
119+
float,
120+
Array<float>>({prefix + "least"});
121+
registerFunction<
122+
ParameterBinder<ArrayLeastFunction, double>,
123+
double,
124+
Array<double>>({prefix + "least"});
125+
registerFunction<
126+
ParameterBinder<ArrayLeastFunction, Varchar>,
127+
Varchar,
128+
Array<Varchar>>({prefix + "least"});
129+
registerFunction<
130+
ParameterBinder<ArrayLeastFunction, Date>,
131+
Date,
132+
Array<Date>>({prefix + "least"});
133+
registerFunction<
134+
ParameterBinder<ArrayLeastFunction, Timestamp>,
135+
Timestamp,
136+
Array<Timestamp>>({prefix + "least"});
63137
}
64138
} // namespace
65139

velox/functions/prestosql/tests/GreatestLeastTest.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,129 @@ TEST_F(GreatestLeastTest, boolean) {
439439
{true, true, true, false, std::nullopt, std::nullopt});
440440
test::assertEqualVectors(expected, result);
441441
}
442+
// Array overload tests for least() and greatest()
443+
TEST_F(GreatestLeastTest, arrayLeastBigint) {
444+
auto data = makeRowVector({
445+
makeArrayVector<int64_t>({
446+
{1, 2, 3},
447+
{-100, 50, 0},
448+
}),
449+
});
450+
451+
auto result = evaluate("least(c0)", data);
452+
auto expected = makeNullableFlatVector<int64_t>({1, -100});
453+
test::assertEqualVectors(expected, result);
454+
}
455+
456+
TEST_F(GreatestLeastTest, arrayGreatestBigint) {
457+
auto data = makeRowVector({
458+
makeArrayVector<int64_t>({
459+
{1, 2, 3},
460+
{-100, 50, 0},
461+
}),
462+
});
463+
464+
auto result = evaluate("greatest(c0)", data);
465+
auto expected = makeNullableFlatVector<int64_t>({3, 50});
466+
test::assertEqualVectors(expected, result);
467+
}
468+
469+
TEST_F(GreatestLeastTest, arrayWithNulls) {
470+
auto data = makeRowVector({
471+
makeNullableArrayVector<int64_t>({
472+
{{1, std::nullopt, 3}},
473+
{{std::nullopt, std::nullopt}},
474+
std::nullopt,
475+
}),
476+
});
477+
478+
auto result = evaluate("least(c0)", data);
479+
auto expected =
480+
makeNullableFlatVector<int64_t>({1, std::nullopt, std::nullopt});
481+
test::assertEqualVectors(expected, result);
482+
483+
result = evaluate("greatest(c0)", data);
484+
expected = makeNullableFlatVector<int64_t>({3, std::nullopt, std::nullopt});
485+
test::assertEqualVectors(expected, result);
486+
}
487+
488+
TEST_F(GreatestLeastTest, arrayEmpty) {
489+
auto data = makeRowVector({
490+
makeArrayVector<int64_t>({{}}),
491+
});
492+
493+
auto result = evaluate("least(c0)", data);
494+
auto expected = makeNullableFlatVector<int64_t>({std::nullopt});
495+
test::assertEqualVectors(expected, result);
496+
497+
result = evaluate("greatest(c0)", data);
498+
test::assertEqualVectors(expected, result);
499+
}
500+
501+
TEST_F(GreatestLeastTest, arrayDouble) {
502+
auto data = makeRowVector({
503+
makeArrayVector<double>({
504+
{1.1, 2.2, 3.3},
505+
{-100.5, 50.5, 0.0},
506+
}),
507+
});
508+
509+
auto result = evaluate("least(c0)", data);
510+
auto expected = makeNullableFlatVector<double>({1.1, -100.5});
511+
test::assertEqualVectors(expected, result);
512+
513+
result = evaluate("greatest(c0)", data);
514+
expected = makeNullableFlatVector<double>({3.3, 50.5});
515+
test::assertEqualVectors(expected, result);
516+
}
517+
518+
TEST_F(GreatestLeastTest, arrayNaN) {
519+
auto data = makeRowVector({
520+
makeArrayVector<double>({
521+
{1.0, std::nan("1"), 2.0},
522+
{std::nan("1"), -std::numeric_limits<double>::infinity(), 0.0},
523+
}),
524+
});
525+
526+
auto result = evaluate("greatest(c0)", data);
527+
EXPECT_TRUE(std::isnan(result->asFlatVector<double>()->valueAt(0)));
528+
EXPECT_TRUE(std::isnan(result->asFlatVector<double>()->valueAt(1)));
529+
530+
result = evaluate("least(c0)", data);
531+
EXPECT_EQ(result->asFlatVector<double>()->valueAt(0), 1.0);
532+
EXPECT_EQ(
533+
result->asFlatVector<double>()->valueAt(1),
534+
-std::numeric_limits<double>::infinity());
535+
}
536+
537+
TEST_F(GreatestLeastTest, arrayVarchar) {
538+
auto data = makeRowVector({
539+
makeArrayVector<StringView>({
540+
{"apple"_sv, "banana"_sv, "cherry"_sv},
541+
{"zebra"_sv, "aardvark"_sv, "monkey"_sv},
542+
}),
543+
});
544+
545+
auto result = evaluate("least(c0)", data);
546+
auto expected =
547+
makeNullableFlatVector<StringView>({"apple"_sv, "aardvark"_sv});
548+
test::assertEqualVectors(expected, result);
549+
550+
result = evaluate("greatest(c0)", data);
551+
expected = makeNullableFlatVector<StringView>({"cherry"_sv, "zebra"_sv});
552+
test::assertEqualVectors(expected, result);
553+
}
554+
555+
TEST_F(GreatestLeastTest, arrayDate) {
556+
auto data = makeRowVector({
557+
makeArrayVector<int32_t>({{0, 5, -10}, {100, -50, 25}}, ARRAY(DATE())),
558+
});
559+
560+
auto result = evaluate("least(c0)", data);
561+
auto expected = makeNullableFlatVector<int32_t>({-10, -50}, DATE());
562+
test::assertEqualVectors(expected, result);
563+
564+
result = evaluate("greatest(c0)", data);
565+
expected = makeNullableFlatVector<int32_t>({5, 100}, DATE());
566+
test::assertEqualVectors(expected, result);
567+
}

0 commit comments

Comments
 (0)