[X86][SSE] Generalized SplitBinaryOpsAndApply to SplitOpsAndApply to support any...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 11 Mar 2018 14:04:53 +0000 (14:04 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 11 Mar 2018 14:04:53 +0000 (14:04 +0000)
I've kept SplitBinaryOpsAndApply as a wrapper to avoid a lot of makeArrayRef code.

llvm-svn: 327240

llvm/lib/Target/X86/X86ISelLowering.cpp

index 89ad2bb..3836158 100644 (file)
@@ -5095,18 +5095,17 @@ static SDValue widenSubVector(MVT VT, SDValue Vec, bool ZeroNewElements,
                      DAG.getIntPtrConstant(0, dl));
 }
 
-// Helper for splitting operands of a binary operation to legal target size and
+// Helper for splitting operands of an operation to legal target size and
 // apply a function on each part.
 // Useful for operations that are available on SSE2 in 128-bit, on AVX2 in
-// 256-bit and on AVX512BW in 512-bit.
-// The argument VT is the type used for deciding if/how to split the operands
-// Op0 and Op1. Op0 and Op1 do *not* have to be of type VT.
-// The argument Builder is a function that will be applied on each split psrt:
-// SDValue Builder(SelectionDAG&G, SDLoc, SDValue, SDValue)
+// 256-bit and on AVX512BW in 512-bit. The argument VT is the type used for
+// deciding if/how to split Ops. Ops elements do *not* have to be of type VT.
+// The argument Builder is a function that will be applied on each split part:
+// SDValue Builder(SelectionDAG&G, SDLoc, ArrayRef<SDValue>)
 template <typename F>
-SDValue SplitBinaryOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget,
-                               const SDLoc &DL, EVT VT, SDValue Op0,
-                               SDValue Op1, F Builder) {
+SDValue SplitOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget,
+                         const SDLoc &DL, EVT VT, ArrayRef<SDValue> Ops,
+                         F Builder) {
   assert(Subtarget.hasSSE2() && "Target assumed to support at least SSE2");
   unsigned NumSubs = 1;
   if (Subtarget.useBWIRegs()) {
@@ -5127,21 +5126,32 @@ SDValue SplitBinaryOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget,
   }
 
   if (NumSubs == 1)
-    return Builder(DAG, DL, Op0, Op1);
+    return Builder(DAG, DL, Ops);
 
   SmallVector<SDValue, 4> Subs;
-  EVT InVT = Op0.getValueType();
-  EVT SubVT = EVT::getVectorVT(*DAG.getContext(), InVT.getScalarType(),
-                               InVT.getVectorNumElements() / NumSubs);
   for (unsigned i = 0; i != NumSubs; ++i) {
-    unsigned Idx = i * SubVT.getVectorNumElements();
-    SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits());
-    SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits());
-    Subs.push_back(Builder(DAG, DL, LHS, RHS));
+    SmallVector<SDValue, 2> SubOps;
+    for (SDValue Op : Ops) {
+      EVT OpVT = Op.getValueType();
+      unsigned NumSubElts = OpVT.getVectorNumElements() / NumSubs;
+      unsigned SizeSub = OpVT.getSizeInBits() / NumSubs;
+      SubOps.push_back(extractSubVector(Op, i * NumSubElts, DAG, DL, SizeSub));
+    }
+    Subs.push_back(Builder(DAG, DL, SubOps));
   }
   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs);
 }
 
+// Helper for splitting operands of a binary operation to legal target size and
+// apply a function on each part.
+template <typename F>
+SDValue SplitBinaryOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget,
+                               const SDLoc &DL, EVT VT, SDValue Op0,
+                               SDValue Op1, F Builder) {
+  return SplitOpsAndApply(DAG, Subtarget, DL, VT, makeArrayRef({Op0, Op1}),
+                          Builder);
+}
+
 // Return true if the instruction zeroes the unused upper part of the
 // destination and accepts mask.
 static bool isMaskedZeroUpperBitsvXi1(unsigned int Opcode) {
@@ -31249,10 +31259,10 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
   SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
 
   // Actually build the SAD, split as 128/256/512 bits for SSE/AVX2/AVX512BW.
-  auto PSADBWBuilder = [](SelectionDAG &DAG, const SDLoc &DL, SDValue Op0,
-                          SDValue Op1) {
-    MVT VT = MVT::getVectorVT(MVT::i64, Op0.getValueSizeInBits() / 64);
-    return DAG.getNode(X86ISD::PSADBW, DL, VT, Op0, Op1);
+  auto PSADBWBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                          ArrayRef<SDValue> Ops) {
+    MVT VT = MVT::getVectorVT(MVT::i64, Ops[0].getValueSizeInBits() / 64);
+    return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops);
   };
   MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
   return SplitBinaryOpsAndApply(DAG, Subtarget, DL, SadVT, SadOp0, SadOp1,
@@ -32079,9 +32089,9 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
       SDValue OpLHS = Other->getOperand(0), OpRHS = Other->getOperand(1);
       SDValue CondRHS = Cond->getOperand(1);
 
-      auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL, SDValue Op0,
-                             SDValue Op1) {
-        return DAG.getNode(X86ISD::SUBUS, DL, Op0.getValueType(), Op0, Op1);
+      auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                             ArrayRef<SDValue> Ops) {
+        return DAG.getNode(X86ISD::SUBUS, DL, Ops[0].getValueType(), Ops);
       };
 
       // Look for a general sub with unsigned saturation first.
@@ -33035,10 +33045,10 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   // Use SplitBinaryOpsAndApply to handle AVX splitting.
-  auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, SDValue Op0,
-                           SDValue Op1) {
-    MVT VT = MVT::getVectorVT(MVT::i32, Op0.getValueSizeInBits() / 32);
-    return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Op0, Op1);
+  auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                           ArrayRef<SDValue> Ops) {
+    MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
+    return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops);
   };
   return SplitBinaryOpsAndApply(DAG, Subtarget, SDLoc(N), VT,
                                 DAG.getBitcast(WVT, N0),
@@ -34672,9 +34682,9 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
   Operands[0] = LHS.getOperand(0);
   Operands[1] = LHS.getOperand(1);
 
-  auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL, SDValue Op0,
-                       SDValue Op1) {
-    return DAG.getNode(X86ISD::AVG, DL, Op0.getValueType(), Op0, Op1);
+  auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                       ArrayRef<SDValue> Ops) {
+    return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops);
   };
 
   // Take care of the case when one of the operands is a constant vector whose
@@ -37705,10 +37715,10 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG,
   SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1));
 
   // Madd vector size is half of the original vector size
-  auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, SDValue Op0,
-                           SDValue Op1) {
-    MVT VT = MVT::getVectorVT(MVT::i32, Op0.getValueSizeInBits() / 32);
-    return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Op0, Op1);
+  auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                           ArrayRef<SDValue> Ops) {
+    MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
+    return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops);
   };
   SDValue Madd = SplitBinaryOpsAndApply(DAG, Subtarget, DL, MAddVT, N0, N1,
                                         PMADDWDBuilder);
@@ -37904,21 +37914,21 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
   if (!canReduceVMulWidth(Mul.getNode(), DAG, Mode) || Mode == MULU16)
     return SDValue();
 
-  auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, SDValue Op0,
-                         SDValue Op1) {
+  auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                         ArrayRef<SDValue> Ops) {
     // Shrink by adding truncate nodes and let DAGCombine fold with the
     // sources.
-    EVT InVT = Op0.getValueType();
+    EVT InVT = Ops[0].getValueType();
     assert(InVT.getScalarType() == MVT::i32 &&
            "Unexpected scalar element type");
-    assert(InVT == Op1.getValueType() && "Operands' types mismatch");
+    assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
     EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
                                  InVT.getVectorNumElements() / 2);
     EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
                                    InVT.getVectorNumElements());
     return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT,
-                       DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Op0),
-                       DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Op1));
+                       DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Ops[0]),
+                       DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Ops[1]));
   };
   return SplitBinaryOpsAndApply(DAG, Subtarget, DL, VT, Mul.getOperand(0),
                                 Mul.getOperand(1), PMADDBuilder);
@@ -37994,9 +38004,9 @@ static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG,
   } else
     return SDValue();
 
-  auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL, SDValue Op0,
-                         SDValue Op1) {
-    return DAG.getNode(X86ISD::SUBUS, DL, Op0.getValueType(), Op0, Op1);
+  auto SUBUSBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+                         ArrayRef<SDValue> Ops) {
+    return DAG.getNode(X86ISD::SUBUS, DL, Ops[0].getValueType(), Ops);
   };
 
   // PSUBUS doesn't support v8i32/v8i64/v16i32, but it can be enabled with