@@ -113,9 +113,10 @@ void MergeJoin::initialize() {
113113 isSemiFilterJoin (joinType_)) {
114114 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
115115 }
116- } else if (joinNode_->isAntiJoin ()) {
116+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
117117 // Anti join needs to track the left side rows that have no match on the
118- // right.
118+ // right. Full outer join needs to track the right side rows that have no
119+ // match on the left.
119120 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
120121 }
121122
@@ -392,7 +393,8 @@ bool MergeJoin::tryAddOutputRow(
392393 const RowVectorPtr& leftBatch,
393394 vector_size_t leftRow,
394395 const RowVectorPtr& rightBatch,
395- vector_size_t rightRow) {
396+ vector_size_t rightRow,
397+ bool isRightJoinForFullOuter) {
396398 if (outputSize_ == outputBatchSize_) {
397399 return false ;
398400 }
@@ -426,12 +428,15 @@ bool MergeJoin::tryAddOutputRow(
426428 filterRightInputProjections_);
427429
428430 if (joinTracker_) {
429- if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
431+ if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_) ||
432+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
430433 // Record right-side row with a match on the left-side.
431- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
434+ joinTracker_->addMatch (
435+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
432436 } else {
433437 // Record left-side row with a match on the right-side.
434- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
438+ joinTracker_->addMatch (
439+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
435440 }
436441 }
437442 }
@@ -441,7 +446,8 @@ bool MergeJoin::tryAddOutputRow(
441446 if (isAntiJoin (joinType_)) {
442447 VELOX_CHECK (joinTracker_.has_value ());
443448 // Record left-side row with a match on the right-side.
444- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
449+ joinTracker_->addMatch (
450+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
445451 }
446452
447453 ++outputSize_;
@@ -459,14 +465,15 @@ bool MergeJoin::prepareOutput(
459465 return true ;
460466 }
461467
462- if (isRightJoin (joinType_) && right != currentRight_) {
463- return true ;
464- }
465-
466468 // If there is a new right, we need to flatten the dictionary.
467469 if (!isRightFlattened_ && right && currentRight_ != right) {
468470 flattenRightProjections ();
469471 }
472+
473+ if (right != currentRight_) {
474+ return true ;
475+ }
476+
470477 return false ;
471478 }
472479
@@ -489,11 +496,15 @@ bool MergeJoin::prepareOutput(
489496 }
490497 } else {
491498 for (const auto & projection : leftProjections_) {
499+ auto column = left->childAt (projection.inputChannel );
500+ // Flatten the left column if the column already is DictionaryVector.
501+ if (column->wrappedVector ()->encoding () ==
502+ VectorEncoding::Simple::DICTIONARY) {
503+ BaseVector::flattenVector (column);
504+ }
505+ column->clearContainingLazyAndWrapped ();
492506 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
493- {},
494- leftOutputIndices_,
495- outputBatchSize_,
496- left->childAt (projection.inputChannel ));
507+ {}, leftOutputIndices_, outputBatchSize_, column);
497508 }
498509 }
499510 currentLeft_ = left;
@@ -509,11 +520,10 @@ bool MergeJoin::prepareOutput(
509520 isRightFlattened_ = true ;
510521 } else {
511522 for (const auto & projection : rightProjections_) {
523+ auto column = right->childAt (projection.inputChannel );
524+ column->clearContainingLazyAndWrapped ();
512525 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
513- {},
514- rightOutputIndices_,
515- outputBatchSize_,
516- right->childAt (projection.inputChannel ));
526+ {}, rightOutputIndices_, outputBatchSize_, column);
517527 }
518528 isRightFlattened_ = false ;
519529 }
@@ -577,6 +587,39 @@ bool MergeJoin::prepareOutput(
577587bool MergeJoin::addToOutput () {
578588 if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
579589 return addToOutputForRightJoin ();
590+ } else if (isFullJoin (joinType_) && filter_) {
591+ if (!leftForRightJoinMatch_) {
592+ leftForRightJoinMatch_ = leftMatch_;
593+ rightForRightJoinMatch_ = rightMatch_;
594+ }
595+
596+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
597+ auto left = addToOutputForLeftJoin ();
598+ if (!leftMatch_) {
599+ leftJoinForFullFinished_ = true ;
600+ }
601+ if (left) {
602+ if (!leftMatch_) {
603+ leftMatch_ = leftForRightJoinMatch_;
604+ rightMatch_ = rightForRightJoinMatch_;
605+ }
606+
607+ return true ;
608+ }
609+ }
610+
611+ if (!leftMatch_ && !rightJoinForFullFinished_) {
612+ leftMatch_ = leftForRightJoinMatch_;
613+ rightMatch_ = rightForRightJoinMatch_;
614+ rightJoinForFullFinished_ = true ;
615+ }
616+
617+ auto right = addToOutputForRightJoin ();
618+
619+ leftForRightJoinMatch_ = leftMatch_;
620+ rightForRightJoinMatch_ = rightMatch_;
621+
622+ return right;
580623 } else {
581624 return addToOutputForLeftJoin ();
582625 }
@@ -669,7 +712,13 @@ bool MergeJoin::addToOutputImpl() {
669712 } else {
670713 for (auto innerRow = innerStartRow; innerRow < innerEndRow;
671714 ++innerRow) {
672- if (!tryAddOutputRow (leftBatch, innerRow, rightBatch, outerRow)) {
715+ const auto isRightJoinForFullOuter = isFullJoin (joinType_);
716+ if (!tryAddOutputRow (
717+ leftBatch,
718+ innerRow,
719+ rightBatch,
720+ outerRow,
721+ isRightJoinForFullOuter)) {
673722 outerMatch->setCursor (outerBatchIndex, outerRow);
674723 innerMatch->setCursor (innerBatchIndex, innerRow);
675724 return true ;
@@ -938,7 +987,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
938987 isFullJoin (joinType_)) {
939988 // If output_ is currently wrapping a different buffer, return it
940989 // first.
941- if (prepareOutput (input_, nullptr )) {
990+ if (prepareOutput (input_, rightInput_ )) {
942991 output_->resize (outputSize_);
943992 return std::move (output_);
944993 }
@@ -963,7 +1012,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9631012 if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
9641013 // If output_ is currently wrapping a different buffer, return it
9651014 // first.
966- if (prepareOutput (nullptr , rightInput_)) {
1015+ if (prepareOutput (input_ , rightInput_)) {
9671016 output_->resize (outputSize_);
9681017 return std::move (output_);
9691018 }
@@ -1009,6 +1058,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
10091058 endRightRow < rightInput_->size (),
10101059 std::nullopt };
10111060
1061+ leftJoinForFullFinished_ = false ;
1062+ rightJoinForFullFinished_ = false ;
10121063 if (!leftMatch_->complete || !rightMatch_->complete ) {
10131064 if (!leftMatch_->complete ) {
10141065 // Need to continue looking for the end of match.
@@ -1270,8 +1321,6 @@ void MergeJoin::clearRightInput() {
12701321RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
12711322 const auto numRows = output->size ();
12721323
1273- RowVectorPtr fullOuterOutput = nullptr ;
1274-
12751324 BufferPtr indices = allocateIndices (numRows, pool ());
12761325 auto * rawIndices = indices->asMutable <vector_size_t >();
12771326 vector_size_t numPassed = 0 ;
@@ -1288,84 +1337,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12881337
12891338 // If all matches for a given left-side row fail the filter, add a row to
12901339 // the output with nulls for the right-side columns.
1291- const auto onMiss = [&](auto row) {
1340+ const auto onMiss = [&](auto row, bool isRightJoinForFullOuter ) {
12921341 if (isSemiFilterJoin (joinType_)) {
12931342 return ;
12941343 }
12951344 rawIndices[numPassed++] = row;
12961345
1297- if (isFullJoin (joinType_)) {
1298- // For filtered rows, it is necessary to insert additional data
1299- // to ensure the result set is complete. Specifically, we
1300- // need to generate two records: one record containing the
1301- // columns from the left table along with nulls for the
1302- // right table, and another record containing the columns
1303- // from the right table along with nulls for the left table.
1304- // For instance, the current output is filtered based on the condition
1305- // t > 1.
1306-
1307- // 1, 1
1308- // 2, 2
1309- // 3, 3
1310-
1311- // In this scenario, we need to additionally insert a record 1, 1.
1312- // Subsequently, we will set the values of the columns on the left to
1313- // null and the values of the columns on the right to null as well. By
1314- // doing so, we will obtain the final result set.
1315-
1316- // 1, null
1317- // null, 1
1318- // 2, 2
1319- // 3, 3
1320- fullOuterOutput = BaseVector::create<RowVector>(
1321- output->type (), output->size () + 1 , pool ());
1322-
1323- for (auto i = 0 ; i < row + 1 ; ++i) {
1324- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1325- fullOuterOutput->childAt (j)->copy (
1326- output->childAt (j).get (), i, i, 1 );
1346+ if (!isRightJoin (joinType_)) {
1347+ if (isFullJoin (joinType_) && isRightJoinForFullOuter) {
1348+ for (auto & projection : leftProjections_) {
1349+ auto target = output->childAt (projection.outputChannel );
1350+ target->setNull (row, true );
13271351 }
1328- }
1329-
1330- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1331- fullOuterOutput->childAt (j)->copy (
1332- output->childAt (j).get (), row + 1 , row, 1 );
1333- }
1334-
1335- for (auto i = row + 1 ; i < output->size (); ++i) {
1336- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1337- fullOuterOutput->childAt (j)->copy (
1338- output->childAt (j).get (), i + 1 , i, 1 );
1352+ } else {
1353+ for (auto & projection : rightProjections_) {
1354+ auto target = output->childAt (projection.outputChannel );
1355+ target->setNull (row, true );
13391356 }
13401357 }
1341-
1342- for (auto & projection : leftProjections_) {
1343- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1344- target->setNull (row, true );
1345- }
1346-
1347- for (auto & projection : rightProjections_) {
1348- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1349- target->setNull (row + 1 , true );
1350- }
1351- } else if (!isRightJoin (joinType_)) {
1352- for (auto & projection : rightProjections_) {
1353- auto & target = output->childAt (projection.outputChannel );
1354- target->setNull (row, true );
1355- }
13561358 } else {
13571359 for (auto & projection : leftProjections_) {
1358- auto & target = output->childAt (projection.outputChannel );
1360+ auto target = output->childAt (projection.outputChannel );
13591361 target->setNull (row, true );
13601362 }
13611363 }
13621364 };
13631365
13641366 auto onMatch = [&](auto row, bool firstMatch) {
1365- const bool isNonSemiAntiJoin =
1366- !isSemiFilterJoin (joinType_) && !isAntiJoin (joinType_);
1367+ const bool isFullLeftJoin =
1368+ isFullJoin (joinType_) && !joinTracker_->isRightJoinForFullOuter (row);
1369+
1370+ const bool isNonSemiAntiFullJoin = !isSemiFilterJoin (joinType_) &&
1371+ !isAntiJoin (joinType_) && !isFullJoin (joinType_);
13671372
1368- if ((isSemiFilterJoin (joinType_) && firstMatch) || isNonSemiAntiJoin) {
1373+ if ((isSemiFilterJoin (joinType_) && firstMatch) ||
1374+ isNonSemiAntiFullJoin || isFullLeftJoin) {
13691375 rawIndices[numPassed++] = row;
13701376 }
13711377 };
@@ -1426,17 +1432,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14261432
14271433 if (numPassed == numRows) {
14281434 // All rows passed.
1429- if (fullOuterOutput) {
1430- return fullOuterOutput;
1431- }
14321435 return output;
14331436 }
14341437
14351438 // Some, but not all rows passed.
1436- if (fullOuterOutput) {
1437- return wrap (numPassed, indices, fullOuterOutput);
1438- }
1439-
14401439 return wrap (numPassed, indices, output);
14411440}
14421441
0 commit comments