Skip to content

Commit 51a05a1

Browse files
committed
Register companion functions
1 parent 9304499 commit 51a05a1

File tree

11 files changed

+104
-17
lines changed

11 files changed

+104
-17
lines changed

velox/expression/SwitchExpr.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,17 @@ TypePtr SwitchExpr::resolveType(const std::vector<TypePtr>& argTypes) {
253253
if (hasElse) {
254254
auto& elseClauseType = argTypes.back();
255255

256-
VELOX_CHECK(
256+
if (elseClauseType->isDecimal()) {
257+
// Regard decimals as the same type regardless of precision and scale.
258+
VELOX_CHECK(expressionType->isDecimal());
259+
} else {
260+
VELOX_CHECK(
257261
*elseClauseType == *expressionType,
258262
"Else clause of a SWITCH statement must have the same type as 'then' clauses. "
259263
"Expected {}, but got {}.",
260264
expressionType->toString(),
261265
elseClauseType->toString());
266+
}
262267
}
263268

264269
return expressionType;

velox/functions/lib/aggregates/BitwiseAggregateBase.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ exec::AggregateRegistrationResult registerBitwise(const std::string& name) {
106106
name,
107107
inputType->kindName());
108108
}
109-
});
109+
},
110+
/*registerCompanionFunctions*/ true,
111+
/*overwrite*/ true);
110112
}
111113

112114
} // namespace facebook::velox::functions::aggregate

velox/functions/prestosql/aggregates/CountAggregate.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ exec::AggregateRegistrationResult registerCount(const std::string& name) {
172172
VELOX_CHECK_LE(
173173
argTypes.size(), 1, "{} takes at most one argument", name);
174174
return std::make_unique<CountAggregate>();
175-
});
175+
},
176+
true);
176177
}
177178

178179
} // namespace

velox/functions/prestosql/aggregates/CovarianceAggregates.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,8 @@ exec::AggregateRegistrationResult registerCovariance(const std::string& name) {
608608
"Unsupported raw input type: {}. Expected DOUBLE or REAL.",
609609
rawInputType->toString())
610610
}
611-
});
611+
},
612+
true);
612613
}
613614

614615
} // namespace

velox/functions/prestosql/aggregates/MinMaxAggregates.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,8 @@ exec::AggregateRegistrationResult registerMinMax(const std::string& name) {
974974
inputType->kindName());
975975
}
976976
}
977-
});
977+
},
978+
/*registerCompanionFunctions*/ true);
978979
}
979980

980981
} // namespace

velox/functions/prestosql/aggregates/SumAggregate.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,8 @@ exec::AggregateRegistrationResult registerSum(const std::string& name) {
272272
name,
273273
inputType->kindName());
274274
}
275-
});
275+
},
276+
true);
276277
}
277278

278279
} // namespace facebook::velox::aggregate::prestosql

velox/functions/prestosql/aggregates/VarianceAggregates.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,8 @@ exec::AggregateRegistrationResult registerVariance(const std::string& name) {
508508
"(count:bigint, mean:double, m2:double) struct");
509509
return std::make_unique<TClass<int64_t>>(resultType);
510510
}
511-
});
511+
},
512+
/*registerCompanionFunctions*/ true);
512513
}
513514

514515
} // namespace

velox/functions/sparksql/aggregates/AverageAggregate.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,15 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) {
379379
.returnType("DECIMAL(r_precision, r_scale)")
380380
.build());
381381

382+
signatures.push_back(
383+
exec::AggregateFunctionSignatureBuilder()
384+
.integerVariable("a_precision")
385+
.integerVariable("a_scale")
386+
.argumentType("DECIMAL(a_precision, a_scale)")
387+
.intermediateType("ROW(DECIMAL(a_precision, a_scale), BIGINT)")
388+
.returnType("DECIMAL(a_precision, a_scale)")
389+
.build());
390+
382391
return exec::registerAggregateFunction(
383392
name,
384393
std::move(signatures),
@@ -465,7 +474,8 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) {
465474
}
466475
}
467476
},
468-
/*registerCompanionFunctions*/ true);
477+
/*registerCompanionFunctions*/ true,
478+
/*overwrite*/ true);
469479
}
470480

471481
} // namespace facebook::velox::functions::aggregate::sparksql

velox/functions/sparksql/aggregates/DecimalSumAggregate.h

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "velox/expression/FunctionSignature.h"
1919
#include "velox/vector/FlatVector.h"
2020

21+
#include "velox/functions/prestosql/aggregates/SumAggregate.h"
22+
2123
namespace facebook::velox::functions::aggregate::sparksql {
2224

2325
struct DecimalSum {
@@ -377,9 +379,19 @@ class DecimalSumAggregate : public exec::Aggregate {
377379
TypePtr sumType_;
378380
};
379381

380-
exec::AggregateRegistrationResult registerDecimalSumAggregate(
382+
exec::AggregateRegistrationResult registerSumAggregate(
381383
const std::string& name) {
382384
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
385+
exec::AggregateFunctionSignatureBuilder()
386+
.returnType("real")
387+
.intermediateType("double")
388+
.argumentType("real")
389+
.build(),
390+
exec::AggregateFunctionSignatureBuilder()
391+
.returnType("double")
392+
.intermediateType("double")
393+
.argumentType("double")
394+
.build(),
383395
exec::AggregateFunctionSignatureBuilder()
384396
.integerVariable("a_precision")
385397
.integerVariable("a_scale")
@@ -388,7 +400,16 @@ exec::AggregateRegistrationResult registerDecimalSumAggregate(
388400
.argumentType("DECIMAL(a_precision, a_scale)")
389401
.intermediateType("ROW(DECIMAL(r_precision, r_scale), boolean)")
390402
.returnType("DECIMAL(r_precision, r_scale)")
391-
.build()};
403+
.build(),
404+
};
405+
406+
for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) {
407+
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
408+
.returnType("bigint")
409+
.intermediateType("bigint")
410+
.argumentType(inputType)
411+
.build());
412+
}
392413

393414
return exec::registerAggregateFunction(
394415
name,
@@ -401,32 +422,73 @@ exec::AggregateRegistrationResult registerDecimalSumAggregate(
401422
-> std::unique_ptr<exec::Aggregate> {
402423
VELOX_CHECK_EQ(argTypes.size(), 1, "{} takes only one argument", name);
403424
auto& inputType = argTypes[0];
404-
auto sumType =
405-
exec::isPartialOutput(step) ? resultType->childAt(0) : resultType;
406425
switch (inputType->kind()) {
426+
case TypeKind::TINYINT:
427+
return std::make_unique<velox::aggregate::prestosql::
428+
SumAggregate<int8_t, int64_t, int64_t>>(
429+
BIGINT());
430+
case TypeKind::SMALLINT:
431+
return std::make_unique<
432+
velox::aggregate::prestosql::
433+
SumAggregate<int16_t, int64_t, int64_t>>(BIGINT());
434+
case TypeKind::INTEGER:
435+
return std::make_unique<
436+
velox::aggregate::prestosql::
437+
SumAggregate<int32_t, int64_t, int64_t>>(BIGINT());
407438
case TypeKind::BIGINT: {
408-
DCHECK(exec::isRawInput(step));
409439
if (inputType->isShortDecimal()) {
440+
auto sumType = exec::isPartialOutput(step)
441+
? resultType->childAt(0)
442+
: resultType;
410443
if (sumType->isShortDecimal()) {
411444
return std::make_unique<DecimalSumAggregate<int64_t, int64_t>>(
412445
resultType, sumType);
413446
} else if (sumType->isLongDecimal()) {
414447
return std::make_unique<DecimalSumAggregate<int64_t, int128_t>>(
415448
resultType, sumType);
416449
}
450+
VELOX_UNREACHABLE();
417451
}
452+
return std::make_unique<
453+
velox::aggregate::prestosql::
454+
SumAggregate<int64_t, int64_t, int64_t>>(BIGINT());
418455
}
419-
case TypeKind::HUGEINT:
456+
case TypeKind::HUGEINT: {
420457
if (inputType->isLongDecimal()) {
458+
auto sumType = exec::isPartialOutput(step)
459+
? resultType->childAt(0)
460+
: resultType;
421461
// If inputType is long decimal,
422462
// its output type always be long decimal.
423463
return std::make_unique<DecimalSumAggregate<int128_t, int128_t>>(
424464
resultType, sumType);
425465
}
466+
VELOX_NYI();
467+
}
468+
case TypeKind::REAL:
469+
if (resultType->kind() == TypeKind::REAL) {
470+
return std::make_unique<velox::aggregate::prestosql::
471+
SumAggregate<float, double, float>>(
472+
resultType);
473+
}
474+
return std::make_unique<velox::aggregate::prestosql::
475+
SumAggregate<float, double, double>>(
476+
DOUBLE());
477+
case TypeKind::DOUBLE:
478+
if (resultType->kind() == TypeKind::REAL) {
479+
return std::make_unique<velox::aggregate::prestosql::
480+
SumAggregate<double, double, float>>(
481+
resultType);
482+
}
483+
return std::make_unique<velox::aggregate::prestosql::
484+
SumAggregate<double, double, double>>(
485+
DOUBLE());
426486
case TypeKind::ROW: {
427487
DCHECK(!exec::isRawInput(step));
428488
// For intermediate input agg, input intermediate sum type
429489
// is equal to final result sum type.
490+
auto sumType = exec::isPartialOutput(step) ? resultType->childAt(0)
491+
: resultType;
430492
if (inputType->childAt(0)->isShortDecimal()) {
431493
return std::make_unique<DecimalSumAggregate<int64_t, int64_t>>(
432494
resultType, sumType);
@@ -443,7 +505,8 @@ exec::AggregateRegistrationResult registerDecimalSumAggregate(
443505
inputType->kindName());
444506
}
445507
},
446-
true);
508+
/*registerCompanionFunctions*/ true,
509+
/*overwrite*/ true);
447510
}
448511

449512
} // namespace facebook::velox::functions::aggregate::sparksql

velox/functions/sparksql/aggregates/FirstLastAggregate.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,9 @@ AggregateRegistrationResult registerFirstLast(const std::string& name) {
499499
name,
500500
inputType->toString());
501501
}
502-
});
502+
},
503+
/*registerCompanionFunctions*/ true,
504+
/*overwrite*/ true);
503505
}
504506

505507
void registerFirstLastAggregates(const std::string& prefix) {

0 commit comments

Comments
 (0)