[DAGCombine] Add undef shuffle elt support to partitionShuffleOfConcats
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 25 Feb 2019 16:02:01 +0000 (16:02 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 25 Feb 2019 16:02:01 +0000 (16:02 +0000)
Support undef shuffle mask indices in the shuffle(concat_vectors, concat_vectors) -> concat_vectors fold

Differential Revision: https://reviews.llvm.org/D58585

llvm-svn: 354793

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/X86/subvector-broadcast.ll

index 07214d7..c172fb0 100644 (file)
@@ -17366,20 +17366,24 @@ static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
+  ArrayRef<int> Mask = SVN->getMask();
 
   SmallVector<SDValue, 4> Ops;
   EVT ConcatVT = N0.getOperand(0).getValueType();
   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
   unsigned NumConcats = NumElts / NumElemsPerConcat;
 
+  auto IsUndefMaskElt = [](int i) { return i == -1; };
+
   // Special case: shuffle(concat(A,B)) can be more efficiently represented
   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
   // half vector elements.
   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
-      std::all_of(SVN->getMask().begin() + NumElemsPerConcat,
-                  SVN->getMask().end(), [](int i) { return i == -1; })) {
-    N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0), N0.getOperand(1),
-                              makeArrayRef(SVN->getMask().begin(), NumElemsPerConcat));
+      llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
+                   IsUndefMaskElt)) {
+    N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
+                              N0.getOperand(1),
+                              Mask.slice(0, NumElemsPerConcat));
     N1 = DAG.getUNDEF(ConcatVT);
     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
   }
@@ -17387,35 +17391,32 @@ static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
   // Look at every vector that's inserted. We're looking for exact
   // subvector-sized copies from a concatenated vector
   for (unsigned I = 0; I != NumConcats; ++I) {
-    // Make sure we're dealing with a copy.
     unsigned Begin = I * NumElemsPerConcat;
-    bool AllUndef = true, NoUndef = true;
-    for (unsigned J = Begin; J != Begin + NumElemsPerConcat; ++J) {
-      if (SVN->getMaskElt(J) >= 0)
-        AllUndef = false;
-      else
-        NoUndef = false;
+    ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
+
+    // Make sure we're dealing with a copy.
+    if (llvm::all_of(SubMask, IsUndefMaskElt)) {
+      Ops.push_back(DAG.getUNDEF(ConcatVT));
+      continue;
     }
 
-    if (NoUndef) {
-      if (SVN->getMaskElt(Begin) % NumElemsPerConcat != 0)
+    int OpIdx = -1;
+    for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
+      if (IsUndefMaskElt(SubMask[i]))
+        continue;
+      if ((SubMask[i] % NumElemsPerConcat) != i)
         return SDValue();
-
-      for (unsigned J = 1; J != NumElemsPerConcat; ++J)
-        if (SVN->getMaskElt(Begin + J - 1) + 1 != SVN->getMaskElt(Begin + J))
-          return SDValue();
-
-      unsigned FirstElt = SVN->getMaskElt(Begin) / NumElemsPerConcat;
-      if (FirstElt < N0.getNumOperands())
-        Ops.push_back(N0.getOperand(FirstElt));
-      else
-        Ops.push_back(N1.getOperand(FirstElt - N0.getNumOperands()));
-
-    } else if (AllUndef) {
-      Ops.push_back(DAG.getUNDEF(N0.getOperand(0).getValueType()));
-    } else { // Mixed with general masks and undefs, can't do optimization.
-      return SDValue();
+      int EltOpIdx = SubMask[i] / NumElemsPerConcat;
+      if (0 <= OpIdx && EltOpIdx != OpIdx)
+        return SDValue();
+      OpIdx = EltOpIdx;
     }
+    assert(0 <= OpIdx && "Unknown concat_vectors op");
+
+    if (OpIdx < (int)N0.getNumOperands())
+      Ops.push_back(N0.getOperand(OpIdx));
+    else
+      Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
   }
 
   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
index 1685376..fc75d6c 100644 (file)
@@ -1676,20 +1676,12 @@ define <8 x float> @broadcast_v8f32_v2f32_u1uu0uEu(<2 x float>* %vp, <8 x float>
 }
 
 define <8 x double> @broadcast_v8f64_v2f64_u1u10101(<2 x double>* %vp) {
-; X32-AVX1-LABEL: broadcast_v8f64_v2f64_u1u10101:
-; X32-AVX1:       # %bb.0:
-; X32-AVX1-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X32-AVX1-NEXT:    vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
-; X32-AVX1-NEXT:    vmovaps %ymm0, %ymm1
-; X32-AVX1-NEXT:    retl
-;
-; X32-AVX2-LABEL: broadcast_v8f64_v2f64_u1u10101:
-; X32-AVX2:       # %bb.0:
-; X32-AVX2-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X32-AVX2-NEXT:    vmovaps (%eax), %xmm0
-; X32-AVX2-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm1
-; X32-AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[0,1,2,1]
-; X32-AVX2-NEXT:    retl
+; X32-AVX-LABEL: broadcast_v8f64_v2f64_u1u10101:
+; X32-AVX:       # %bb.0:
+; X32-AVX-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X32-AVX-NEXT:    vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
+; X32-AVX-NEXT:    vmovaps %ymm0, %ymm1
+; X32-AVX-NEXT:    retl
 ;
 ; X32-AVX512-LABEL: broadcast_v8f64_v2f64_u1u10101:
 ; X32-AVX512:       # %bb.0:
@@ -1697,18 +1689,11 @@ define <8 x double> @broadcast_v8f64_v2f64_u1u10101(<2 x double>* %vp) {
 ; X32-AVX512-NEXT:    vbroadcastf32x4 {{.*#+}} zmm0 = mem[0,1,2,3,0,1,2,3,0,1,2,3,0,1,2,3]
 ; X32-AVX512-NEXT:    retl
 ;
-; X64-AVX1-LABEL: broadcast_v8f64_v2f64_u1u10101:
-; X64-AVX1:       # %bb.0:
-; X64-AVX1-NEXT:    vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
-; X64-AVX1-NEXT:    vmovaps %ymm0, %ymm1
-; X64-AVX1-NEXT:    retq
-;
-; X64-AVX2-LABEL: broadcast_v8f64_v2f64_u1u10101:
-; X64-AVX2:       # %bb.0:
-; X64-AVX2-NEXT:    vmovaps (%rdi), %xmm0
-; X64-AVX2-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm1
-; X64-AVX2-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[0,1,2,1]
-; X64-AVX2-NEXT:    retq
+; X64-AVX-LABEL: broadcast_v8f64_v2f64_u1u10101:
+; X64-AVX:       # %bb.0:
+; X64-AVX-NEXT:    vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1]
+; X64-AVX-NEXT:    vmovaps %ymm0, %ymm1
+; X64-AVX-NEXT:    retq
 ;
 ; X64-AVX512-LABEL: broadcast_v8f64_v2f64_u1u10101:
 ; X64-AVX512:       # %bb.0: