|
23 | 23 | namespace facebook::velox::functions::sparksql { |
24 | 24 | namespace { |
25 | 25 |
|
| 26 | +inline static std::pair<uint8_t, uint8_t> bounded( |
| 27 | + const uint8_t rPrecision, |
| 28 | + const uint8_t rScale) { |
| 29 | + return {std::min(rPrecision, 38), std::min(rScale, 38)}; |
| 30 | +} |
| 31 | + |
26 | 32 | inline static std::pair<uint8_t, uint8_t> adjustPrecisionScale( |
27 | 33 | const uint8_t rPrecision, |
28 | 34 | const uint8_t rScale) { |
@@ -385,11 +391,13 @@ class Addition { |
385 | 391 | const uint8_t aPrecision, |
386 | 392 | const uint8_t aScale, |
387 | 393 | const uint8_t bPrecision, |
388 | | - const uint8_t bScale) { |
| 394 | + const uint8_t bScale, |
| 395 | + const allowPrecisionLoss) { |
389 | 396 | auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + |
390 | 397 | std::max(aScale, bScale) + 1; |
391 | 398 | auto scale = std::max(aScale, bScale); |
392 | | - return adjustPrecisionScale(precision, scale); |
| 399 | + return allowPrecisionLoss ? adjustPrecisionScale(precision, scale) |
| 400 | + : bounded(precision, scale); |
393 | 401 | } |
394 | 402 | }; |
395 | 403 |
|
@@ -433,7 +441,8 @@ class Subtraction { |
433 | 441 | const uint8_t aPrecision, |
434 | 442 | const uint8_t aScale, |
435 | 443 | const uint8_t bPrecision, |
436 | | - const uint8_t bScale) { |
| 444 | + const uint8_t bScale, |
| 445 | + const bool allowPrecisionLoss) { |
437 | 446 | return Addition::computeResultPrecisionScale( |
438 | 447 | aPrecision, aScale, bPrecision, bScale); |
439 | 448 | } |
@@ -539,8 +548,11 @@ class Multiply { |
539 | 548 | const uint8_t aPrecision, |
540 | 549 | const uint8_t aScale, |
541 | 550 | const uint8_t bPrecision, |
542 | | - const uint8_t bScale) { |
543 | | - return adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale); |
| 551 | + const uint8_t bScale, |
| 552 | + const bool allowPrecisionLoss) { |
| 553 | + return allowPrecisionLoss |
| 554 | + ? adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale) |
| 555 | + : bounded(aPrecision + bPrecision + 1, aScale + bScale) |
544 | 556 | } |
545 | 557 |
|
546 | 558 | private: |
@@ -591,10 +603,22 @@ class Divide { |
591 | 603 | const uint8_t aPrecision, |
592 | 604 | const uint8_t aScale, |
593 | 605 | const uint8_t bPrecision, |
594 | | - const uint8_t bScale) { |
595 | | - auto scale = std::max(6, aScale + bPrecision + 1); |
596 | | - auto precision = aPrecision - aScale + bScale + scale; |
597 | | - return adjustPrecisionScale(precision, scale); |
| 606 | + const uint8_t bScale, |
| 607 | + const bool allowPrecisionLoss) { |
| 608 | + if (allowPrecisionLoss) { |
| 609 | + auto scale = std::max(6, aScale + bPrecision + 1); |
| 610 | + auto precision = aPrecision - aScale + bScale + scale; |
| 611 | + return adjustPrecisionScale(precision, scale); |
| 612 | + } else { |
| 613 | + auto intDig = std::min(38, aPrecision - aScale + bScale); |
| 614 | + auto decDig = |
| 615 | + std::min(38, std::max(6, aScale + bPrecision + 1)) auto diff = |
| 616 | + (intDig + decDig) - 38; |
| 617 | + if (diff > 0) { |
| 618 | + decDig -= diff / 2 + 1 intDig = 38 - decDig |
| 619 | + } |
| 620 | + return bounded(intDig + decDig, decDig); |
| 621 | + } |
598 | 622 | } |
599 | 623 | }; |
600 | 624 |
|
@@ -664,13 +688,14 @@ template <typename Operation> |
664 | 688 | std::shared_ptr<exec::VectorFunction> createDecimalFunction( |
665 | 689 | const std::string& name, |
666 | 690 | const std::vector<exec::VectorFunctionArg>& inputArgs, |
667 | | - const core::QueryConfig& /*config*/) { |
| 691 | + const core::QueryConfig& config) { |
668 | 692 | auto aType = inputArgs[0].type; |
669 | 693 | auto bType = inputArgs[1].type; |
670 | 694 | auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); |
671 | 695 | auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); |
| 696 | + const bool allowPrecisionLoss = config.isAllowPrecisionLoss(); |
672 | 697 | auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( |
673 | | - aPrecision, aScale, bPrecision, bScale); |
| 698 | + aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss); |
674 | 699 | uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale); |
675 | 700 | uint8_t bRescale = Operation::computeRescaleFactor(bScale, aScale, rScale); |
676 | 701 | if (aType->isShortDecimal()) { |
|
0 commit comments