[Hexagon] Use BUILD_PAIR instead of HexagonISD::COMBINE in lowering
authorKrzysztof Parzyszek <kparzysz@quicinc.com>
Thu, 17 Nov 2022 18:55:00 +0000 (10:55 -0800)
committerKrzysztof Parzyszek <kparzysz@quicinc.com>
Thu, 17 Nov 2022 20:31:48 +0000 (12:31 -0800)
llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
llvm/lib/Target/Hexagon/HexagonISelLowering.h
llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp

index 046f82c..45a3b3c 100644 (file)
@@ -1808,6 +1808,7 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::FMUL,    MVT::f64, Legal);
   }
 
+  setTargetDAGCombine(ISD::OR);
   setTargetDAGCombine(ISD::TRUNCATE);
   setTargetDAGCombine(ISD::VSELECT);
 
@@ -2289,15 +2290,15 @@ HexagonTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG)
     }
 
     // Byte packs.
-    SDValue Concat10 = DAG.getNode(HexagonISD::COMBINE, dl,
-                                   typeJoin({ty(Op1), ty(Op0)}), {Op1, Op0});
+    SDValue Concat10 =
+        getCombine(Op1, Op0, dl, typeJoin({ty(Op1), ty(Op0)}), DAG);
     if (MaskIdx == (0x06040200 | MaskUnd))
       return getInstr(Hexagon::S2_vtrunehb, dl, VecTy, {Concat10}, DAG);
     if (MaskIdx == (0x07050301 | MaskUnd))
       return getInstr(Hexagon::S2_vtrunohb, dl, VecTy, {Concat10}, DAG);
 
-    SDValue Concat01 = DAG.getNode(HexagonISD::COMBINE, dl,
-                                   typeJoin({ty(Op0), ty(Op1)}), {Op0, Op1});
+    SDValue Concat01 =
+        getCombine(Op0, Op1, dl, typeJoin({ty(Op0), ty(Op1)}), DAG);
     if (MaskIdx == (0x02000604 | MaskUnd))
       return getInstr(Hexagon::S2_vtrunehb, dl, VecTy, {Concat01}, DAG);
     if (MaskIdx == (0x03010705 | MaskUnd))
@@ -2630,7 +2631,7 @@ HexagonTargetLowering::buildVector64(ArrayRef<SDValue> Elem, const SDLoc &dl,
   SDValue H = (ElemTy == MVT::i32)
                 ? Elem[1]
                 : buildVector32(Elem.drop_front(Num/2), dl, HalfTy, DAG);
-  return DAG.getNode(HexagonISD::COMBINE, dl, VecTy, {H, L});
+  return getCombine(H, L, dl, VecTy, DAG);
 }
 
 SDValue
@@ -2748,8 +2749,7 @@ HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
 
     for (unsigned R = Scale; R > 1; R /= 2) {
       ValR = contractPredicate(ValR, dl, DAG);
-      ValR = DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64,
-                         DAG.getUNDEF(MVT::i32), ValR);
+      ValR = getCombine(DAG.getUNDEF(MVT::i32), ValR, dl, MVT::i64, DAG);
     }
     // The longest possible subvector is at most 32 bits, so it is always
     // contained in the low subregister.
@@ -2857,6 +2857,28 @@ HexagonTargetLowering::appendUndef(SDValue Val, MVT ResTy, SelectionDAG &DAG)
 }
 
 SDValue
+HexagonTargetLowering::getCombine(SDValue Hi, SDValue Lo, const SDLoc &dl,
+                                  MVT ResTy, SelectionDAG &DAG) const {
+  MVT ElemTy = ty(Hi);
+  assert(ElemTy == ty(Lo));
+
+  if (!ElemTy.isVector()) {
+    assert(ElemTy.isScalarInteger());
+    MVT PairTy = MVT::getIntegerVT(2 * ElemTy.getSizeInBits());
+    SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, dl, PairTy, Lo, Hi);
+    return DAG.getBitcast(ResTy, Pair);
+  }
+
+  unsigned Width = ElemTy.getSizeInBits();
+  MVT IntTy = MVT::getIntegerVT(Width);
+  MVT PairTy = MVT::getIntegerVT(2 * Width);
+  SDValue Pair =
+      DAG.getNode(ISD::BUILD_PAIR, dl, PairTy,
+                  {DAG.getBitcast(IntTy, Lo), DAG.getBitcast(IntTy, Hi)});
+  return DAG.getBitcast(ResTy, Pair);
+}
+
+SDValue
 HexagonTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
   MVT VecTy = ty(Op);
   unsigned BW = VecTy.getSizeInBits();
@@ -2917,8 +2939,7 @@ HexagonTargetLowering::LowerCONCAT_VECTORS(SDValue Op,
   const SDLoc &dl(Op);
   if (VecTy.getSizeInBits() == 64) {
     assert(Op.getNumOperands() == 2);
-    return DAG.getNode(HexagonISD::COMBINE, dl, VecTy, Op.getOperand(1),
-                       Op.getOperand(0));
+    return getCombine(Op.getOperand(1), Op.getOperand(0), dl, VecTy, DAG);
   }
 
   MVT ElemTy = VecTy.getVectorElementType();
@@ -2941,8 +2962,7 @@ HexagonTargetLowering::LowerCONCAT_VECTORS(SDValue Op,
       SDValue W = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, P);
       for (unsigned R = Scale; R > 1; R /= 2) {
         W = contractPredicate(W, dl, DAG);
-        W = DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64,
-                        DAG.getUNDEF(MVT::i32), W);
+        W = getCombine(DAG.getUNDEF(MVT::i32), W, dl, MVT::i64, DAG);
       }
       W = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, W);
       Words[IdxW].push_back(W);
@@ -2966,8 +2986,7 @@ HexagonTargetLowering::LowerCONCAT_VECTORS(SDValue Op,
     // At this point there should only be two words left, and Scale should be 2.
     assert(Scale == 2 && Words[IdxW].size() == 2);
 
-    SDValue WW = DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64,
-                             Words[IdxW][1], Words[IdxW][0]);
+    SDValue WW = getCombine(Words[IdxW][1], Words[IdxW][0], dl, MVT::i64, DAG);
     return DAG.getNode(HexagonISD::D2P, dl, VecTy, WW);
   }
 
@@ -3376,8 +3395,8 @@ HexagonTargetLowering::ReplaceNodeResults(SDNode *N,
 }
 
 SDValue
-HexagonTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI)
-      const {
+HexagonTargetLowering::PerformDAGCombine(SDNode *N,
+                                         DAGCombinerInfo &DCI) const {
   if (isHvxOperation(N, DCI.DAG)) {
     if (SDValue V = PerformHvxDAGCombine(N, DCI))
       return V;
@@ -3409,12 +3428,12 @@ HexagonTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI)
   if (Opc == HexagonISD::P2D) {
     SDValue P = Op.getOperand(0);
     switch (P.getOpcode()) {
-      case HexagonISD::PTRUE:
-        return DCI.DAG.getConstant(-1, dl, ty(Op));
-      case HexagonISD::PFALSE:
-        return getZero(dl, ty(Op), DCI.DAG);
-      default:
-        break;
+    case HexagonISD::PTRUE:
+      return DCI.DAG.getConstant(-1, dl, ty(Op));
+    case HexagonISD::PFALSE:
+      return getZero(dl, ty(Op), DCI.DAG);
+    default:
+      break;
     }
   } else if (Opc == ISD::VSELECT) {
     // This is pretty much duplicated in HexagonISelLoweringHVX...
@@ -3442,6 +3461,36 @@ HexagonTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI)
       if (ty(Elem0).bitsGT(TruncTy))
         return DCI.DAG.getNode(ISD::TRUNCATE, dl, TruncTy, Elem0);
     }
+  } else if (Opc == ISD::OR) {
+    // fold (or (shl xx, s), (zext y)) -> (COMBINE (shl xx, s-32), y)
+    // if s >= 32
+    auto fold0 = [&, this](SDValue Op) {
+      if (ty(Op) != MVT::i64)
+        return SDValue();
+      SDValue Shl = Op.getOperand(0);
+      SDValue Zxt = Op.getOperand(1);
+      if (Shl.getOpcode() != ISD::SHL)
+        std::swap(Shl, Zxt);
+
+      if (Shl.getOpcode() != ISD::SHL || Zxt.getOpcode() != ISD::ZERO_EXTEND)
+        return SDValue();
+
+      SDValue Z = Zxt.getOperand(0);
+      auto *Amt = dyn_cast<ConstantSDNode>(Shl.getOperand(1));
+      if (Amt && Amt->getZExtValue() >= 32 && ty(Z).getSizeInBits() <= 32) {
+        unsigned A = Amt->getZExtValue();
+        SDValue S = Shl.getOperand(0);
+        SDValue T0 = DCI.DAG.getNode(ISD::SHL, dl, ty(S), S,
+                                     DCI.DAG.getConstant(32 - A, dl, MVT::i32));
+        SDValue T1 = DCI.DAG.getZExtOrTrunc(T0, dl, MVT::i32);
+        SDValue T2 = DCI.DAG.getZExtOrTrunc(Z, dl, MVT::i32);
+        return DCI.DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64, {T1, T2});
+      }
+      return SDValue();
+    };
+
+    if (SDValue R = fold0(Op))
+      return R;
   }
 
   return SDValue();
index feeb180..5db849d 100644 (file)
@@ -387,6 +387,8 @@ private:
   SDValue getSplatValue(SDValue Op, SelectionDAG &DAG) const;
   SDValue getVectorShiftByInt(SDValue Op, SelectionDAG &DAG) const;
   SDValue appendUndef(SDValue Val, MVT ResTy, SelectionDAG &DAG) const;
+  SDValue getCombine(SDValue Hi, SDValue Lo, const SDLoc &dl, MVT ResTy,
+                     SelectionDAG &DAG) const;
 
   bool isUndef(SDValue Op) const {
     if (Op.isMachineOpcode())
index ad0e8cd..2d50b62 100644 (file)
@@ -1295,7 +1295,7 @@ HexagonTargetLowering::extractHvxSubvectorReg(SDValue VecV, SDValue IdxV,
 
   SDValue W1Idx = DAG.getConstant(WordIdx+1, dl, MVT::i32);
   SDValue W1 = extractHvxElementReg(WordVec, W1Idx, dl, MVT::i32, DAG);
-  SDValue WW = DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64, {W1, W0});
+  SDValue WW = getCombine(W1, W0, dl, MVT::i64, DAG);
   return DAG.getBitcast(ResTy, WW);
 }
 
@@ -1358,7 +1358,7 @@ HexagonTargetLowering::extractHvxSubvectorPred(SDValue VecV, SDValue IdxV,
   SDValue W0 = DAG.getNode(HexagonISD::VEXTRACTW, dl, MVT::i32, {ShuffV, Zero});
   SDValue W1 = DAG.getNode(HexagonISD::VEXTRACTW, dl, MVT::i32,
                            {ShuffV, DAG.getConstant(4, dl, MVT::i32)});
-  SDValue Vec64 = DAG.getNode(HexagonISD::COMBINE, dl, MVT::v8i8, {W1, W0});
+  SDValue Vec64 = getCombine(W1, W0, dl, MVT::v8i8, DAG);
   return getInstr(Hexagon::A4_vcmpbgtui, dl, ResTy,
                   {Vec64, DAG.getTargetConstant(0, dl, MVT::i32)}, DAG);
 }
@@ -1995,8 +1995,7 @@ HexagonTargetLowering::LowerHvxBitcast(SDValue Op, SelectionDAG &DAG) const {
     SmallVector<SDValue,2> Combines;
     assert(Words.size() % 2 == 0);
     for (unsigned i = 0, e = Words.size(); i < e; i += 2) {
-      SDValue C = DAG.getNode(
-          HexagonISD::COMBINE, dl, MVT::i64, {Words[i+1], Words[i]});
+      SDValue C = getCombine(Words[i+1], Words[i], dl, MVT::i64, DAG);
       Combines.push_back(C);
     }