[DAGCombiner] generalize binop-of-splats scalarization
authorSanjay Patel <spatel@rotateright.com>
Tue, 23 Apr 2019 13:16:41 +0000 (13:16 +0000)
committerSanjay Patel <spatel@rotateright.com>
Tue, 23 Apr 2019 13:16:41 +0000 (13:16 +0000)
If we only match build vectors, we can miss some patterns
that use shuffles as seen in the affected tests.

Note that the underlying calls within getSplatSourceVector()
have the potential for compile-time explosion because of
exponential recursion looking through binop opcodes, but
currently the list of supported opcodes is very limited.
Both of those problems should be addressed in follow-up
patches.

llvm-svn: 358984

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/X86/scalarize-fp.ll

index 29926a4..d1da7e3 100644 (file)
@@ -18764,59 +18764,51 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
   return SDValue();
 }
 
-/// If a vector binop is performed on build vector operands that only have one
-/// non-undef element, it may be profitable to extract, scalarize, and insert.
-static SDValue scalarizeBinOpOfBuildVectors(SDNode *N, SelectionDAG &DAG) {
+/// If a vector binop is performed on splat values, it may be profitable to
+/// extract, scalarize, and insert/splat.
+static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  if (N0.getOpcode() != ISD::BUILD_VECTOR || N0.getOpcode() != N1.getOpcode())
-    return SDValue();
-
-  // Return the index of exactly one scalar element in an otherwise undefined
-  // build vector.
-  auto getScalarIndex = [](SDValue V) {
-    int NotUndefIndex = -1;
-    for (unsigned i = 0, e = V.getNumOperands(); i != e; ++i) {
-      // Ignore undef elements.
-      if (V.getOperand(i).isUndef())
-        continue;
-      // There can be only one.
-      if (NotUndefIndex >= 0)
-        return -1;
-      // This might be the only non-undef operand.
-      NotUndefIndex = i;
-    }
-    return NotUndefIndex;
-  };
-  int N0Index = getScalarIndex(N0);
-  if (N0Index == -1)
-    return SDValue();
-  int N1Index = getScalarIndex(N1);
-  if (N1Index == -1)
-    return SDValue();
-
-  SDValue X = N0.getOperand(N0Index);
-  SDValue Y = N1.getOperand(N1Index);
-  EVT ScalarVT = X.getValueType();
-  if (ScalarVT != Y.getValueType())
-    return SDValue();
+  unsigned Opcode = N->getOpcode();
+  EVT VT = N->getValueType(0);
+  EVT EltVT = VT.getVectorElementType();
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   // TODO: Remove/replace the extract cost check? If the elements are available
   //       as scalars, then there may be no extract cost. Should we ask if
   //       inserting a scalar back into a vector is cheap instead?
-  EVT VT = N->getValueType(0);
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  if (N0Index != N1Index || !TLI.isExtractVecEltCheap(VT, N0Index) ||
-      !TLI.isOperationLegalOrCustom(N->getOpcode(), ScalarVT))
+  int Index0, Index1;
+  SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
+  SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
+  if (!Src0 || !Src1 || Index0 != Index1 ||
+      Src0.getValueType().getVectorElementType() != EltVT ||
+      Src1.getValueType().getVectorElementType() != EltVT ||
+      !TLI.isExtractVecEltCheap(VT, Index0) ||
+      !TLI.isOperationLegalOrCustom(Opcode, EltVT))
     return SDValue();
 
-  // bo (build_vec ...undef, x, undef...), (build_vec ...undef, y, undef...) -->
-  // build_vec ...undef, (bo x, y), undef...
-  SDValue ScalarBO = DAG.getNode(N->getOpcode(), SDLoc(N), ScalarVT, X, Y,
-                                 N->getFlags());
-  SmallVector<SDValue, 8> Ops(N0.getNumOperands(), DAG.getUNDEF(ScalarVT));
-  Ops[N0Index] = ScalarBO;
-  return DAG.getBuildVector(VT, SDLoc(N), Ops);
+  SDLoc DL(N);
+  SDValue IndexC =
+      DAG.getConstant(Index0, DL, TLI.getVectorIdxTy(DAG.getDataLayout()));
+  SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, N0, IndexC);
+  SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, N1, IndexC);
+  SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
+
+  // If all lanes but 1 are undefined, no need to splat the scalar result.
+  // TODO: Keep track of undefs and use that info in the general case.
+  if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
+      count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
+      count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
+    // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
+    // build_vec ..undef, (bo X, Y), undef...
+    SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
+    Ops[Index0] = ScalarBO;
+    return DAG.getBuildVector(VT, DL, Ops);
+  }
+
+  // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
+  SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
+  return DAG.getBuildVector(VT, DL, Ops);
 }
 
 /// Visit a binary vector operation, like ADD.
@@ -18881,7 +18873,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
     }
   }
 
-  if (SDValue V = scalarizeBinOpOfBuildVectors(N, DAG))
+  if (SDValue V = scalarizeBinOpOfSplats(N, DAG))
     return V;
 
   return SDValue();
index 40eed1f..650b948 100644 (file)
@@ -507,14 +507,14 @@ define <2 x i64> @add_splat_splat_v2i64(<2 x i64> %vx, <2 x i64> %vy) {
 define <2 x double> @fadd_splat_const_op1_v2f64(<2 x double> %vx) {
 ; SSE-LABEL: fadd_splat_const_op1_v2f64:
 ; SSE:       # %bb.0:
+; SSE-NEXT:    addsd {{.*}}(%rip), %xmm0
 ; SSE-NEXT:    unpcklpd {{.*#+}} xmm0 = xmm0[0,0]
-; SSE-NEXT:    addpd {{.*}}(%rip), %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: fadd_splat_const_op1_v2f64:
 ; AVX:       # %bb.0:
+; AVX-NEXT:    vaddsd {{.*}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    vmovddup {{.*#+}} xmm0 = xmm0[0,0]
-; AVX-NEXT:    vaddpd {{.*}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %splatx = shufflevector <2 x double> %vx, <2 x double> undef, <2 x i32> zeroinitializer
   %r = fadd <2 x double> %splatx, <double 42.0, double 42.0>
@@ -548,14 +548,14 @@ define <4 x double> @fsub_const_op0_splat_v4f64(double %x) {
 define <4 x float> @fmul_splat_const_op1_v4f32(<4 x float> %vx, <4 x float> %vy) {
 ; SSE-LABEL: fmul_splat_const_op1_v4f32:
 ; SSE:       # %bb.0:
+; SSE-NEXT:    mulss {{.*}}(%rip), %xmm0
 ; SSE-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,0,0,0]
-; SSE-NEXT:    mulps {{.*}}(%rip), %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: fmul_splat_const_op1_v4f32:
 ; AVX:       # %bb.0:
+; AVX-NEXT:    vmulss {{.*}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    vpermilps {{.*#+}} xmm0 = xmm0[0,0,0,0]
-; AVX-NEXT:    vmulps {{.*}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %splatx = shufflevector <4 x float> %vx, <4 x float> undef, <4 x i32> zeroinitializer
   %r = fmul fast <4 x float> %splatx, <float 17.0, float 17.0, float 17.0, float 17.0>
@@ -565,28 +565,18 @@ define <4 x float> @fmul_splat_const_op1_v4f32(<4 x float> %vx, <4 x float> %vy)
 define <8 x float> @fdiv_splat_const_op0_v8f32(<8 x float> %vy) {
 ; SSE-LABEL: fdiv_splat_const_op0_v8f32:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,0,0,0]
-; SSE-NEXT:    rcpps %xmm0, %xmm2
-; SSE-NEXT:    mulps %xmm2, %xmm0
-; SSE-NEXT:    movaps {{.*#+}} xmm1 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0]
-; SSE-NEXT:    subps %xmm0, %xmm1
-; SSE-NEXT:    mulps %xmm2, %xmm1
-; SSE-NEXT:    addps %xmm2, %xmm1
-; SSE-NEXT:    mulps {{.*}}(%rip), %xmm1
+; SSE-NEXT:    movss {{.*#+}} xmm1 = mem[0],zero,zero,zero
+; SSE-NEXT:    divss %xmm0, %xmm1
+; SSE-NEXT:    shufps {{.*#+}} xmm1 = xmm1[0,0,0,0]
 ; SSE-NEXT:    movaps %xmm1, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: fdiv_splat_const_op0_v8f32:
 ; AVX:       # %bb.0:
+; AVX-NEXT:    vmovss {{.*#+}} xmm1 = mem[0],zero,zero,zero
+; AVX-NEXT:    vdivss %xmm0, %xmm1, %xmm0
 ; AVX-NEXT:    vpermilps {{.*#+}} xmm0 = xmm0[0,0,0,0]
 ; AVX-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX-NEXT:    vrcpps %ymm0, %ymm1
-; AVX-NEXT:    vmulps %ymm1, %ymm0, %ymm0
-; AVX-NEXT:    vmovaps {{.*#+}} ymm2 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0]
-; AVX-NEXT:    vsubps %ymm0, %ymm2, %ymm0
-; AVX-NEXT:    vmulps %ymm0, %ymm1, %ymm0
-; AVX-NEXT:    vaddps %ymm0, %ymm1, %ymm0
-; AVX-NEXT:    vmulps {{.*}}(%rip), %ymm0, %ymm0
 ; AVX-NEXT:    retq
   %splatx = shufflevector <8 x float> <float 4.5, float 1.0, float 2.0, float 3.0, float 4.0, float 5.0, float 6.0, float 7.0>, <8 x float> undef, <8 x i32> zeroinitializer
   %splaty = shufflevector <8 x float> %vy, <8 x float> undef, <8 x i32> zeroinitializer
@@ -597,22 +587,18 @@ define <8 x float> @fdiv_splat_const_op0_v8f32(<8 x float> %vy) {
 define <8 x float> @fdiv_const_op1_splat_v8f32(<8 x float> %vx) {
 ; SSE-LABEL: fdiv_const_op1_splat_v8f32:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,0,0,0]
 ; SSE-NEXT:    xorps %xmm1, %xmm1
-; SSE-NEXT:    rcpps %xmm1, %xmm1
-; SSE-NEXT:    addps %xmm1, %xmm1
-; SSE-NEXT:    mulps %xmm0, %xmm1
-; SSE-NEXT:    movaps %xmm1, %xmm0
+; SSE-NEXT:    divss %xmm1, %xmm0
+; SSE-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,0,0,0]
+; SSE-NEXT:    movaps %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: fdiv_const_op1_splat_v8f32:
 ; AVX:       # %bb.0:
+; AVX-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX-NEXT:    vdivss %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpermilps {{.*#+}} xmm0 = xmm0[0,0,0,0]
 ; AVX-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX-NEXT:    vxorps %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vrcpps %ymm1, %ymm1
-; AVX-NEXT:    vaddps %ymm1, %ymm1, %ymm1
-; AVX-NEXT:    vmulps %ymm1, %ymm0, %ymm0
 ; AVX-NEXT:    retq
   %splatx = shufflevector <8 x float> %vx, <8 x float> undef, <8 x i32> zeroinitializer
   %splaty = shufflevector <8 x float> <float 0.0, float 1.0, float 2.0, float 3.0, float 4.0, float 5.0, float 6.0, float 7.0>, <8 x float> undef, <8 x i32> zeroinitializer