Skip to content

Commit 80deccd

Browse files
committed
Add Spark atan2 function (7113)
1 parent bec32c0 commit 80deccd

File tree

4 files changed

+22
-1
lines changed

4 files changed

+22
-1
lines changed

velox/docs/functions/spark/math.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ Mathematical Functions
1818
1919
Returns inverse hyperbolic sine of ``x``.
2020

21+
.. spark:function:: atan2(x, y) -> double
22+
23+
Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates(x, y).
24+
2125
.. spark:function:: atanh(x) -> double
2226
2327
Returns inverse hyperbolic tangent of ``x``.

velox/functions/sparksql/Arithmetic.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,13 @@ struct Log10Function {
280280
return true;
281281
}
282282
};
283+
284+
template <typename T>
285+
struct Atan2Function {
286+
template <typename TInput>
287+
FOLLY_ALWAYS_INLINE void call(TInput& result, TInput y, TInput x) {
288+
result = std::atan2(y + 0.0, x + 0.0);
289+
}
290+
};
291+
283292
} // namespace facebook::velox::functions::sparksql

velox/functions/sparksql/RegisterArithmetic.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ void registerArithmeticFunctions(const std::string& prefix) {
9595
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_sub, prefix + "subtract");
9696
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_mul, prefix + "multiply");
9797
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_div, prefix + "divide");
98+
registerFunction<Atan2Function, double, double, double>({prefix + "atan2"});
9899
}
99100

100101
} // namespace facebook::velox::functions::sparksql

velox/functions/sparksql/tests/ArithmeticTest.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,14 @@ TEST_F(ArithmeticTest, cot) {
376376
EXPECT_EQ(cot(0), 1 / std::tan(0));
377377
}
378378

379+
TEST_F(ArithmeticTest, atan2) {
380+
const auto atan2 = [&](std::optional<double> y, std::optional<double> x) {
381+
return evaluateOnce<double>("atan2(c0, c1)", y, x);
382+
};
383+
384+
EXPECT_EQ(atan2(0, 0), 0.0);
385+
}
386+
379387
class LogNTest : public SparkFunctionBaseTest {
380388
protected:
381389
static constexpr float kInf = std::numeric_limits<double>::infinity();
@@ -400,6 +408,5 @@ TEST_F(LogNTest, log10) {
400408
EXPECT_EQ(log10(-1.0), std::nullopt);
401409
EXPECT_EQ(log10(kInf), kInf);
402410
}
403-
404411
} // namespace
405412
} // namespace facebook::velox::functions::sparksql::test

0 commit comments

Comments
 (0)