1818#include " velox/expression/FunctionSignature.h"
1919#include " velox/vector/FlatVector.h"
2020
21+ #include " velox/functions/prestosql/aggregates/SumAggregate.h"
22+
2123namespace facebook ::velox::functions::aggregate::sparksql {
2224
2325struct DecimalSum {
@@ -377,9 +379,19 @@ class DecimalSumAggregate : public exec::Aggregate {
377379 TypePtr sumType_;
378380};
379381
380- exec::AggregateRegistrationResult registerDecimalSumAggregate (
382+ exec::AggregateRegistrationResult registerSumAggregate (
381383 const std::string& name) {
382384 std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
385+ exec::AggregateFunctionSignatureBuilder ()
386+ .returnType (" real" )
387+ .intermediateType (" double" )
388+ .argumentType (" real" )
389+ .build (),
390+ exec::AggregateFunctionSignatureBuilder ()
391+ .returnType (" double" )
392+ .intermediateType (" double" )
393+ .argumentType (" double" )
394+ .build (),
383395 exec::AggregateFunctionSignatureBuilder ()
384396 .integerVariable (" a_precision" )
385397 .integerVariable (" a_scale" )
@@ -388,7 +400,16 @@ exec::AggregateRegistrationResult registerDecimalSumAggregate(
388400 .argumentType (" DECIMAL(a_precision, a_scale)" )
389401 .intermediateType (" ROW(DECIMAL(r_precision, r_scale), boolean)" )
390402 .returnType (" DECIMAL(r_precision, r_scale)" )
391- .build ()};
403+ .build (),
404+ };
405+
406+ for (const auto & inputType : {" tinyint" , " smallint" , " integer" , " bigint" }) {
407+ signatures.push_back (exec::AggregateFunctionSignatureBuilder ()
408+ .returnType (" bigint" )
409+ .intermediateType (" bigint" )
410+ .argumentType (inputType)
411+ .build ());
412+ }
392413
393414 return exec::registerAggregateFunction (
394415 name,
@@ -401,32 +422,73 @@ exec::AggregateRegistrationResult registerDecimalSumAggregate(
401422 -> std::unique_ptr<exec::Aggregate> {
402423 VELOX_CHECK_EQ (argTypes.size (), 1 , " {} takes only one argument" , name);
403424 auto & inputType = argTypes[0 ];
404- auto sumType =
405- exec::isPartialOutput (step) ? resultType->childAt (0 ) : resultType;
406425 switch (inputType->kind ()) {
426+ case TypeKind::TINYINT:
427+ return std::make_unique<velox::aggregate::prestosql::
428+ SumAggregate<int8_t , int64_t , int64_t >>(
429+ BIGINT ());
430+ case TypeKind::SMALLINT:
431+ return std::make_unique<
432+ velox::aggregate::prestosql::
433+ SumAggregate<int16_t , int64_t , int64_t >>(BIGINT ());
434+ case TypeKind::INTEGER:
435+ return std::make_unique<
436+ velox::aggregate::prestosql::
437+ SumAggregate<int32_t , int64_t , int64_t >>(BIGINT ());
407438 case TypeKind::BIGINT: {
408- DCHECK (exec::isRawInput (step));
409439 if (inputType->isShortDecimal ()) {
440+ auto sumType = exec::isPartialOutput (step)
441+ ? resultType->childAt (0 )
442+ : resultType;
410443 if (sumType->isShortDecimal ()) {
411444 return std::make_unique<DecimalSumAggregate<int64_t , int64_t >>(
412445 resultType, sumType);
413446 } else if (sumType->isLongDecimal ()) {
414447 return std::make_unique<DecimalSumAggregate<int64_t , int128_t >>(
415448 resultType, sumType);
416449 }
450+ VELOX_UNREACHABLE ();
417451 }
452+ return std::make_unique<
453+ velox::aggregate::prestosql::
454+ SumAggregate<int64_t , int64_t , int64_t >>(BIGINT ());
418455 }
419- case TypeKind::HUGEINT:
456+ case TypeKind::HUGEINT: {
420457 if (inputType->isLongDecimal ()) {
458+ auto sumType = exec::isPartialOutput (step)
459+ ? resultType->childAt (0 )
460+ : resultType;
421461 // If inputType is long decimal,
422462 // its output type always be long decimal.
423463 return std::make_unique<DecimalSumAggregate<int128_t , int128_t >>(
424464 resultType, sumType);
425465 }
466+ VELOX_NYI ();
467+ }
468+ case TypeKind::REAL:
469+ if (resultType->kind () == TypeKind::REAL) {
470+ return std::make_unique<velox::aggregate::prestosql::
471+ SumAggregate<float , double , float >>(
472+ resultType);
473+ }
474+ return std::make_unique<velox::aggregate::prestosql::
475+ SumAggregate<float , double , double >>(
476+ DOUBLE ());
477+ case TypeKind::DOUBLE:
478+ if (resultType->kind () == TypeKind::REAL) {
479+ return std::make_unique<velox::aggregate::prestosql::
480+ SumAggregate<double , double , float >>(
481+ resultType);
482+ }
483+ return std::make_unique<velox::aggregate::prestosql::
484+ SumAggregate<double , double , double >>(
485+ DOUBLE ());
426486 case TypeKind::ROW: {
427487 DCHECK (!exec::isRawInput (step));
428488 // For intermediate input agg, input intermediate sum type
429489 // is equal to final result sum type.
490+ auto sumType = exec::isPartialOutput (step) ? resultType->childAt (0 )
491+ : resultType;
430492 if (inputType->childAt (0 )->isShortDecimal ()) {
431493 return std::make_unique<DecimalSumAggregate<int64_t , int64_t >>(
432494 resultType, sumType);
@@ -443,7 +505,8 @@ exec::AggregateRegistrationResult registerDecimalSumAggregate(
443505 inputType->kindName ());
444506 }
445507 },
446- true );
508+ /* registerCompanionFunctions*/ true ,
509+ /* overwrite*/ true );
447510}
448511
449512} // namespace facebook::velox::functions::aggregate::sparksql
0 commit comments