[DAG] Fold shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 16 Feb 2021 15:24:23 +0000 (15:24 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 16 Feb 2021 15:46:34 +0000 (15:46 +0000)
Fold shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d))) -> bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d))

Attempt to fold from a shuffle of a pair of binops to a binop of shuffles, as long as one/both of the binop sources are also shuffles that can be merged with the outer shuffle. This should guarantee that we remove one binop without introducing any additional shuffles.

Technically there's potential for a merged shuffle's lowering to be poorer than the original shuffle, but it could also be better, and I'm not seeing any regressions as long as we keep the 'don't merge splats' rule already present in MergeInnerShuffle.

This expands and generalizes an existing X86 combine and attempts to merge either of each binop's sources (with an on-the-fly commutation of the shuffle mask) - we couldn't do that in the x86 version as it had to stay in a form that DAGCombine's MergeInnerShuffle would still recognise.

Differential Revision: https://reviews.llvm.org/D96345

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/bitcast-and-setcc-256.ll
llvm/test/CodeGen/X86/promote-cmp.ll

index e3e79cb..586b3b9 100644 (file)
@@ -20919,11 +20919,13 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
 
   // Compute the combined shuffle mask for a shuffle with SV0 as the first
   // operand, and SV1 as the second operand.
-  // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask).
+  // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
+  //      Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
   auto MergeInnerShuffle =
-      [NumElts, &VT](ShuffleVectorSDNode *SVN, ShuffleVectorSDNode *OtherSVN,
-                     SDValue N1, const TargetLowering &TLI, SDValue &SV0,
-                     SDValue &SV1, SmallVectorImpl<int> &Mask) -> bool {
+      [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
+                     ShuffleVectorSDNode *OtherSVN, SDValue N1,
+                     const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
+                     SmallVectorImpl<int> &Mask) -> bool {
     // Don't try to fold splats; they're likely to simplify somehow, or they
     // might be free.
     if (OtherSVN->isSplat())
@@ -20940,6 +20942,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
         continue;
       }
 
+      if (Commute)
+        Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
+
       SDValue CurrentVec;
       if (Idx < (int)NumElts) {
         // This shuffle index refers to the inner shuffle N0. Lookup the inner
@@ -21045,7 +21050,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
 
     SDValue SV0, SV1;
     SmallVector<int, 4> Mask;
-    if (MergeInnerShuffle(SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) {
+    if (MergeInnerShuffle(false, SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) {
       // Check if all indices in Mask are Undef. In case, propagate Undef.
       if (llvm::all_of(Mask, [](int M) { return M < 0; }))
         return DAG.getUNDEF(VT);
@@ -21055,6 +21060,77 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
     }
   }
 
+  // Merge shuffles through binops if we are able to merge it with at least one
+  // other shuffles.
+  // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
+  if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
+    unsigned SrcOpcode = N0.getOpcode();
+    if (SrcOpcode == N1.getOpcode() && TLI.isBinOp(SrcOpcode) &&
+        N->isOnlyUserOf(N0.getNode()) && N->isOnlyUserOf(N1.getNode())) {
+      SDValue Op00 = N0.getOperand(0);
+      SDValue Op10 = N1.getOperand(0);
+      SDValue Op01 = N0.getOperand(1);
+      SDValue Op11 = N1.getOperand(1);
+      // TODO: We might be able to relax the VT check but we don't currently
+      // have any isBinOp() that has different result/ops VTs so play safe until
+      // we have test coverage.
+      if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
+          Op01.getValueType() == VT && Op11.getValueType() == VT &&
+          (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
+           Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
+           Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
+           Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
+        auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
+                                        SmallVectorImpl<int> &Mask, bool LeftOp,
+                                        bool Commute) {
+          SDValue InnerN = Commute ? N1 : N0;
+          SDValue Op0 = LeftOp ? Op00 : Op01;
+          SDValue Op1 = LeftOp ? Op10 : Op11;
+          if (Commute)
+            std::swap(Op0, Op1);
+          return Op0.getOpcode() == ISD::VECTOR_SHUFFLE &&
+                 InnerN->isOnlyUserOf(Op0.getNode()) &&
+                 MergeInnerShuffle(Commute, SVN, cast<ShuffleVectorSDNode>(Op0),
+                                   Op1, TLI, SV0, SV1, Mask) &&
+                 llvm::none_of(Mask, [](int M) { return M < 0; });
+        };
+
+        // Ensure we don't increase the number of shuffles - we must merge a
+        // shuffle from at least one of the LHS and RHS ops.
+        bool MergedLeft = false;
+        SDValue LeftSV0, LeftSV1;
+        SmallVector<int, 4> LeftMask;
+        if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
+            CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
+          MergedLeft = true;
+        } else {
+          LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
+          LeftSV0 = Op00, LeftSV1 = Op10;
+        }
+
+        bool MergedRight = false;
+        SDValue RightSV0, RightSV1;
+        SmallVector<int, 4> RightMask;
+        if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
+            CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
+          MergedRight = true;
+        } else {
+          RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
+          RightSV0 = Op01, RightSV1 = Op11;
+        }
+
+        if (MergedLeft || MergedRight) {
+          SDLoc DL(N);
+          SDValue LHS =
+              DAG.getVectorShuffle(VT, DL, LeftSV0, LeftSV1, LeftMask);
+          SDValue RHS =
+              DAG.getVectorShuffle(VT, DL, RightSV0, RightSV1, RightMask);
+          return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
+        }
+      }
+    }
+  }
+
   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
     return V;
 
index 2e5fb35..501274d 100644 (file)
@@ -38029,34 +38029,6 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
 
     if (SDValue HAddSub = foldShuffleOfHorizOp(N, DAG))
       return HAddSub;
-
-    // Merge shuffles through binops if its likely we'll be able to merge it
-    // with other shuffles (as long as they aren't splats).
-    // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
-    // TODO: We might be able to move this to DAGCombiner::visitVECTOR_SHUFFLE.
-    if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(N)) {
-      unsigned SrcOpcode = N->getOperand(0).getOpcode();
-      if (SrcOpcode == N->getOperand(1).getOpcode() && TLI.isBinOp(SrcOpcode) &&
-          N->isOnlyUserOf(N->getOperand(0).getNode()) &&
-          N->isOnlyUserOf(N->getOperand(1).getNode())) {
-        SDValue Op00 = N->getOperand(0).getOperand(0);
-        SDValue Op10 = N->getOperand(1).getOperand(0);
-        SDValue Op01 = N->getOperand(0).getOperand(1);
-        SDValue Op11 = N->getOperand(1).getOperand(1);
-        auto *SVN00 = dyn_cast<ShuffleVectorSDNode>(Op00);
-        auto *SVN10 = dyn_cast<ShuffleVectorSDNode>(Op10);
-        auto *SVN01 = dyn_cast<ShuffleVectorSDNode>(Op01);
-        auto *SVN11 = dyn_cast<ShuffleVectorSDNode>(Op11);
-        if (((SVN00 && !SVN00->isSplat()) || (SVN10 && !SVN10->isSplat())) &&
-            ((SVN01 && !SVN01->isSplat()) || (SVN11 && !SVN11->isSplat()))) {
-          SDLoc DL(N);
-          ArrayRef<int> Mask = SVN->getMask();
-          SDValue LHS = DAG.getVectorShuffle(VT, DL, Op00, Op10, Mask);
-          SDValue RHS = DAG.getVectorShuffle(VT, DL, Op01, Op11, Mask);
-          return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
-        }
-      }
-    }
   }
 
   // Attempt to combine into a vector load/broadcast.
index aa8b123..225ebac 100644 (file)
@@ -9,47 +9,41 @@
 define i4 @v4i64(<4 x i64> %a, <4 x i64> %b, <4 x i64> %c, <4 x i64> %d) {
 ; SSE2-SSSE3-LABEL: v4i64:
 ; SSE2-SSSE3:       # %bb.0:
-; SSE2-SSSE3-NEXT:    movdqa {{.*#+}} xmm8 = [2147483648,2147483648]
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm3
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm1
-; SSE2-SSSE3-NEXT:    movdqa %xmm1, %xmm9
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm3, %xmm9
+; SSE2-SSSE3-NEXT:    movdqa {{.*#+}} xmm9 = [2147483648,2147483648]
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm3
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm1
+; SSE2-SSSE3-NEXT:    movdqa %xmm1, %xmm10
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm3, %xmm10
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm2
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm0
+; SSE2-SSSE3-NEXT:    movdqa %xmm0, %xmm8
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm2, %xmm8
+; SSE2-SSSE3-NEXT:    movdqa %xmm8, %xmm11
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm11 = xmm11[0,2],xmm10[0,2]
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm3, %xmm1
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
-; SSE2-SSSE3-NEXT:    pand %xmm9, %xmm1
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm3 = xmm9[1,1,3,3]
-; SSE2-SSSE3-NEXT:    por %xmm1, %xmm3
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm2
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm0
-; SSE2-SSSE3-NEXT:    movdqa %xmm0, %xmm1
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm2, %xmm1
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm2, %xmm0
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm2 = xmm0[1,1,3,3]
-; SSE2-SSSE3-NEXT:    pand %xmm1, %xmm2
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[1,1,3,3]
-; SSE2-SSSE3-NEXT:    por %xmm2, %xmm0
-; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,2],xmm3[0,2]
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm7
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm5
-; SSE2-SSSE3-NEXT:    movdqa %xmm5, %xmm1
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm7, %xmm1
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm0 = xmm0[1,3],xmm1[1,3]
+; SSE2-SSSE3-NEXT:    andps %xmm11, %xmm0
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm8 = xmm8[1,3],xmm10[1,3]
+; SSE2-SSSE3-NEXT:    orps %xmm0, %xmm8
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm7
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm5
+; SSE2-SSSE3-NEXT:    movdqa %xmm5, %xmm0
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm7, %xmm0
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm6
+; SSE2-SSSE3-NEXT:    pxor %xmm9, %xmm4
+; SSE2-SSSE3-NEXT:    movdqa %xmm4, %xmm1
+; SSE2-SSSE3-NEXT:    pcmpgtd %xmm6, %xmm1
+; SSE2-SSSE3-NEXT:    movdqa %xmm1, %xmm2
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm2 = xmm2[0,2],xmm0[0,2]
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm7, %xmm5
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm2 = xmm5[1,1,3,3]
-; SSE2-SSSE3-NEXT:    pand %xmm1, %xmm2
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3]
-; SSE2-SSSE3-NEXT:    por %xmm2, %xmm1
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm6
-; SSE2-SSSE3-NEXT:    pxor %xmm8, %xmm4
-; SSE2-SSSE3-NEXT:    movdqa %xmm4, %xmm2
-; SSE2-SSSE3-NEXT:    pcmpgtd %xmm6, %xmm2
 ; SSE2-SSSE3-NEXT:    pcmpeqd %xmm6, %xmm4
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm3 = xmm4[1,1,3,3]
-; SSE2-SSSE3-NEXT:    pand %xmm2, %xmm3
-; SSE2-SSSE3-NEXT:    pshufd {{.*#+}} xmm2 = xmm2[1,1,3,3]
-; SSE2-SSSE3-NEXT:    por %xmm3, %xmm2
-; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm2 = xmm2[0,2],xmm1[0,2]
-; SSE2-SSSE3-NEXT:    andps %xmm0, %xmm2
-; SSE2-SSSE3-NEXT:    movmskps %xmm2, %eax
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm4 = xmm4[1,3],xmm5[1,3]
+; SSE2-SSSE3-NEXT:    andps %xmm2, %xmm4
+; SSE2-SSSE3-NEXT:    shufps {{.*#+}} xmm1 = xmm1[1,3],xmm0[1,3]
+; SSE2-SSSE3-NEXT:    orps %xmm4, %xmm1
+; SSE2-SSSE3-NEXT:    andps %xmm8, %xmm1
+; SSE2-SSSE3-NEXT:    movmskps %xmm1, %eax
 ; SSE2-SSSE3-NEXT:    # kill: def $al killed $al killed $eax
 ; SSE2-SSSE3-NEXT:    retq
 ;
index ecf475e..c59f808 100644 (file)
@@ -8,36 +8,34 @@ define <4 x i64> @PR45808(<4 x i64> %0, <4 x i64> %1) {
 ; SSE2-LABEL: PR45808:
 ; SSE2:       # %bb.0:
 ; SSE2-NEXT:    movdqa {{.*#+}} xmm4 = [2147483648,2147483648]
-; SSE2-NEXT:    movdqa %xmm3, %xmm5
-; SSE2-NEXT:    pxor %xmm4, %xmm5
+; SSE2-NEXT:    movdqa %xmm3, %xmm9
+; SSE2-NEXT:    pxor %xmm4, %xmm9
 ; SSE2-NEXT:    movdqa %xmm1, %xmm6
 ; SSE2-NEXT:    pxor %xmm4, %xmm6
-; SSE2-NEXT:    movdqa %xmm6, %xmm7
-; SSE2-NEXT:    pcmpgtd %xmm5, %xmm7
-; SSE2-NEXT:    pcmpeqd %xmm5, %xmm6
-; SSE2-NEXT:    pshufd {{.*#+}} xmm5 = xmm6[1,1,3,3]
-; SSE2-NEXT:    pand %xmm7, %xmm5
-; SSE2-NEXT:    pshufd {{.*#+}} xmm6 = xmm7[1,1,3,3]
-; SSE2-NEXT:    por %xmm5, %xmm6
-; SSE2-NEXT:    movdqa %xmm2, %xmm5
-; SSE2-NEXT:    pxor %xmm4, %xmm5
+; SSE2-NEXT:    movdqa %xmm6, %xmm8
+; SSE2-NEXT:    pcmpgtd %xmm9, %xmm8
+; SSE2-NEXT:    movdqa %xmm2, %xmm7
+; SSE2-NEXT:    pxor %xmm4, %xmm7
 ; SSE2-NEXT:    pxor %xmm0, %xmm4
-; SSE2-NEXT:    movdqa %xmm4, %xmm7
-; SSE2-NEXT:    pcmpgtd %xmm5, %xmm7
-; SSE2-NEXT:    pcmpeqd %xmm5, %xmm4
+; SSE2-NEXT:    movdqa %xmm4, %xmm5
+; SSE2-NEXT:    pcmpgtd %xmm7, %xmm5
+; SSE2-NEXT:    movdqa %xmm5, %xmm10
+; SSE2-NEXT:    shufps {{.*#+}} xmm10 = xmm10[0,2],xmm8[0,2]
+; SSE2-NEXT:    pcmpeqd %xmm9, %xmm6
+; SSE2-NEXT:    pcmpeqd %xmm7, %xmm4
+; SSE2-NEXT:    shufps {{.*#+}} xmm4 = xmm4[1,3],xmm6[1,3]
+; SSE2-NEXT:    andps %xmm10, %xmm4
+; SSE2-NEXT:    shufps {{.*#+}} xmm5 = xmm5[1,3],xmm8[1,3]
+; SSE2-NEXT:    orps %xmm4, %xmm5
+; SSE2-NEXT:    pshufd {{.*#+}} xmm4 = xmm5[2,1,3,3]
+; SSE2-NEXT:    pxor {{.*}}(%rip), %xmm5
+; SSE2-NEXT:    psllq $63, %xmm4
+; SSE2-NEXT:    psrad $31, %xmm4
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm4 = xmm4[1,1,3,3]
-; SSE2-NEXT:    pand %xmm7, %xmm4
-; SSE2-NEXT:    pshufd {{.*#+}} xmm5 = xmm7[1,1,3,3]
-; SSE2-NEXT:    por %xmm4, %xmm5
-; SSE2-NEXT:    pshufd {{.*#+}} xmm4 = xmm5[0,2,2,3]
-; SSE2-NEXT:    pxor {{.*}}(%rip), %xmm4
-; SSE2-NEXT:    psllq $63, %xmm6
-; SSE2-NEXT:    psrad $31, %xmm6
-; SSE2-NEXT:    pshufd {{.*#+}} xmm5 = xmm6[1,1,3,3]
-; SSE2-NEXT:    pand %xmm5, %xmm1
-; SSE2-NEXT:    pandn %xmm3, %xmm5
-; SSE2-NEXT:    por %xmm5, %xmm1
-; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm4[0,1,1,3]
+; SSE2-NEXT:    pand %xmm4, %xmm1
+; SSE2-NEXT:    pandn %xmm3, %xmm4
+; SSE2-NEXT:    por %xmm4, %xmm1
+; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm5[0,1,1,3]
 ; SSE2-NEXT:    psllq $63, %xmm3
 ; SSE2-NEXT:    psrad $31, %xmm3
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm3 = xmm3[1,1,3,3]