Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions velox/functions/prestosql/GreatestLeast.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,87 @@ using LeastFunction = details::ExtremeValueFunction<TExec, T, true>;

template <typename TExec, typename T>
using GreatestFunction = details::ExtremeValueFunction<TExec, T, false>;
// Array overloads for least() and greatest()
template <typename TExec, typename T, bool isLeast>
struct ArrayExtremeValueFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

FOLLY_ALWAYS_INLINE bool call(
out_type<T>& result,
const arg_type<Array<T>>& array) {
if (array.size() == 0) {
// Empty array returns null
return false;
}

bool hasValue = false;
std::optional<typename arg_type<Array<T>>::element_t> currentValue;

for (const auto& element : array) {
if (element.has_value()) {
if (!hasValue) {
currentValue = element.value();
hasValue = true;
} else {
if constexpr (isLeast) {
if (smallerThan(element.value(), currentValue.value())) {
currentValue = element.value();
}
} else {
if (greaterThan(element.value(), currentValue.value())) {
currentValue = element.value();
}
}
}
}
}

if (!hasValue) {
// Array contains only nulls
return false;
}

result = currentValue.value();
return true;
}

private:
template <typename K>
bool greaterThan(const K& lhs, const K& rhs) const {
if constexpr (std::is_same_v<K, double> || std::is_same_v<K, float>) {
if (std::isnan(lhs)) {
return true;
}

if (std::isnan(rhs)) {
return false;
}
}

return lhs > rhs;
}

template <typename K>
bool smallerThan(const K& lhs, const K& rhs) const {
if constexpr (std::is_same_v<K, double> || std::is_same_v<K, float>) {
if (std::isnan(lhs)) {
return false;
}

if (std::isnan(rhs)) {
return true;
}
}

return lhs < rhs;
}
};

template <typename TExec, typename T>
struct ArrayLeastFunction : public ArrayExtremeValueFunction<TExec, T, true> {};

template <typename TExec, typename T>
struct ArrayGreatestFunction
: public ArrayExtremeValueFunction<TExec, T, false> {};

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,73 @@ void registerAllGreatestLeastFunctions(const std::string& prefix) {
registerGreatestLeastFunction<Timestamp>(prefix);
registerGreatestLeastFunction<TimestampWithTimezone>(prefix);
registerGreatestLeastFunction<IPAddress>(prefix);

// Register array overloads
registerFunction<
ParameterBinder<ArrayGreatestFunction, int8_t>,
int8_t,
Array<int8_t>>({prefix + "greatest"});
registerFunction<
ParameterBinder<ArrayGreatestFunction, int16_t>,
int16_t,
Array<int16_t>>({prefix + "greatest"});
registerFunction<
ParameterBinder<ArrayGreatestFunction, int32_t>,
int32_t,
Array<int32_t>>({prefix + "greatest"});
registerFunction<
ParameterBinder<ArrayGreatestFunction, int64_t>,
int64_t,
Array<int64_t>>({prefix + "greatest"});
registerFunction<
ParameterBinder<ArrayGreatestFunction, float>,
float,
Array<float>>({prefix + "greatest"});
registerFunction<
ParameterBinder<ArrayGreatestFunction, double>,
double,
Array<double>>({prefix + "greatest"});
registerFunction<
ParameterBinder<ArrayGreatestFunction, Varchar>,
Varchar,
Array<Varchar>>({prefix + "greatest"});
registerFunction<
ParameterBinder<ArrayGreatestFunction, Timestamp>,
Timestamp,
Array<Timestamp>>({prefix + "greatest"});

registerFunction<
ParameterBinder<ArrayLeastFunction, int8_t>,
int8_t,
Array<int8_t>>({prefix + "least"});
registerFunction<
ParameterBinder<ArrayLeastFunction, int16_t>,
int16_t,
Array<int16_t>>({prefix + "least"});
registerFunction<
ParameterBinder<ArrayLeastFunction, int32_t>,
int32_t,
Array<int32_t>>({prefix + "least"});
registerFunction<
ParameterBinder<ArrayLeastFunction, int64_t>,
int64_t,
Array<int64_t>>({prefix + "least"});
registerFunction<
ParameterBinder<ArrayLeastFunction, float>,
float,
Array<float>>({prefix + "least"});
registerFunction<
ParameterBinder<ArrayLeastFunction, double>,
double,
Array<double>>({prefix + "least"});
registerFunction<
ParameterBinder<ArrayLeastFunction, Varchar>,
Varchar,
Array<Varchar>>({prefix + "least"});
registerFunction<
ParameterBinder<ArrayLeastFunction, Timestamp>,
Timestamp,
Array<Timestamp>>({prefix + "least"});
}
} // namespace

Expand Down
112 changes: 112 additions & 0 deletions velox/functions/prestosql/tests/GreatestLeastTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,115 @@ TEST_F(GreatestLeastTest, boolean) {
{true, true, true, false, std::nullopt, std::nullopt});
test::assertEqualVectors(expected, result);
}
// Array overload tests for least() and greatest()
TEST_F(GreatestLeastTest, arrayLeastBigint) {
auto data = makeRowVector({
makeArrayVector<int64_t>({
{1, 2, 3},
{-100, 50, 0},
}),
});

auto result = evaluate("least(c0)", data);
auto expected = makeNullableFlatVector<int64_t>({1, -100});
test::assertEqualVectors(expected, result);
}

TEST_F(GreatestLeastTest, arrayGreatestBigint) {
auto data = makeRowVector({
makeArrayVector<int64_t>({
{1, 2, 3},
{-100, 50, 0},
}),
});

auto result = evaluate("greatest(c0)", data);
auto expected = makeNullableFlatVector<int64_t>({3, 50});
test::assertEqualVectors(expected, result);
}

TEST_F(GreatestLeastTest, arrayWithNulls) {
auto data = makeRowVector({
makeNullableArrayVector<int64_t>({
{{1, std::nullopt, 3}},
{{std::nullopt, std::nullopt}},
std::nullopt,
}),
});

auto result = evaluate("least(c0)", data);
auto expected =
makeNullableFlatVector<int64_t>({1, std::nullopt, std::nullopt});
test::assertEqualVectors(expected, result);

result = evaluate("greatest(c0)", data);
expected = makeNullableFlatVector<int64_t>({3, std::nullopt, std::nullopt});
test::assertEqualVectors(expected, result);
}

TEST_F(GreatestLeastTest, arrayEmpty) {
auto data = makeRowVector({
makeArrayVector<int64_t>({{}}),
});

auto result = evaluate("least(c0)", data);
auto expected = makeNullableFlatVector<int64_t>({std::nullopt});
test::assertEqualVectors(expected, result);

result = evaluate("greatest(c0)", data);
test::assertEqualVectors(expected, result);
}

TEST_F(GreatestLeastTest, arrayDouble) {
auto data = makeRowVector({
makeArrayVector<double>({
{1.1, 2.2, 3.3},
{-100.5, 50.5, 0.0},
}),
});

auto result = evaluate("least(c0)", data);
auto expected = makeNullableFlatVector<double>({1.1, -100.5});
test::assertEqualVectors(expected, result);

result = evaluate("greatest(c0)", data);
expected = makeNullableFlatVector<double>({3.3, 50.5});
test::assertEqualVectors(expected, result);
}

TEST_F(GreatestLeastTest, arrayNaN) {
auto data = makeRowVector({
makeArrayVector<double>({
{1.0, std::nan("1"), 2.0},
{std::nan("1"), -std::numeric_limits<double>::infinity(), 0.0},
}),
});

auto result = evaluate("greatest(c0)", data);
EXPECT_TRUE(std::isnan(result->asFlatVector<double>()->valueAt(0)));
EXPECT_TRUE(std::isnan(result->asFlatVector<double>()->valueAt(1)));

result = evaluate("least(c0)", data);
EXPECT_EQ(result->asFlatVector<double>()->valueAt(0), 1.0);
EXPECT_EQ(
result->asFlatVector<double>()->valueAt(1),
-std::numeric_limits<double>::infinity());
}

TEST_F(GreatestLeastTest, arrayVarchar) {
auto data = makeRowVector({
makeArrayVector<StringView>({
{"apple"_sv, "banana"_sv, "cherry"_sv},
{"zebra"_sv, "aardvark"_sv, "monkey"_sv},
}),
});

auto result = evaluate("least(c0)", data);
auto expected =
makeNullableFlatVector<StringView>({"apple"_sv, "aardvark"_sv});
test::assertEqualVectors(expected, result);

result = evaluate("greatest(c0)", data);
expected = makeNullableFlatVector<StringView>({"cherry"_sv, "zebra"_sv});
test::assertEqualVectors(expected, result);
}
Loading