[DAG] visitVECTOR_SHUFFLE - move shuffle legality check into MergeInnerShuffle lamda...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 8 Feb 2021 13:46:31 +0000 (13:46 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 8 Feb 2021 14:25:16 +0000 (14:25 +0000)
This is going to be necessary for a future reuse of MergeInnerShuffle

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

index 3a3ae67..a17ac6f 100644 (file)
@@ -20836,10 +20836,10 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
   // Compute the combined shuffle mask for a shuffle with SV0 as the first
   // operand, and SV1 as the second operand.
   // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask).
-  auto MergeInnerShuffle = [NumElts](ShuffleVectorSDNode *SVN,
-                                     ShuffleVectorSDNode *OtherSVN, SDValue N1,
-                                     SDValue &SV0, SDValue &SV1,
-                                     SmallVectorImpl<int> &Mask) -> bool {
+  auto MergeInnerShuffle =
+      [NumElts, &VT](ShuffleVectorSDNode *SVN, ShuffleVectorSDNode *OtherSVN,
+                     SDValue N1, const TargetLowering &TLI, SDValue &SV0,
+                     SDValue &SV1, SmallVectorImpl<int> &Mask) -> bool {
     // Don't try to fold splats; they're likely to simplify somehow, or they
     // might be free.
     if (OtherSVN->isSplat())
@@ -20926,7 +20926,23 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
       // Bail out if we cannot convert the shuffle pair into a single shuffle.
       return false;
     }
-    return true;
+
+    if (llvm::all_of(Mask, [](int M) { return M < 0; }))
+      return true;
+
+    // Avoid introducing shuffles with illegal mask.
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
+    if (TLI.isShuffleMaskLegal(Mask, VT))
+      return true;
+
+    std::swap(SV0, SV1);
+    ShuffleVectorSDNode::commuteMask(Mask);
+    return TLI.isShuffleMaskLegal(Mask, VT);
   };
 
   // Try to fold according to rules:
@@ -20937,33 +20953,21 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
   // Only fold if this shuffle is the only user of the other shuffle.
   if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && N->isOnlyUserOf(N0.getNode()) &&
       Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
-    ShuffleVectorSDNode *OtherSV = cast<ShuffleVectorSDNode>(N0);
-
     // The incoming shuffle must be of the same type as the result of the
     // current shuffle.
+    auto *OtherSV = cast<ShuffleVectorSDNode>(N0);
     assert(OtherSV->getOperand(0).getValueType() == VT &&
            "Shuffle types don't match");
 
     SDValue SV0, SV1;
     SmallVector<int, 4> Mask;
-    if (MergeInnerShuffle(SVN, OtherSV, N1, SV0, SV1, Mask)) {
+    if (MergeInnerShuffle(SVN, OtherSV, N1, TLI, SV0, SV1, Mask)) {
       // Check if all indices in Mask are Undef. In case, propagate Undef.
       if (llvm::all_of(Mask, [](int M) { return M < 0; }))
         return DAG.getUNDEF(VT);
 
-      if (!SV0.getNode())
-        SV0 = DAG.getUNDEF(VT);
-      if (!SV1.getNode())
-        SV1 = DAG.getUNDEF(VT);
-
-      // Avoid introducing shuffles with illegal mask.
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
-      return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG);
+      return DAG.getVectorShuffle(VT, SDLoc(N), SV0 ? SV0 : DAG.getUNDEF(VT),
+                                  SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
     }
   }