@@ -113,9 +113,10 @@ void MergeJoin::initialize() {
113113 isSemiFilterJoin (joinType_)) {
114114 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
115115 }
116- } else if (joinNode_->isAntiJoin ()) {
116+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
117117 // Anti join needs to track the left side rows that have no match on the
118- // right.
118+ // right. Full outer join needs to track the right side rows that have no
119+ // match on the left.
119120 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
120121 }
121122
@@ -383,7 +384,8 @@ bool MergeJoin::tryAddOutputRow(
383384 const RowVectorPtr& leftBatch,
384385 vector_size_t leftRow,
385386 const RowVectorPtr& rightBatch,
386- vector_size_t rightRow) {
387+ vector_size_t rightRow,
388+ bool isRightJoinForFullOuter) {
387389 if (outputSize_ == outputBatchSize_) {
388390 return false ;
389391 }
@@ -417,12 +419,15 @@ bool MergeJoin::tryAddOutputRow(
417419 filterRightInputProjections_);
418420
419421 if (joinTracker_) {
420- if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
422+ if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_) ||
423+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
421424 // Record right-side row with a match on the left-side.
422- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
425+ joinTracker_->addMatch (
426+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
423427 } else {
424428 // Record left-side row with a match on the right-side.
425- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
429+ joinTracker_->addMatch (
430+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
426431 }
427432 }
428433 }
@@ -432,7 +437,8 @@ bool MergeJoin::tryAddOutputRow(
432437 if (isAntiJoin (joinType_)) {
433438 VELOX_CHECK (joinTracker_.has_value ());
434439 // Record left-side row with a match on the right-side.
435- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
440+ joinTracker_->addMatch (
441+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
436442 }
437443
438444 ++outputSize_;
@@ -450,14 +456,15 @@ bool MergeJoin::prepareOutput(
450456 return true ;
451457 }
452458
453- if (isRightJoin (joinType_) && right != currentRight_) {
454- return true ;
455- }
456-
457459 // If there is a new right, we need to flatten the dictionary.
458460 if (!isRightFlattened_ && right && currentRight_ != right) {
459461 flattenRightProjections ();
460462 }
463+
464+ if (right != currentRight_) {
465+ return true ;
466+ }
467+
461468 return false ;
462469 }
463470
@@ -480,11 +487,15 @@ bool MergeJoin::prepareOutput(
480487 }
481488 } else {
482489 for (const auto & projection : leftProjections_) {
490+ auto column = left->childAt (projection.inputChannel );
491+ // Flatten the left column if the column already is DictionaryVector.
492+ if (column->wrappedVector ()->encoding () ==
493+ VectorEncoding::Simple::DICTIONARY) {
494+ BaseVector::flattenVector (column);
495+ }
496+ column->clearContainingLazyAndWrapped ();
483497 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
484- {},
485- leftOutputIndices_,
486- outputBatchSize_,
487- left->childAt (projection.inputChannel ));
498+ {}, leftOutputIndices_, outputBatchSize_, column);
488499 }
489500 }
490501 currentLeft_ = left;
@@ -500,11 +511,10 @@ bool MergeJoin::prepareOutput(
500511 isRightFlattened_ = true ;
501512 } else {
502513 for (const auto & projection : rightProjections_) {
514+ auto column = right->childAt (projection.inputChannel );
515+ column->clearContainingLazyAndWrapped ();
503516 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
504- {},
505- rightOutputIndices_,
506- outputBatchSize_,
507- right->childAt (projection.inputChannel ));
517+ {}, rightOutputIndices_, outputBatchSize_, column);
508518 }
509519 isRightFlattened_ = false ;
510520 }
@@ -568,6 +578,39 @@ bool MergeJoin::prepareOutput(
568578bool MergeJoin::addToOutput () {
569579 if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
570580 return addToOutputForRightJoin ();
581+ } else if (isFullJoin (joinType_) && filter_) {
582+ if (!leftForRightJoinMatch_) {
583+ leftForRightJoinMatch_ = leftMatch_;
584+ rightForRightJoinMatch_ = rightMatch_;
585+ }
586+
587+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
588+ auto left = addToOutputForLeftJoin ();
589+ if (!leftMatch_) {
590+ leftJoinForFullFinished_ = true ;
591+ }
592+ if (left) {
593+ if (!leftMatch_) {
594+ leftMatch_ = leftForRightJoinMatch_;
595+ rightMatch_ = rightForRightJoinMatch_;
596+ }
597+
598+ return true ;
599+ }
600+ }
601+
602+ if (!leftMatch_ && !rightJoinForFullFinished_) {
603+ leftMatch_ = leftForRightJoinMatch_;
604+ rightMatch_ = rightForRightJoinMatch_;
605+ rightJoinForFullFinished_ = true ;
606+ }
607+
608+ auto right = addToOutputForRightJoin ();
609+
610+ leftForRightJoinMatch_ = leftMatch_;
611+ rightForRightJoinMatch_ = rightMatch_;
612+
613+ return right;
571614 } else {
572615 return addToOutputForLeftJoin ();
573616 }
@@ -660,7 +703,13 @@ bool MergeJoin::addToOutputImpl() {
660703 } else {
661704 for (auto innerRow = innerStartRow; innerRow < innerEndRow;
662705 ++innerRow) {
663- if (!tryAddOutputRow (leftBatch, innerRow, rightBatch, outerRow)) {
706+ const auto isRightJoinForFullOuter = isFullJoin (joinType_);
707+ if (!tryAddOutputRow (
708+ leftBatch,
709+ innerRow,
710+ rightBatch,
711+ outerRow,
712+ isRightJoinForFullOuter)) {
664713 outerMatch->setCursor (outerBatchIndex, outerRow);
665714 innerMatch->setCursor (innerBatchIndex, innerRow);
666715 return true ;
@@ -931,7 +980,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
931980 isFullJoin (joinType_)) {
932981 // If output_ is currently wrapping a different buffer, return it
933982 // first.
934- if (prepareOutput (input_, nullptr )) {
983+ if (prepareOutput (input_, rightInput_ )) {
935984 output_->resize (outputSize_);
936985 return std::move (output_);
937986 }
@@ -956,7 +1005,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9561005 if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
9571006 // If output_ is currently wrapping a different buffer, return it
9581007 // first.
959- if (prepareOutput (nullptr , rightInput_)) {
1008+ if (prepareOutput (input_ , rightInput_)) {
9601009 output_->resize (outputSize_);
9611010 return std::move (output_);
9621011 }
@@ -1003,6 +1052,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
10031052 endRightRow < rightInput_->size (),
10041053 std::nullopt };
10051054
1055+ leftJoinForFullFinished_ = false ;
1056+ rightJoinForFullFinished_ = false ;
10061057 if (!leftMatch_->complete || !rightMatch_->complete ) {
10071058 if (!leftMatch_->complete ) {
10081059 // Need to continue looking for the end of match.
@@ -1264,8 +1315,6 @@ void MergeJoin::clearRightInput() {
12641315RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
12651316 const auto numRows = output->size ();
12661317
1267- RowVectorPtr fullOuterOutput = nullptr ;
1268-
12691318 BufferPtr indices = allocateIndices (numRows, pool ());
12701319 auto * rawIndices = indices->asMutable <vector_size_t >();
12711320 vector_size_t numPassed = 0 ;
@@ -1282,84 +1331,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12821331
12831332 // If all matches for a given left-side row fail the filter, add a row to
12841333 // the output with nulls for the right-side columns.
1285- const auto onMiss = [&](auto row) {
1334+ const auto onMiss = [&](auto row, bool isRightJoinForFullOuter ) {
12861335 if (isSemiFilterJoin (joinType_)) {
12871336 return ;
12881337 }
12891338 rawIndices[numPassed++] = row;
12901339
1291- if (isFullJoin (joinType_)) {
1292- // For filtered rows, it is necessary to insert additional data
1293- // to ensure the result set is complete. Specifically, we
1294- // need to generate two records: one record containing the
1295- // columns from the left table along with nulls for the
1296- // right table, and another record containing the columns
1297- // from the right table along with nulls for the left table.
1298- // For instance, the current output is filtered based on the condition
1299- // t > 1.
1300-
1301- // 1, 1
1302- // 2, 2
1303- // 3, 3
1304-
1305- // In this scenario, we need to additionally insert a record 1, 1.
1306- // Subsequently, we will set the values of the columns on the left to
1307- // null and the values of the columns on the right to null as well. By
1308- // doing so, we will obtain the final result set.
1309-
1310- // 1, null
1311- // null, 1
1312- // 2, 2
1313- // 3, 3
1314- fullOuterOutput = BaseVector::create<RowVector>(
1315- output->type (), output->size () + 1 , pool ());
1316-
1317- for (auto i = 0 ; i < row + 1 ; ++i) {
1318- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1319- fullOuterOutput->childAt (j)->copy (
1320- output->childAt (j).get (), i, i, 1 );
1340+ if (!isRightJoin (joinType_)) {
1341+ if (isFullJoin (joinType_) && isRightJoinForFullOuter) {
1342+ for (auto & projection : leftProjections_) {
1343+ auto target = output->childAt (projection.outputChannel );
1344+ target->setNull (row, true );
13211345 }
1322- }
1323-
1324- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1325- fullOuterOutput->childAt (j)->copy (
1326- output->childAt (j).get (), row + 1 , row, 1 );
1327- }
1328-
1329- for (auto i = row + 1 ; i < output->size (); ++i) {
1330- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1331- fullOuterOutput->childAt (j)->copy (
1332- output->childAt (j).get (), i + 1 , i, 1 );
1346+ } else {
1347+ for (auto & projection : rightProjections_) {
1348+ auto target = output->childAt (projection.outputChannel );
1349+ target->setNull (row, true );
13331350 }
13341351 }
1335-
1336- for (auto & projection : leftProjections_) {
1337- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1338- target->setNull (row, true );
1339- }
1340-
1341- for (auto & projection : rightProjections_) {
1342- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1343- target->setNull (row + 1 , true );
1344- }
1345- } else if (!isRightJoin (joinType_)) {
1346- for (auto & projection : rightProjections_) {
1347- auto & target = output->childAt (projection.outputChannel );
1348- target->setNull (row, true );
1349- }
13501352 } else {
13511353 for (auto & projection : leftProjections_) {
1352- auto & target = output->childAt (projection.outputChannel );
1354+ auto target = output->childAt (projection.outputChannel );
13531355 target->setNull (row, true );
13541356 }
13551357 }
13561358 };
13571359
13581360 auto onMatch = [&](auto row, bool firstMatch) {
1359- const bool isNonSemiAntiJoin =
1360- !isSemiFilterJoin (joinType_) && !isAntiJoin (joinType_);
1361+ const bool isFullLeftJoin =
1362+ isFullJoin (joinType_) && !joinTracker_->isRightJoinForFullOuter (row);
1363+
1364+ const bool isNonSemiAntiFullJoin = !isSemiFilterJoin (joinType_) &&
1365+ !isAntiJoin (joinType_) && !isFullJoin (joinType_);
13611366
1362- if ((isSemiFilterJoin (joinType_) && firstMatch) || isNonSemiAntiJoin) {
1367+ if ((isSemiFilterJoin (joinType_) && firstMatch) ||
1368+ isNonSemiAntiFullJoin || isFullLeftJoin) {
13631369 rawIndices[numPassed++] = row;
13641370 }
13651371 };
@@ -1420,17 +1426,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14201426
14211427 if (numPassed == numRows) {
14221428 // All rows passed.
1423- if (fullOuterOutput) {
1424- return fullOuterOutput;
1425- }
14261429 return output;
14271430 }
14281431
14291432 // Some, but not all rows passed.
1430- if (fullOuterOutput) {
1431- return wrap (numPassed, indices, fullOuterOutput);
1432- }
1433-
14341433 return wrap (numPassed, indices, output);
14351434}
14361435
0 commit comments