Skip to content

Commit d885538

Browse files
committed
fix: Fix smj result mismatch issue in semi, anit and full outer join
Signed-off-by: Yuan <[email protected]>
1 parent 0beb95f commit d885538

File tree

3 files changed

+155
-97
lines changed

3 files changed

+155
-97
lines changed

velox/exec/MergeJoin.cpp

Lines changed: 91 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,10 @@ void MergeJoin::initialize() {
113113
isSemiFilterJoin(joinType_)) {
114114
joinTracker_ = JoinTracker(outputBatchSize_, pool());
115115
}
116-
} else if (joinNode_->isAntiJoin()) {
116+
} else if (joinNode_->isAntiJoin() || joinNode_->isFullJoin()) {
117117
// Anti join needs to track the left side rows that have no match on the
118-
// right.
118+
// right. Full outer join needs to track the right side rows that have no
119+
// match on the left.
119120
joinTracker_ = JoinTracker(outputBatchSize_, pool());
120121
}
121122

@@ -383,7 +384,8 @@ bool MergeJoin::tryAddOutputRow(
383384
const RowVectorPtr& leftBatch,
384385
vector_size_t leftRow,
385386
const RowVectorPtr& rightBatch,
386-
vector_size_t rightRow) {
387+
vector_size_t rightRow,
388+
bool isRightJoinForFullOuter) {
387389
if (outputSize_ == outputBatchSize_) {
388390
return false;
389391
}
@@ -417,12 +419,15 @@ bool MergeJoin::tryAddOutputRow(
417419
filterRightInputProjections_);
418420

419421
if (joinTracker_) {
420-
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
422+
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) ||
423+
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
421424
// Record right-side row with a match on the left-side.
422-
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
425+
joinTracker_->addMatch(
426+
rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
423427
} else {
424428
// Record left-side row with a match on the right-side.
425-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
429+
joinTracker_->addMatch(
430+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
426431
}
427432
}
428433
}
@@ -432,7 +437,8 @@ bool MergeJoin::tryAddOutputRow(
432437
if (isAntiJoin(joinType_)) {
433438
VELOX_CHECK(joinTracker_.has_value());
434439
// Record left-side row with a match on the right-side.
435-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
440+
joinTracker_->addMatch(
441+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
436442
}
437443

438444
++outputSize_;
@@ -450,14 +456,15 @@ bool MergeJoin::prepareOutput(
450456
return true;
451457
}
452458

453-
if (isRightJoin(joinType_) && right != currentRight_) {
454-
return true;
455-
}
456-
457459
// If there is a new right, we need to flatten the dictionary.
458460
if (!isRightFlattened_ && right && currentRight_ != right) {
459461
flattenRightProjections();
460462
}
463+
464+
if (right != currentRight_) {
465+
return true;
466+
}
467+
461468
return false;
462469
}
463470

@@ -480,11 +487,15 @@ bool MergeJoin::prepareOutput(
480487
}
481488
} else {
482489
for (const auto& projection : leftProjections_) {
490+
auto column = left->childAt(projection.inputChannel);
491+
// Flatten the left column if the column already is DictionaryVector.
492+
if (column->wrappedVector()->encoding() ==
493+
VectorEncoding::Simple::DICTIONARY) {
494+
BaseVector::flattenVector(column);
495+
}
496+
column->clearContainingLazyAndWrapped();
483497
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
484-
{},
485-
leftOutputIndices_,
486-
outputBatchSize_,
487-
left->childAt(projection.inputChannel));
498+
{}, leftOutputIndices_, outputBatchSize_, column);
488499
}
489500
}
490501
currentLeft_ = left;
@@ -500,11 +511,10 @@ bool MergeJoin::prepareOutput(
500511
isRightFlattened_ = true;
501512
} else {
502513
for (const auto& projection : rightProjections_) {
514+
auto column = right->childAt(projection.inputChannel);
515+
column->clearContainingLazyAndWrapped();
503516
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
504-
{},
505-
rightOutputIndices_,
506-
outputBatchSize_,
507-
right->childAt(projection.inputChannel));
517+
{}, rightOutputIndices_, outputBatchSize_, column);
508518
}
509519
isRightFlattened_ = false;
510520
}
@@ -568,6 +578,39 @@ bool MergeJoin::prepareOutput(
568578
bool MergeJoin::addToOutput() {
569579
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
570580
return addToOutputForRightJoin();
581+
} else if (isFullJoin(joinType_) && filter_) {
582+
if (!leftForRightJoinMatch_) {
583+
leftForRightJoinMatch_ = leftMatch_;
584+
rightForRightJoinMatch_ = rightMatch_;
585+
}
586+
587+
if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
588+
auto left = addToOutputForLeftJoin();
589+
if (!leftMatch_) {
590+
leftJoinForFullFinished_ = true;
591+
}
592+
if (left) {
593+
if (!leftMatch_) {
594+
leftMatch_ = leftForRightJoinMatch_;
595+
rightMatch_ = rightForRightJoinMatch_;
596+
}
597+
598+
return true;
599+
}
600+
}
601+
602+
if (!leftMatch_ && !rightJoinForFullFinished_) {
603+
leftMatch_ = leftForRightJoinMatch_;
604+
rightMatch_ = rightForRightJoinMatch_;
605+
rightJoinForFullFinished_ = true;
606+
}
607+
608+
auto right = addToOutputForRightJoin();
609+
610+
leftForRightJoinMatch_ = leftMatch_;
611+
rightForRightJoinMatch_ = rightMatch_;
612+
613+
return right;
571614
} else {
572615
return addToOutputForLeftJoin();
573616
}
@@ -660,7 +703,13 @@ bool MergeJoin::addToOutputImpl() {
660703
} else {
661704
for (auto innerRow = innerStartRow; innerRow < innerEndRow;
662705
++innerRow) {
663-
if (!tryAddOutputRow(leftBatch, innerRow, rightBatch, outerRow)) {
706+
const auto isRightJoinForFullOuter = isFullJoin(joinType_);
707+
if (!tryAddOutputRow(
708+
leftBatch,
709+
innerRow,
710+
rightBatch,
711+
outerRow,
712+
isRightJoinForFullOuter)) {
664713
outerMatch->setCursor(outerBatchIndex, outerRow);
665714
innerMatch->setCursor(innerBatchIndex, innerRow);
666715
return true;
@@ -931,7 +980,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
931980
isFullJoin(joinType_)) {
932981
// If output_ is currently wrapping a different buffer, return it
933982
// first.
934-
if (prepareOutput(input_, nullptr)) {
983+
if (prepareOutput(input_, rightInput_)) {
935984
output_->resize(outputSize_);
936985
return std::move(output_);
937986
}
@@ -956,7 +1005,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9561005
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
9571006
// If output_ is currently wrapping a different buffer, return it
9581007
// first.
959-
if (prepareOutput(nullptr, rightInput_)) {
1008+
if (prepareOutput(input_, rightInput_)) {
9601009
output_->resize(outputSize_);
9611010
return std::move(output_);
9621011
}
@@ -1003,6 +1052,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
10031052
endRightRow < rightInput_->size(),
10041053
std::nullopt};
10051054

1055+
leftJoinForFullFinished_ = false;
1056+
rightJoinForFullFinished_ = false;
10061057
if (!leftMatch_->complete || !rightMatch_->complete) {
10071058
if (!leftMatch_->complete) {
10081059
// Need to continue looking for the end of match.
@@ -1264,8 +1315,6 @@ void MergeJoin::clearRightInput() {
12641315
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12651316
const auto numRows = output->size();
12661317

1267-
RowVectorPtr fullOuterOutput = nullptr;
1268-
12691318
BufferPtr indices = allocateIndices(numRows, pool());
12701319
auto* rawIndices = indices->asMutable<vector_size_t>();
12711320
vector_size_t numPassed = 0;
@@ -1282,84 +1331,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12821331

12831332
// If all matches for a given left-side row fail the filter, add a row to
12841333
// the output with nulls for the right-side columns.
1285-
const auto onMiss = [&](auto row) {
1334+
const auto onMiss = [&](auto row, bool isRightJoinForFullOuter) {
12861335
if (isSemiFilterJoin(joinType_)) {
12871336
return;
12881337
}
12891338
rawIndices[numPassed++] = row;
12901339

1291-
if (isFullJoin(joinType_)) {
1292-
// For filtered rows, it is necessary to insert additional data
1293-
// to ensure the result set is complete. Specifically, we
1294-
// need to generate two records: one record containing the
1295-
// columns from the left table along with nulls for the
1296-
// right table, and another record containing the columns
1297-
// from the right table along with nulls for the left table.
1298-
// For instance, the current output is filtered based on the condition
1299-
// t > 1.
1300-
1301-
// 1, 1
1302-
// 2, 2
1303-
// 3, 3
1304-
1305-
// In this scenario, we need to additionally insert a record 1, 1.
1306-
// Subsequently, we will set the values of the columns on the left to
1307-
// null and the values of the columns on the right to null as well. By
1308-
// doing so, we will obtain the final result set.
1309-
1310-
// 1, null
1311-
// null, 1
1312-
// 2, 2
1313-
// 3, 3
1314-
fullOuterOutput = BaseVector::create<RowVector>(
1315-
output->type(), output->size() + 1, pool());
1316-
1317-
for (auto i = 0; i < row + 1; ++i) {
1318-
for (auto j = 0; j < output->type()->size(); ++j) {
1319-
fullOuterOutput->childAt(j)->copy(
1320-
output->childAt(j).get(), i, i, 1);
1340+
if (!isRightJoin(joinType_)) {
1341+
if (isFullJoin(joinType_) && isRightJoinForFullOuter) {
1342+
for (auto& projection : leftProjections_) {
1343+
auto target = output->childAt(projection.outputChannel);
1344+
target->setNull(row, true);
13211345
}
1322-
}
1323-
1324-
for (auto j = 0; j < output->type()->size(); ++j) {
1325-
fullOuterOutput->childAt(j)->copy(
1326-
output->childAt(j).get(), row + 1, row, 1);
1327-
}
1328-
1329-
for (auto i = row + 1; i < output->size(); ++i) {
1330-
for (auto j = 0; j < output->type()->size(); ++j) {
1331-
fullOuterOutput->childAt(j)->copy(
1332-
output->childAt(j).get(), i + 1, i, 1);
1346+
} else {
1347+
for (auto& projection : rightProjections_) {
1348+
auto target = output->childAt(projection.outputChannel);
1349+
target->setNull(row, true);
13331350
}
13341351
}
1335-
1336-
for (auto& projection : leftProjections_) {
1337-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1338-
target->setNull(row, true);
1339-
}
1340-
1341-
for (auto& projection : rightProjections_) {
1342-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1343-
target->setNull(row + 1, true);
1344-
}
1345-
} else if (!isRightJoin(joinType_)) {
1346-
for (auto& projection : rightProjections_) {
1347-
auto& target = output->childAt(projection.outputChannel);
1348-
target->setNull(row, true);
1349-
}
13501352
} else {
13511353
for (auto& projection : leftProjections_) {
1352-
auto& target = output->childAt(projection.outputChannel);
1354+
auto target = output->childAt(projection.outputChannel);
13531355
target->setNull(row, true);
13541356
}
13551357
}
13561358
};
13571359

13581360
auto onMatch = [&](auto row, bool firstMatch) {
1359-
const bool isNonSemiAntiJoin =
1360-
!isSemiFilterJoin(joinType_) && !isAntiJoin(joinType_);
1361+
const bool isFullLeftJoin =
1362+
isFullJoin(joinType_) && !joinTracker_->isRightJoinForFullOuter(row);
1363+
1364+
const bool isNonSemiAntiFullJoin = !isSemiFilterJoin(joinType_) &&
1365+
!isAntiJoin(joinType_) && !isFullJoin(joinType_);
13611366

1362-
if ((isSemiFilterJoin(joinType_) && firstMatch) || isNonSemiAntiJoin) {
1367+
if ((isSemiFilterJoin(joinType_) && firstMatch) ||
1368+
isNonSemiAntiFullJoin || isFullLeftJoin) {
13631369
rawIndices[numPassed++] = row;
13641370
}
13651371
};
@@ -1420,17 +1426,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14201426

14211427
if (numPassed == numRows) {
14221428
// All rows passed.
1423-
if (fullOuterOutput) {
1424-
return fullOuterOutput;
1425-
}
14261429
return output;
14271430
}
14281431

14291432
// Some, but not all rows passed.
1430-
if (fullOuterOutput) {
1431-
return wrap(numPassed, indices, fullOuterOutput);
1432-
}
1433-
14341433
return wrap(numPassed, indices, output);
14351434
}
14361435

0 commit comments

Comments
 (0)