Skip to content

Commit 74a7a98

Browse files
committed
adding param allowPrecisionLoss
Signed-off-by: Yuan Zhou <[email protected]>
1 parent 0034e47 commit 74a7a98

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

velox/core/QueryConfig.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class QueryConfig {
8080
// truncating the decimal part instead of rounding.
8181
static constexpr const char* kCastToIntByTruncate = "cast_to_int_by_truncate";
8282

83+
// This flags forces to bound the decimal precision.
84+
static constexpr const char* kAllowPrecisionLoss = "allow_precision_loss";
85+
8386
/// Used for backpressure to block local exchange producers when the local
8487
/// exchange buffer reaches or exceeds this size.
8588
static constexpr const char* kMaxLocalExchangeBufferSize =
@@ -329,6 +332,10 @@ class QueryConfig {
329332
return get<bool>(kCastToIntByTruncate, false);
330333
}
331334

335+
bool isAllowPrecisionLoss() const {
336+
return get<bool>(kAllowPrecisionLoss, true);
337+
}
338+
332339
bool codegenEnabled() const {
333340
return get<bool>(kCodegenEnabled, false);
334341
}

velox/functions/sparksql/DecimalArithmetic.cpp

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
namespace facebook::velox::functions::sparksql {
2424
namespace {
2525

26+
inline static std::pair<uint8_t, uint8_t> bounded(
27+
const uint8_t rPrecision,
28+
const uint8_t rScale) {
29+
return {std::min(rPrecision, 38), std::min(rScale, 38)};
30+
}
31+
2632
inline static std::pair<uint8_t, uint8_t> adjustPrecisionScale(
2733
const uint8_t rPrecision,
2834
const uint8_t rScale) {
@@ -385,11 +391,13 @@ class Addition {
385391
const uint8_t aPrecision,
386392
const uint8_t aScale,
387393
const uint8_t bPrecision,
388-
const uint8_t bScale) {
394+
const uint8_t bScale,
395+
const allowPrecisionLoss) {
389396
auto precision = std::max(aPrecision - aScale, bPrecision - bScale) +
390397
std::max(aScale, bScale) + 1;
391398
auto scale = std::max(aScale, bScale);
392-
return adjustPrecisionScale(precision, scale);
399+
return allowPrecisionLoss ? adjustPrecisionScale(precision, scale)
400+
: bounded(precision, scale);
393401
}
394402
};
395403

@@ -433,7 +441,8 @@ class Subtraction {
433441
const uint8_t aPrecision,
434442
const uint8_t aScale,
435443
const uint8_t bPrecision,
436-
const uint8_t bScale) {
444+
const uint8_t bScale,
445+
const bool allowPrecisionLoss) {
437446
return Addition::computeResultPrecisionScale(
438447
aPrecision, aScale, bPrecision, bScale);
439448
}
@@ -539,8 +548,11 @@ class Multiply {
539548
const uint8_t aPrecision,
540549
const uint8_t aScale,
541550
const uint8_t bPrecision,
542-
const uint8_t bScale) {
543-
return adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale);
551+
const uint8_t bScale,
552+
const bool allowPrecisionLoss) {
553+
return allowPrecisionLoss
554+
? adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale)
555+
: bounded(aPrecision + bPrecision + 1, aScale + bScale)
544556
}
545557

546558
private:
@@ -591,10 +603,22 @@ class Divide {
591603
const uint8_t aPrecision,
592604
const uint8_t aScale,
593605
const uint8_t bPrecision,
594-
const uint8_t bScale) {
595-
auto scale = std::max(6, aScale + bPrecision + 1);
596-
auto precision = aPrecision - aScale + bScale + scale;
597-
return adjustPrecisionScale(precision, scale);
606+
const uint8_t bScale,
607+
const bool allowPrecisionLoss) {
608+
if (allowPrecisionLoss) {
609+
auto scale = std::max(6, aScale + bPrecision + 1);
610+
auto precision = aPrecision - aScale + bScale + scale;
611+
return adjustPrecisionScale(precision, scale);
612+
} else {
613+
auto intDig = std::min(38, aPrecision - aScale + bScale);
614+
auto decDig =
615+
std::min(38, std::max(6, aScale + bPrecision + 1)) auto diff =
616+
(intDig + decDig) - 38;
617+
if (diff > 0) {
618+
decDig -= diff / 2 + 1 intDig = 38 - decDig
619+
}
620+
return bounded(intDig + decDig, decDig);
621+
}
598622
}
599623
};
600624

@@ -664,13 +688,14 @@ template <typename Operation>
664688
std::shared_ptr<exec::VectorFunction> createDecimalFunction(
665689
const std::string& name,
666690
const std::vector<exec::VectorFunctionArg>& inputArgs,
667-
const core::QueryConfig& /*config*/) {
691+
const core::QueryConfig& config) {
668692
auto aType = inputArgs[0].type;
669693
auto bType = inputArgs[1].type;
670694
auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
671695
auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
696+
const bool allowPrecisionLoss = config.isAllowPrecisionLoss();
672697
auto [rPrecision, rScale] = Operation::computeResultPrecisionScale(
673-
aPrecision, aScale, bPrecision, bScale);
698+
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
674699
uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale);
675700
uint8_t bRescale = Operation::computeRescaleFactor(bScale, aScale, rScale);
676701
if (aType->isShortDecimal()) {

0 commit comments

Comments
 (0)