@@ -1027,6 +1027,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
10271027 ISD::SCALAR_TO_VECTOR,
10281028 ISD::ZERO_EXTEND,
10291029 ISD::SIGN_EXTEND_INREG,
1030+ ISD::ANY_EXTEND,
10301031 ISD::EXTRACT_VECTOR_ELT,
10311032 ISD::INSERT_VECTOR_ELT,
10321033 ISD::FCOPYSIGN});
@@ -13289,6 +13290,20 @@ static uint32_t getPermuteMask(SDValue V) {
1328913290 return ~0;
1329013291}
1329113292
13293+ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI);
13294+
13295+ SDValue SITargetLowering::performLeftShiftCombine(SDNode *N,
13296+ DAGCombinerInfo &DCI) const {
13297+ if (DCI.getDAGCombineLevel() < AfterLegalizeTypes)
13298+ return SDValue();
13299+
13300+ EVT VT = N->getValueType(0);
13301+ if (VT != MVT::i32)
13302+ return SDValue();
13303+
13304+ return matchPERM(N, DCI);
13305+ }
13306+
1329213307SDValue SITargetLowering::performAndCombine(SDNode *N,
1329313308 DAGCombinerInfo &DCI) const {
1329413309 if (DCI.isBeforeLegalize())
@@ -14330,10 +14345,11 @@ SDValue SITargetLowering::performXorCombine(SDNode *N,
1433014345 return SDValue();
1433114346}
1433214347
14333- SDValue SITargetLowering::performZeroExtendCombine(SDNode *N,
14334- DAGCombinerInfo &DCI) const {
14348+ SDValue
14349+ SITargetLowering::performZeroOrAnyExtendCombine(SDNode *N,
14350+ DAGCombinerInfo &DCI) const {
1433514351 if (!Subtarget->has16BitInsts() ||
14336- DCI.getDAGCombineLevel() < AfterLegalizeDAG )
14352+ DCI.getDAGCombineLevel() < AfterLegalizeTypes )
1433714353 return SDValue();
1433814354
1433914355 EVT VT = N->getValueType(0);
@@ -14344,7 +14360,41 @@ SDValue SITargetLowering::performZeroExtendCombine(SDNode *N,
1434414360 if (Src.getValueType() != MVT::i16)
1434514361 return SDValue();
1434614362
14347- return SDValue();
14363+ // TODO: We bail out below if SrcOffset is not in the first dword (>= 4). It's
14364+ // possible we're missing out on some combine opportunities, but we'd need to
14365+ // weigh the cost of extracting the byte from the upper dwords.
14366+
14367+ std::optional<ByteProvider<SDValue>> BP0 =
14368+ calculateByteProvider(SDValue(N, 0), 0, 0, 0);
14369+ if (!BP0.has_value() || 4 <= BP0->SrcOffset)
14370+ return SDValue();
14371+ SDValue V0 = BP0->Src.value_or(SDValue());
14372+
14373+ std::optional<ByteProvider<SDValue>> BP1 =
14374+ calculateByteProvider(SDValue(N, 0), 1, 0, 1);
14375+ if (!BP1.has_value() || 4 <= BP1->SrcOffset)
14376+ return SDValue();
14377+ SDValue V1 = BP1->Src.value_or(SDValue());
14378+
14379+ if (!V0 || !V1 || V0 == V1)
14380+ return SDValue();
14381+
14382+ SelectionDAG &DAG = DCI.DAG;
14383+ SDLoc DL(N);
14384+ uint32_t PermMask = 0x0c0c0c0c;
14385+ if (V0) {
14386+ V0 = DAG.getBitcastedAnyExtOrTrunc(V0, DL, MVT::i32);
14387+ PermMask = (PermMask & ~0xFF) | (BP0->SrcOffset + 4);
14388+ }
14389+
14390+ if (V1) {
14391+ V1 = DAG.getBitcastedAnyExtOrTrunc(V1, DL, MVT::i32);
14392+ PermMask = (PermMask & ~(0xFF << 8)) | (BP1->SrcOffset << 8);
14393+ }
14394+
14395+ SDValue P = DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, V0, V1,
14396+ DAG.getConstant(PermMask, DL, MVT::i32));
14397+ return P;
1434814398}
1434914399
1435014400SDValue
@@ -17012,6 +17062,10 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1701217062 return performMinMaxCombine(N, DCI);
1701317063 case ISD::FMA:
1701417064 return performFMACombine(N, DCI);
17065+
17066+ case ISD::SHL:
17067+ return performLeftShiftCombine(N, DCI);
17068+
1701517069 case ISD::AND:
1701617070 return performAndCombine(N, DCI);
1701717071 case ISD::OR:
@@ -17026,8 +17080,9 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1702617080 }
1702717081 case ISD::XOR:
1702817082 return performXorCombine(N, DCI);
17083+ case ISD::ANY_EXTEND:
1702917084 case ISD::ZERO_EXTEND:
17030- return performZeroExtendCombine (N, DCI);
17085+ return performZeroOrAnyExtendCombine (N, DCI);
1703117086 case ISD::SIGN_EXTEND_INREG:
1703217087 return performSignExtendInRegCombine(N, DCI);
1703317088 case AMDGPUISD::FP_CLASS:
0 commit comments