1818#include " velox/exec/Task.h"
1919#include " velox/expression/FieldReference.h"
2020
21+ #include < iostream>
22+
2123namespace facebook ::velox::exec {
2224
2325MergeJoin::MergeJoin (
@@ -93,7 +95,7 @@ void MergeJoin::initialize() {
9395 joinNode_->isRightJoin () || joinNode_->isFullJoin ()) {
9496 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
9597 }
96- } else if (joinNode_->isAntiJoin ()) {
98+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
9799 // Anti join needs to track the left side rows that have no match on the
98100 // right.
99101 joinTracker_ = JoinTracker (outputBatchSize_, pool ());
@@ -387,7 +389,8 @@ bool MergeJoin::tryAddOutputRow(
387389 const RowVectorPtr& leftBatch,
388390 vector_size_t leftRow,
389391 const RowVectorPtr& rightBatch,
390- vector_size_t rightRow) {
392+ vector_size_t rightRow,
393+ bool isRightJoinForFullOuter) {
391394 if (outputSize_ == outputBatchSize_) {
392395 return false ;
393396 }
@@ -421,12 +424,15 @@ bool MergeJoin::tryAddOutputRow(
421424 filterRightInputProjections_);
422425
423426 if (joinTracker_) {
424- if (isRightJoin (joinType_)) {
427+ if (isRightJoin (joinType_) ||
428+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
425429 // Record right-side row with a match on the left-side.
426- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
430+ joinTracker_->addMatch (
431+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
427432 } else {
428433 // Record left-side row with a match on the right-side.
429- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
434+ joinTracker_->addMatch (
435+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
430436 }
431437 }
432438 }
@@ -436,7 +442,8 @@ bool MergeJoin::tryAddOutputRow(
436442 if (isAntiJoin (joinType_)) {
437443 VELOX_CHECK (joinTracker_.has_value ());
438444 // Record left-side row with a match on the right-side.
439- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
445+ joinTracker_->addMatch (
446+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
440447 }
441448
442449 ++outputSize_;
@@ -455,14 +462,14 @@ bool MergeJoin::prepareOutput(
455462 return true ;
456463 }
457464
458- if (isRightJoin (joinType_) && right != currentRight_) {
459- return true ;
460- }
461-
462465 // If there is a new right, we need to flatten the dictionary.
463466 if (!isRightFlattened_ && right && currentRight_ != right) {
464467 flattenRightProjections ();
465468 }
469+
470+ if (right != currentRight_) {
471+ return true ;
472+ }
466473 return false ;
467474 }
468475
@@ -574,6 +581,39 @@ bool MergeJoin::prepareOutput(
574581bool MergeJoin::addToOutput () {
575582 if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
576583 return addToOutputForRightJoin ();
584+ } else if (isFullJoin (joinType_) && filter_) {
585+ if (!leftForRightJoinMatch_) {
586+ leftForRightJoinMatch_ = leftMatch_;
587+ rightForRightJoinMatch_ = rightMatch_;
588+ }
589+
590+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
591+ auto left = addToOutputForLeftJoin ();
592+ if (!leftMatch_) {
593+ leftJoinForFullFinished_ = true ;
594+ }
595+ if (left) {
596+ if (!leftMatch_) {
597+ leftMatch_ = leftForRightJoinMatch_;
598+ rightMatch_ = rightForRightJoinMatch_;
599+ }
600+
601+ return true ;
602+ }
603+ }
604+
605+ if (!leftMatch_ && !rightJoinForFullFinished_) {
606+ leftMatch_ = leftForRightJoinMatch_;
607+ rightMatch_ = rightForRightJoinMatch_;
608+ rightJoinForFullFinished_ = true ;
609+ }
610+
611+ auto right = addToOutputForRightJoin ();
612+
613+ leftForRightJoinMatch_ = leftMatch_;
614+ rightForRightJoinMatch_ = rightMatch_;
615+
616+ return right;
577617 } else {
578618 return addToOutputForLeftJoin ();
579619 }
@@ -720,7 +760,13 @@ bool MergeJoin::addToOutputForRightJoin() {
720760 }
721761
722762 for (auto j = leftStartRow; j < leftEndRow; ++j) {
723- if (!tryAddOutputRow (leftBatch, j, rightBatch, i)) {
763+ auto isRightJoinForFullOuter = false ;
764+ if (isFullJoin (joinType_)) {
765+ isRightJoinForFullOuter = true ;
766+ }
767+
768+ if (!tryAddOutputRow (
769+ leftBatch, j, rightBatch, i, isRightJoinForFullOuter)) {
724770 // If we run out of space in the current output_, we will need to
725771 // produce a buffer and continue processing left later. In this
726772 // case, we cannot leave left as a lazy vector, since we cannot have
@@ -818,7 +864,7 @@ RowVectorPtr MergeJoin::getOutput() {
818864 continue ;
819865 } else if (isAntiJoin (joinType_)) {
820866 output = filterOutputForAntiJoin (output);
821- if (output) {
867+ if (output != nullptr && output-> size () > 0 ) {
822868 return output;
823869 }
824870
@@ -926,6 +972,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
926972 // results from the current match.
927973 if (addToOutput ()) {
928974 return std::move (output_);
975+ } else {
976+ previousLeftMatch_ = leftMatch_;
929977 }
930978 }
931979
@@ -990,6 +1038,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
9901038
9911039 if (addToOutput ()) {
9921040 return std::move (output_);
1041+ } else {
1042+ previousLeftMatch_ = leftMatch_;
9931043 }
9941044 }
9951045
@@ -1134,7 +1184,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11341184 isFullJoin (joinType_)) {
11351185 // If output_ is currently wrapping a different buffer, return it
11361186 // first.
1137- if (prepareOutput (input_, nullptr )) {
1187+ if (prepareOutput (input_, rightInput_ )) {
11381188 output_->resize (outputSize_);
11391189 return std::move (output_);
11401190 }
@@ -1159,7 +1209,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11591209 if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
11601210 // If output_ is currently wrapping a different buffer, return it
11611211 // first.
1162- if (prepareOutput (nullptr , rightInput_)) {
1212+ if (prepareOutput (input_ , rightInput_)) {
11631213 output_->resize (outputSize_);
11641214 return std::move (output_);
11651215 }
@@ -1211,6 +1261,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
12111261 endRightRow < rightInput_->size (),
12121262 std::nullopt };
12131263
1264+ leftJoinForFullFinished_ = false ;
1265+ rightJoinForFullFinished_ = false ;
12141266 if (!leftMatch_->complete || !rightMatch_->complete ) {
12151267 if (!leftMatch_->complete ) {
12161268 // Need to continue looking for the end of match.
@@ -1239,6 +1291,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
12391291
12401292 if (addToOutput ()) {
12411293 return std::move (output_);
1294+ } else {
1295+ previousLeftMatch_ = leftMatch_;
12421296 }
12431297
12441298 if (!rightInput_) {
@@ -1255,8 +1309,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
12551309RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
12561310 const auto numRows = output->size ();
12571311
1258- RowVectorPtr fullOuterOutput = nullptr ;
1259-
12601312 BufferPtr indices = allocateIndices (numRows, pool ());
12611313 auto * rawIndices = indices->asMutable <vector_size_t >();
12621314 vector_size_t numPassed = 0 ;
@@ -1273,76 +1325,29 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12731325
12741326 // If all matches for a given left-side row fail the filter, add a row to
12751327 // the output with nulls for the right-side columns.
1276- const auto onMiss = [&](auto row) {
1277- if (isAntiJoin (joinType_)) {
1278- return ;
1279- }
1280- rawIndices[numPassed++] = row;
1281-
1282- if (isFullJoin (joinType_)) {
1283- // For filtered rows, it is necessary to insert additional data
1284- // to ensure the result set is complete. Specifically, we
1285- // need to generate two records: one record containing the
1286- // columns from the left table along with nulls for the
1287- // right table, and another record containing the columns
1288- // from the right table along with nulls for the left table.
1289- // For instance, the current output is filtered based on the condition
1290- // t > 1.
1291-
1292- // 1, 1
1293- // 2, 2
1294- // 3, 3
1295-
1296- // In this scenario, we need to additionally insert a record 1, 1.
1297- // Subsequently, we will set the values of the columns on the left to
1298- // null and the values of the columns on the right to null as well. By
1299- // doing so, we will obtain the final result set.
1300-
1301- // 1, null
1302- // null, 1
1303- // 2, 2
1304- // 3, 3
1305- fullOuterOutput = BaseVector::create<RowVector>(
1306- output->type (), output->size () + 1 , pool ());
1307-
1308- for (auto i = 0 ; i < row + 1 ; ++i) {
1309- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1310- fullOuterOutput->childAt (j)->copy (
1311- output->childAt (j).get (), i, i, 1 );
1328+ auto onMiss = [&](auto row, bool flag) {
1329+ if (!isLeftSemiFilterJoin (joinType_) &&
1330+ !isRightSemiFilterJoin (joinType_)) {
1331+ rawIndices[numPassed++] = row;
1332+
1333+ if (!isRightJoin (joinType_)) {
1334+ if (isFullJoin (joinType_) && flag) {
1335+ for (auto & projection : leftProjections_) {
1336+ auto target = output->childAt (projection.outputChannel );
1337+ target->setNull (row, true );
1338+ }
1339+ } else {
1340+ for (auto & projection : rightProjections_) {
1341+ auto target = output->childAt (projection.outputChannel );
1342+ target->setNull (row, true );
1343+ }
13121344 }
1313- }
1314-
1315- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1316- fullOuterOutput->childAt (j)->copy (
1317- output->childAt (j).get (), row + 1 , row, 1 );
1318- }
1319-
1320- for (auto i = row + 1 ; i < output->size (); ++i) {
1321- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1322- fullOuterOutput->childAt (j)->copy (
1323- output->childAt (j).get (), i + 1 , i, 1 );
1345+ } else {
1346+ for (auto & projection : leftProjections_) {
1347+ auto target = output->childAt (projection.outputChannel );
1348+ target->setNull (row, true );
13241349 }
13251350 }
1326-
1327- for (auto & projection : leftProjections_) {
1328- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1329- target->setNull (row, true );
1330- }
1331-
1332- for (auto & projection : rightProjections_) {
1333- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1334- target->setNull (row + 1 , true );
1335- }
1336- } else if (!isRightJoin (joinType_)) {
1337- for (auto & projection : rightProjections_) {
1338- auto & target = output->childAt (projection.outputChannel );
1339- target->setNull (row, true );
1340- }
1341- } else {
1342- for (auto & projection : leftProjections_) {
1343- auto & target = output->childAt (projection.outputChannel );
1344- target->setNull (row, true );
1345- }
13461351 }
13471352 };
13481353
@@ -1353,12 +1358,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13531358
13541359 joinTracker_->processFilterResult (i, passed, onMiss);
13551360
1356- if (isAntiJoin (joinType_)) {
1357- if (!passed) {
1358- rawIndices[numPassed++] = i;
1359- }
1360- } else {
1361- if (passed) {
1361+ if (!isAntiJoin (joinType_)) {
1362+ if (passed && !joinTracker_->isRightJoinForFullOuter (i)) {
13621363 rawIndices[numPassed++] = i;
13631364 }
13641365 }
@@ -1371,26 +1372,30 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13711372
13721373 // Every time we start a new left key match, `processFilterResult()` will
13731374 // check if at least one row from the previous match passed the filter. If
1374- // none did, it calls onMiss to add a record with null right projections to
1375- // the output.
1375+ // none did, it calls onMiss to add a record with null right projections
1376+ // to the output.
13761377 //
13771378 // Before we leave the current buffer, since we may not have seen the next
1378- // left key match yet, the last key match may still be pending to produce a
1379- // row (because `processFilterResult()` was not called yet).
1379+ // left key match yet, the last key match may still be pending to produce
1380+ // a row (because `processFilterResult()` was not called yet).
13801381 //
13811382 // To handle this, we need to call `noMoreFilterResults()` unless the
1382- // same current left key match may continue in the next buffer. So there are
1383- // two cases to check:
1383+ // same current left key match may continue in the next buffer. So there
1384+ // are two cases to check:
13841385 //
1385- // 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
1386- // a different key match.
1386+ // 1. If leftMatch_ is nullopt, there for sure the next buffer will
1387+ // contain a different key match.
13871388 //
13881389 // 2. leftMatch_ may not be nullopt, but may be related to a different
13891390 // (subsequent) left key. So we check if the last row in the batch has the
13901391 // same left row number as the last key match.
13911392 if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch (numRows - 1 )) {
13921393 joinTracker_->noMoreFilterResults (onMiss);
13931394 }
1395+
1396+ if (isAntiJoin (joinType_) && leftMatch_ && !previousLeftMatch_) {
1397+ joinTracker_->noMoreFilterResults (onMiss);
1398+ }
13941399 } else {
13951400 filterRows_.resize (numRows);
13961401 filterRows_.setAll ();
@@ -1412,17 +1417,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14121417
14131418 if (numPassed == numRows) {
14141419 // All rows passed.
1415- if (fullOuterOutput) {
1416- return fullOuterOutput;
1417- }
14181420 return output;
14191421 }
14201422
14211423 // Some, but not all rows passed.
1422- if (fullOuterOutput) {
1423- return wrap (numPassed, indices, fullOuterOutput);
1424- }
1425-
14261424 return wrap (numPassed, indices, output);
14271425}
14281426
0 commit comments