Skip to content

Commit f60a7db

Browse files
JkSelfGlutenPerfBot
authored andcommitted
[11771] [11772] Fix smj result mismatch issue
1 parent 226a130 commit f60a7db

File tree

5 files changed

+310
-114
lines changed

5 files changed

+310
-114
lines changed

velox/exec/MergeJoin.cpp

Lines changed: 102 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "velox/exec/Task.h"
1919
#include "velox/expression/FieldReference.h"
2020

21+
#include <iostream>
22+
2123
namespace facebook::velox::exec {
2224

2325
MergeJoin::MergeJoin(
@@ -93,7 +95,7 @@ void MergeJoin::initialize() {
9395
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
9496
joinTracker_ = JoinTracker(outputBatchSize_, pool());
9597
}
96-
} else if (joinNode_->isAntiJoin()) {
98+
} else if (joinNode_->isAntiJoin() || joinNode_->isFullJoin()) {
9799
// Anti join needs to track the left side rows that have no match on the
98100
// right.
99101
joinTracker_ = JoinTracker(outputBatchSize_, pool());
@@ -387,7 +389,8 @@ bool MergeJoin::tryAddOutputRow(
387389
const RowVectorPtr& leftBatch,
388390
vector_size_t leftRow,
389391
const RowVectorPtr& rightBatch,
390-
vector_size_t rightRow) {
392+
vector_size_t rightRow,
393+
bool isRightJoinForFullOuter) {
391394
if (outputSize_ == outputBatchSize_) {
392395
return false;
393396
}
@@ -421,12 +424,15 @@ bool MergeJoin::tryAddOutputRow(
421424
filterRightInputProjections_);
422425

423426
if (joinTracker_) {
424-
if (isRightJoin(joinType_)) {
427+
if (isRightJoin(joinType_) ||
428+
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
425429
// Record right-side row with a match on the left-side.
426-
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
430+
joinTracker_->addMatch(
431+
rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
427432
} else {
428433
// Record left-side row with a match on the right-side.
429-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
434+
joinTracker_->addMatch(
435+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
430436
}
431437
}
432438
}
@@ -436,7 +442,8 @@ bool MergeJoin::tryAddOutputRow(
436442
if (isAntiJoin(joinType_)) {
437443
VELOX_CHECK(joinTracker_.has_value());
438444
// Record left-side row with a match on the right-side.
439-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
445+
joinTracker_->addMatch(
446+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
440447
}
441448

442449
++outputSize_;
@@ -455,14 +462,14 @@ bool MergeJoin::prepareOutput(
455462
return true;
456463
}
457464

458-
if (isRightJoin(joinType_) && right != currentRight_) {
459-
return true;
460-
}
461-
462465
// If there is a new right, we need to flatten the dictionary.
463466
if (!isRightFlattened_ && right && currentRight_ != right) {
464467
flattenRightProjections();
465468
}
469+
470+
if (right != currentRight_) {
471+
return true;
472+
}
466473
return false;
467474
}
468475

@@ -574,6 +581,39 @@ bool MergeJoin::prepareOutput(
574581
bool MergeJoin::addToOutput() {
575582
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
576583
return addToOutputForRightJoin();
584+
} else if (isFullJoin(joinType_) && filter_) {
585+
if (!leftForRightJoinMatch_) {
586+
leftForRightJoinMatch_ = leftMatch_;
587+
rightForRightJoinMatch_ = rightMatch_;
588+
}
589+
590+
if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
591+
auto left = addToOutputForLeftJoin();
592+
if (!leftMatch_) {
593+
leftJoinForFullFinished_ = true;
594+
}
595+
if (left) {
596+
if (!leftMatch_) {
597+
leftMatch_ = leftForRightJoinMatch_;
598+
rightMatch_ = rightForRightJoinMatch_;
599+
}
600+
601+
return true;
602+
}
603+
}
604+
605+
if (!leftMatch_ && !rightJoinForFullFinished_) {
606+
leftMatch_ = leftForRightJoinMatch_;
607+
rightMatch_ = rightForRightJoinMatch_;
608+
rightJoinForFullFinished_ = true;
609+
}
610+
611+
auto right = addToOutputForRightJoin();
612+
613+
leftForRightJoinMatch_ = leftMatch_;
614+
rightForRightJoinMatch_ = rightMatch_;
615+
616+
return right;
577617
} else {
578618
return addToOutputForLeftJoin();
579619
}
@@ -720,7 +760,13 @@ bool MergeJoin::addToOutputForRightJoin() {
720760
}
721761

722762
for (auto j = leftStartRow; j < leftEndRow; ++j) {
723-
if (!tryAddOutputRow(leftBatch, j, rightBatch, i)) {
763+
auto isRightJoinForFullOuter = false;
764+
if (isFullJoin(joinType_)) {
765+
isRightJoinForFullOuter = true;
766+
}
767+
768+
if (!tryAddOutputRow(
769+
leftBatch, j, rightBatch, i, isRightJoinForFullOuter)) {
724770
// If we run out of space in the current output_, we will need to
725771
// produce a buffer and continue processing left later. In this
726772
// case, we cannot leave left as a lazy vector, since we cannot have
@@ -818,7 +864,7 @@ RowVectorPtr MergeJoin::getOutput() {
818864
continue;
819865
} else if (isAntiJoin(joinType_)) {
820866
output = filterOutputForAntiJoin(output);
821-
if (output) {
867+
if (output != nullptr && output->size() > 0) {
822868
return output;
823869
}
824870

@@ -926,6 +972,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
926972
// results from the current match.
927973
if (addToOutput()) {
928974
return std::move(output_);
975+
} else {
976+
previousLeftMatch_ = leftMatch_;
929977
}
930978
}
931979

@@ -990,6 +1038,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
9901038

9911039
if (addToOutput()) {
9921040
return std::move(output_);
1041+
} else {
1042+
previousLeftMatch_ = leftMatch_;
9931043
}
9941044
}
9951045

@@ -1134,7 +1184,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11341184
isFullJoin(joinType_)) {
11351185
// If output_ is currently wrapping a different buffer, return it
11361186
// first.
1137-
if (prepareOutput(input_, nullptr)) {
1187+
if (prepareOutput(input_, rightInput_)) {
11381188
output_->resize(outputSize_);
11391189
return std::move(output_);
11401190
}
@@ -1159,7 +1209,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11591209
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
11601210
// If output_ is currently wrapping a different buffer, return it
11611211
// first.
1162-
if (prepareOutput(nullptr, rightInput_)) {
1212+
if (prepareOutput(input_, rightInput_)) {
11631213
output_->resize(outputSize_);
11641214
return std::move(output_);
11651215
}
@@ -1211,6 +1261,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
12111261
endRightRow < rightInput_->size(),
12121262
std::nullopt};
12131263

1264+
leftJoinForFullFinished_ = false;
1265+
rightJoinForFullFinished_ = false;
12141266
if (!leftMatch_->complete || !rightMatch_->complete) {
12151267
if (!leftMatch_->complete) {
12161268
// Need to continue looking for the end of match.
@@ -1239,6 +1291,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
12391291

12401292
if (addToOutput()) {
12411293
return std::move(output_);
1294+
} else {
1295+
previousLeftMatch_ = leftMatch_;
12421296
}
12431297

12441298
if (!rightInput_) {
@@ -1255,8 +1309,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
12551309
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12561310
const auto numRows = output->size();
12571311

1258-
RowVectorPtr fullOuterOutput = nullptr;
1259-
12601312
BufferPtr indices = allocateIndices(numRows, pool());
12611313
auto* rawIndices = indices->asMutable<vector_size_t>();
12621314
vector_size_t numPassed = 0;
@@ -1273,76 +1325,29 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12731325

12741326
// If all matches for a given left-side row fail the filter, add a row to
12751327
// the output with nulls for the right-side columns.
1276-
const auto onMiss = [&](auto row) {
1277-
if (isAntiJoin(joinType_)) {
1278-
return;
1279-
}
1280-
rawIndices[numPassed++] = row;
1281-
1282-
if (isFullJoin(joinType_)) {
1283-
// For filtered rows, it is necessary to insert additional data
1284-
// to ensure the result set is complete. Specifically, we
1285-
// need to generate two records: one record containing the
1286-
// columns from the left table along with nulls for the
1287-
// right table, and another record containing the columns
1288-
// from the right table along with nulls for the left table.
1289-
// For instance, the current output is filtered based on the condition
1290-
// t > 1.
1291-
1292-
// 1, 1
1293-
// 2, 2
1294-
// 3, 3
1295-
1296-
// In this scenario, we need to additionally insert a record 1, 1.
1297-
// Subsequently, we will set the values of the columns on the left to
1298-
// null and the values of the columns on the right to null as well. By
1299-
// doing so, we will obtain the final result set.
1300-
1301-
// 1, null
1302-
// null, 1
1303-
// 2, 2
1304-
// 3, 3
1305-
fullOuterOutput = BaseVector::create<RowVector>(
1306-
output->type(), output->size() + 1, pool());
1307-
1308-
for (auto i = 0; i < row + 1; ++i) {
1309-
for (auto j = 0; j < output->type()->size(); ++j) {
1310-
fullOuterOutput->childAt(j)->copy(
1311-
output->childAt(j).get(), i, i, 1);
1328+
auto onMiss = [&](auto row, bool flag) {
1329+
if (!isLeftSemiFilterJoin(joinType_) &&
1330+
!isRightSemiFilterJoin(joinType_)) {
1331+
rawIndices[numPassed++] = row;
1332+
1333+
if (!isRightJoin(joinType_)) {
1334+
if (isFullJoin(joinType_) && flag) {
1335+
for (auto& projection : leftProjections_) {
1336+
auto target = output->childAt(projection.outputChannel);
1337+
target->setNull(row, true);
1338+
}
1339+
} else {
1340+
for (auto& projection : rightProjections_) {
1341+
auto target = output->childAt(projection.outputChannel);
1342+
target->setNull(row, true);
1343+
}
13121344
}
1313-
}
1314-
1315-
for (auto j = 0; j < output->type()->size(); ++j) {
1316-
fullOuterOutput->childAt(j)->copy(
1317-
output->childAt(j).get(), row + 1, row, 1);
1318-
}
1319-
1320-
for (auto i = row + 1; i < output->size(); ++i) {
1321-
for (auto j = 0; j < output->type()->size(); ++j) {
1322-
fullOuterOutput->childAt(j)->copy(
1323-
output->childAt(j).get(), i + 1, i, 1);
1345+
} else {
1346+
for (auto& projection : leftProjections_) {
1347+
auto target = output->childAt(projection.outputChannel);
1348+
target->setNull(row, true);
13241349
}
13251350
}
1326-
1327-
for (auto& projection : leftProjections_) {
1328-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1329-
target->setNull(row, true);
1330-
}
1331-
1332-
for (auto& projection : rightProjections_) {
1333-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1334-
target->setNull(row + 1, true);
1335-
}
1336-
} else if (!isRightJoin(joinType_)) {
1337-
for (auto& projection : rightProjections_) {
1338-
auto& target = output->childAt(projection.outputChannel);
1339-
target->setNull(row, true);
1340-
}
1341-
} else {
1342-
for (auto& projection : leftProjections_) {
1343-
auto& target = output->childAt(projection.outputChannel);
1344-
target->setNull(row, true);
1345-
}
13461351
}
13471352
};
13481353

@@ -1353,12 +1358,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13531358

13541359
joinTracker_->processFilterResult(i, passed, onMiss);
13551360

1356-
if (isAntiJoin(joinType_)) {
1357-
if (!passed) {
1358-
rawIndices[numPassed++] = i;
1359-
}
1360-
} else {
1361-
if (passed) {
1361+
if (!isAntiJoin(joinType_)) {
1362+
if (passed && !joinTracker_->isRightJoinForFullOuter(i)) {
13621363
rawIndices[numPassed++] = i;
13631364
}
13641365
}
@@ -1371,26 +1372,30 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13711372

13721373
// Every time we start a new left key match, `processFilterResult()` will
13731374
// check if at least one row from the previous match passed the filter. If
1374-
// none did, it calls onMiss to add a record with null right projections to
1375-
// the output.
1375+
// none did, it calls onMiss to add a record with null right projections
1376+
// to the output.
13761377
//
13771378
// Before we leave the current buffer, since we may not have seen the next
1378-
// left key match yet, the last key match may still be pending to produce a
1379-
// row (because `processFilterResult()` was not called yet).
1379+
// left key match yet, the last key match may still be pending to produce
1380+
// a row (because `processFilterResult()` was not called yet).
13801381
//
13811382
// To handle this, we need to call `noMoreFilterResults()` unless the
1382-
// same current left key match may continue in the next buffer. So there are
1383-
// two cases to check:
1383+
// same current left key match may continue in the next buffer. So there
1384+
// are two cases to check:
13841385
//
1385-
// 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
1386-
// a different key match.
1386+
// 1. If leftMatch_ is nullopt, there for sure the next buffer will
1387+
// contain a different key match.
13871388
//
13881389
// 2. leftMatch_ may not be nullopt, but may be related to a different
13891390
// (subsequent) left key. So we check if the last row in the batch has the
13901391
// same left row number as the last key match.
13911392
if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch(numRows - 1)) {
13921393
joinTracker_->noMoreFilterResults(onMiss);
13931394
}
1395+
1396+
if (isAntiJoin(joinType_) && leftMatch_ && !previousLeftMatch_) {
1397+
joinTracker_->noMoreFilterResults(onMiss);
1398+
}
13941399
} else {
13951400
filterRows_.resize(numRows);
13961401
filterRows_.setAll();
@@ -1412,17 +1417,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14121417

14131418
if (numPassed == numRows) {
14141419
// All rows passed.
1415-
if (fullOuterOutput) {
1416-
return fullOuterOutput;
1417-
}
14181420
return output;
14191421
}
14201422

14211423
// Some, but not all rows passed.
1422-
if (fullOuterOutput) {
1423-
return wrap(numPassed, indices, fullOuterOutput);
1424-
}
1425-
14261424
return wrap(numPassed, indices, output);
14271425
}
14281426

0 commit comments

Comments
 (0)