[X86][SSE] combineX86ShuffleChain add 'CanonicalizeShuffleInput' helper. NFCI.
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 6 Oct 2020 16:32:35 +0000 (17:32 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 6 Oct 2020 16:47:24 +0000 (17:47 +0100)
As part of PR45974, we're getting closer to not creating 'padded' vectors on-the-fly in combineX86ShufflesRecursively, and only pad the source inputs if we have a definite match inside combineX86ShuffleChain.

At the moment combineX86ShuffleChain just has to bitcast an input to the correct shuffle type, but eventually we'll need to pad them as well. So, move the bitcast into a 'CanonicalizeShuffleInput helper for now, making the diff for future padding support a lot smaller.

llvm/lib/Target/X86/X86ISelLowering.cpp

index bd80812..66986a1 100644 (file)
@@ -35013,6 +35013,12 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
   unsigned RootSizeInBits = RootVT.getSizeInBits();
   unsigned NumRootElts = RootVT.getVectorNumElements();
 
+  // Canonicalize shuffle input op to the requested type.
+  // TODO: Support cases where Op is smaller than VT.
+  auto CanonicalizeShuffleInput = [&](MVT VT, SDValue Op) {
+    return DAG.getBitcast(VT, Op);
+  };
+
   // Find the inputs that enter the chain. Note that multiple uses are OK
   // here, we're not going to remove the operands we find.
   bool UnaryShuffle = (Inputs.size() == 1);
@@ -35031,7 +35037,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
   unsigned NumBaseMaskElts = BaseMask.size();
   if (NumBaseMaskElts == 1) {
     assert(BaseMask[0] == 0 && "Invalid shuffle index found!");
-    return DAG.getBitcast(RootVT, V1);
+    return CanonicalizeShuffleInput(RootVT, V1);
   }
 
   bool OptForSize = DAG.shouldOptForSize();
@@ -35055,8 +35061,9 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
   // we can just use the broadcast directly. This works for smaller broadcast
   // elements as well as they already repeat across each mask element
   if (UnaryShuffle && isTargetShuffleSplat(V1) && !isAnyZero(BaseMask) &&
-      (BaseMaskEltSizeInBits % V1.getScalarValueSizeInBits()) == 0) {
-    return DAG.getBitcast(RootVT, V1);
+      (BaseMaskEltSizeInBits % V1.getScalarValueSizeInBits()) == 0 &&
+      V1.getValueSizeInBits() >= RootSizeInBits) {
+    return CanonicalizeShuffleInput(RootVT, V1);
   }
 
   // Attempt to match a subvector broadcast.
@@ -35089,7 +35096,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         return SDValue(); // Nothing to do!
       assert(isInRange(BaseMask[0], 0, NumBaseMaskElts) &&
              "Unexpected lane shuffle");
-      Res = DAG.getBitcast(ShuffleVT, V1);
+      Res = CanonicalizeShuffleInput(ShuffleVT, V1);
       unsigned SubIdx = BaseMask[0] * (8 / NumBaseMaskElts);
       bool UseZero = isAnyZero(BaseMask);
       Res = extractSubVector(Res, SubIdx, DAG, DL, BaseMaskEltSizeInBits);
@@ -35103,8 +35110,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
     narrowShuffleMaskElts(BaseMaskEltSizeInBits / 128, BaseMask, Mask);
 
     // Try to lower to vshuf64x2/vshuf32x4.
-    auto MatchSHUF128 = [](MVT ShuffleVT, const SDLoc &DL, ArrayRef<int> Mask,
-                           SDValue V1, SDValue V2, SelectionDAG &DAG) {
+    auto MatchSHUF128 = [&](MVT ShuffleVT, const SDLoc &DL, ArrayRef<int> Mask,
+                            SDValue V1, SDValue V2, SelectionDAG &DAG) {
       unsigned PermMask = 0;
       // Insure elements came from the same Op.
       SDValue Ops[2] = {DAG.getUNDEF(ShuffleVT), DAG.getUNDEF(ShuffleVT)};
@@ -35127,8 +35134,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       }
 
       return DAG.getNode(X86ISD::SHUF128, DL, ShuffleVT,
-                         DAG.getBitcast(ShuffleVT, Ops[0]),
-                         DAG.getBitcast(ShuffleVT, Ops[1]),
+                         CanonicalizeShuffleInput(ShuffleVT, Ops[0]),
+                         CanonicalizeShuffleInput(ShuffleVT, Ops[1]),
                          DAG.getTargetConstant(PermMask, DL, MVT::i8));
     };
 
@@ -35161,7 +35168,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       if (Depth == 0 && Root.getOpcode() == ISD::INSERT_SUBVECTOR)
         return SDValue(); // Nothing to do!
       assert(isInRange(BaseMask[0], 0, 2) && "Unexpected lane shuffle");
-      Res = DAG.getBitcast(ShuffleVT, V1);
+      Res = CanonicalizeShuffleInput(ShuffleVT, V1);
       Res = extract128BitVector(Res, BaseMask[0] * 2, DAG, DL);
       Res = widenSubVector(Res, BaseMask[1] == SM_SentinelZero, Subtarget, DAG,
                            DL, 256);
@@ -35181,7 +35188,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       PermMask |= ((BaseMask[0] < 0 ? 0x8 : (BaseMask[0] & 1)) << 0);
       PermMask |= ((BaseMask[1] < 0 ? 0x8 : (BaseMask[1] & 1)) << 4);
 
-      Res = DAG.getBitcast(ShuffleVT, V1);
+      Res = CanonicalizeShuffleInput(ShuffleVT, V1);
       Res = DAG.getNode(X86ISD::VPERM2X128, DL, ShuffleVT, Res,
                         DAG.getUNDEF(ShuffleVT),
                         DAG.getTargetConstant(PermMask, DL, MVT::i8));
@@ -35202,11 +35209,12 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         PermMask |= ((BaseMask[0] & 3) << 0);
         PermMask |= ((BaseMask[1] & 3) << 4);
 
-        Res = DAG.getNode(
-            X86ISD::VPERM2X128, DL, ShuffleVT,
-            DAG.getBitcast(ShuffleVT, isInRange(BaseMask[0], 0, 2) ? V1 : V2),
-            DAG.getBitcast(ShuffleVT, isInRange(BaseMask[1], 0, 2) ? V1 : V2),
-            DAG.getTargetConstant(PermMask, DL, MVT::i8));
+        SDValue LHS = isInRange(BaseMask[0], 0, 2) ? V1 : V2;
+        SDValue RHS = isInRange(BaseMask[1], 0, 2) ? V1 : V2;
+        Res = DAG.getNode(X86ISD::VPERM2X128, DL, ShuffleVT,
+                          CanonicalizeShuffleInput(ShuffleVT, LHS),
+                          CanonicalizeShuffleInput(ShuffleVT, RHS),
+                          DAG.getTargetConstant(PermMask, DL, MVT::i8));
         return DAG.getBitcast(RootVT, Res);
       }
     }
@@ -35282,7 +35290,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         if (Subtarget.hasAVX2()) {
           if (Depth == 0 && Root.getOpcode() == X86ISD::VBROADCAST)
             return SDValue(); // Nothing to do!
-          Res = DAG.getBitcast(MaskVT, V1);
+          Res = CanonicalizeShuffleInput(MaskVT, V1);
           Res = DAG.getNode(X86ISD::VBROADCAST, DL, MaskVT, Res);
           return DAG.getBitcast(RootVT, Res);
         }
@@ -35297,7 +35305,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
          (NumRootElts == ShuffleVT.getVectorNumElements()))) {
       if (Depth == 0 && Root.getOpcode() == Shuffle)
         return SDValue(); // Nothing to do!
-      Res = DAG.getBitcast(ShuffleSrcVT, NewV1);
+      Res = CanonicalizeShuffleInput(ShuffleSrcVT, NewV1);
       Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res);
       return DAG.getBitcast(RootVT, Res);
     }
@@ -35309,7 +35317,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
          (NumRootElts == ShuffleVT.getVectorNumElements()))) {
       if (Depth == 0 && Root.getOpcode() == Shuffle)
         return SDValue(); // Nothing to do!
-      Res = DAG.getBitcast(ShuffleVT, V1);
+      Res = CanonicalizeShuffleInput(ShuffleVT, V1);
       Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res,
                         DAG.getTargetConstant(PermuteImm, DL, MVT::i8));
       return DAG.getBitcast(RootVT, Res);
@@ -35330,8 +35338,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTPS)
           return SDValue(); // Nothing to do!
         Res = DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32,
-                          DAG.getBitcast(MVT::v4f32, SrcV1),
-                          DAG.getBitcast(MVT::v4f32, SrcV2),
+                          CanonicalizeShuffleInput(MVT::v4f32, SrcV1),
+                          CanonicalizeShuffleInput(MVT::v4f32, SrcV2),
                           DAG.getTargetConstant(PermuteImm, DL, MVT::i8));
         return DAG.getBitcast(RootVT, Res);
       }
@@ -35344,8 +35352,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         return SDValue(); // Nothing to do!
       PermuteImm = (/*DstIdx*/2 << 4) | (/*SrcIdx*/0 << 0);
       Res = DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32,
-                        DAG.getBitcast(MVT::v4f32, V1),
-                        DAG.getBitcast(MVT::v4f32, V2),
+                        CanonicalizeShuffleInput(MVT::v4f32, V1),
+                        CanonicalizeShuffleInput(MVT::v4f32, V2),
                         DAG.getTargetConstant(PermuteImm, DL, MVT::i8));
       return DAG.getBitcast(RootVT, Res);
     }
@@ -35359,8 +35367,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       (!IsMaskedShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
     if (Depth == 0 && Root.getOpcode() == Shuffle)
       return SDValue(); // Nothing to do!
-    NewV1 = DAG.getBitcast(ShuffleSrcVT, NewV1);
-    NewV2 = DAG.getBitcast(ShuffleSrcVT, NewV2);
+    NewV1 = CanonicalizeShuffleInput(ShuffleSrcVT, NewV1);
+    NewV2 = CanonicalizeShuffleInput(ShuffleSrcVT, NewV2);
     Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2);
     return DAG.getBitcast(RootVT, Res);
   }
@@ -35373,8 +35381,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       (!IsMaskedShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
     if (Depth == 0 && Root.getOpcode() == Shuffle)
       return SDValue(); // Nothing to do!
-    NewV1 = DAG.getBitcast(ShuffleVT, NewV1);
-    NewV2 = DAG.getBitcast(ShuffleVT, NewV2);
+    NewV1 = CanonicalizeShuffleInput(ShuffleVT, NewV1);
+    NewV2 = CanonicalizeShuffleInput(ShuffleVT, NewV2);
     Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2,
                       DAG.getTargetConstant(PermuteImm, DL, MVT::i8));
     return DAG.getBitcast(RootVT, Res);
@@ -35391,7 +35399,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
                             Zeroable)) {
       if (Depth == 0 && Root.getOpcode() == X86ISD::EXTRQI)
         return SDValue(); // Nothing to do!
-      V1 = DAG.getBitcast(IntMaskVT, V1);
+      V1 = CanonicalizeShuffleInput(IntMaskVT, V1);
       Res = DAG.getNode(X86ISD::EXTRQI, DL, IntMaskVT, V1,
                         DAG.getTargetConstant(BitLen, DL, MVT::i8),
                         DAG.getTargetConstant(BitIdx, DL, MVT::i8));
@@ -35401,8 +35409,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
     if (matchShuffleAsINSERTQ(IntMaskVT, V1, V2, Mask, BitLen, BitIdx)) {
       if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTQI)
         return SDValue(); // Nothing to do!
-      V1 = DAG.getBitcast(IntMaskVT, V1);
-      V2 = DAG.getBitcast(IntMaskVT, V2);
+      V1 = CanonicalizeShuffleInput(IntMaskVT, V1);
+      V2 = CanonicalizeShuffleInput(IntMaskVT, V2);
       Res = DAG.getNode(X86ISD::INSERTQI, DL, IntMaskVT, V1, V2,
                         DAG.getTargetConstant(BitLen, DL, MVT::i8),
                         DAG.getTargetConstant(BitIdx, DL, MVT::i8));
@@ -35421,7 +35429,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
           IsTRUNCATE ? (unsigned)ISD::TRUNCATE : (unsigned)X86ISD::VTRUNC;
       if (Depth == 0 && Root.getOpcode() == Opc)
         return SDValue(); // Nothing to do!
-      V1 = DAG.getBitcast(ShuffleSrcVT, V1);
+      V1 = CanonicalizeShuffleInput(ShuffleSrcVT, V1);
       Res = DAG.getNode(Opc, DL, ShuffleVT, V1);
       if (ShuffleVT.getSizeInBits() < RootSizeInBits)
         Res = widenSubVector(Res, true, Subtarget, DAG, DL, RootSizeInBits);
@@ -35438,8 +35446,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         return SDValue(); // Nothing to do!
       ShuffleSrcVT = MVT::getIntegerVT(MaskEltSizeInBits * 2);
       ShuffleSrcVT = MVT::getVectorVT(ShuffleSrcVT, NumMaskElts / 2);
-      V1 = DAG.getBitcast(ShuffleSrcVT, V1);
-      V2 = DAG.getBitcast(ShuffleSrcVT, V2);
+      V1 = CanonicalizeShuffleInput(ShuffleSrcVT, V1);
+      V2 = CanonicalizeShuffleInput(ShuffleSrcVT, V2);
       ShuffleSrcVT = MVT::getIntegerVT(MaskEltSizeInBits * 2);
       ShuffleSrcVT = MVT::getVectorVT(ShuffleSrcVT, NumMaskElts);
       Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShuffleSrcVT, V1, V2);
@@ -35468,7 +35476,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       if (Subtarget.hasAVX2() &&
           (MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) {
         SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true);
-        Res = DAG.getBitcast(MaskVT, V1);
+        Res = CanonicalizeShuffleInput(MaskVT, V1);
         Res = DAG.getNode(X86ISD::VPERMV, DL, MaskVT, VPermMask, Res);
         return DAG.getBitcast(RootVT, Res);
       }
@@ -35480,7 +35488,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
            (MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) ||
           (Subtarget.hasVBMI() &&
            (MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8))) {
-        V1 = DAG.getBitcast(MaskVT, V1);
+        V1 = CanonicalizeShuffleInput(MaskVT, V1);
         V2 = DAG.getUNDEF(MaskVT);
         Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG);
         return DAG.getBitcast(RootVT, Res);
@@ -35503,7 +35511,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       for (unsigned i = 0; i != NumMaskElts; ++i)
         if (Mask[i] == SM_SentinelZero)
           Mask[i] = NumMaskElts + i;
-      V1 = DAG.getBitcast(MaskVT, V1);
+      V1 = CanonicalizeShuffleInput(MaskVT, V1);
       V2 = getZeroVector(MaskVT, Subtarget, DAG, DL);
       Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG);
       return DAG.getBitcast(RootVT, Res);
@@ -35528,8 +35536,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
           (MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) ||
          (Subtarget.hasVBMI() && AllowBWIVPERMV3 &&
           (MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8)))) {
-      V1 = DAG.getBitcast(MaskVT, V1);
-      V2 = DAG.getBitcast(MaskVT, V2);
+      V1 = CanonicalizeShuffleInput(MaskVT, V1);
+      V2 = CanonicalizeShuffleInput(MaskVT, V2);
       Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG);
       return DAG.getBitcast(RootVT, Res);
     }
@@ -35556,7 +35564,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       EltBits[i] = AllOnes;
     }
     SDValue BitMask = getConstVector(EltBits, UndefElts, MaskVT, DAG, DL);
-    Res = DAG.getBitcast(MaskVT, V1);
+    Res = CanonicalizeShuffleInput(MaskVT, V1);
     unsigned AndOpcode =
         MaskVT.isFloatingPoint() ? unsigned(X86ISD::FAND) : unsigned(ISD::AND);
     Res = DAG.getNode(AndOpcode, DL, MaskVT, Res, BitMask);
@@ -35576,7 +35584,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       VPermIdx.push_back(Idx);
     }
     SDValue VPermMask = DAG.getBuildVector(IntMaskVT, DL, VPermIdx);
-    Res = DAG.getBitcast(MaskVT, V1);
+    Res = CanonicalizeShuffleInput(MaskVT, V1);
     Res = DAG.getNode(X86ISD::VPERMILPV, DL, MaskVT, Res, VPermMask);
     return DAG.getBitcast(RootVT, Res);
   }
@@ -35608,8 +35616,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       Index = (MaskVT.getScalarSizeInBits() == 64 ? Index << 1 : Index);
       VPerm2Idx.push_back(Index);
     }
-    V1 = DAG.getBitcast(MaskVT, V1);
-    V2 = DAG.getBitcast(MaskVT, V2);
+    V1 = CanonicalizeShuffleInput(MaskVT, V1);
+    V2 = CanonicalizeShuffleInput(MaskVT, V2);
     SDValue VPerm2MaskOp = getConstVector(VPerm2Idx, IntMaskVT, DAG, DL, true);
     Res = DAG.getNode(X86ISD::VPERMIL2, DL, MaskVT, V1, V2, VPerm2MaskOp,
                       DAG.getTargetConstant(M2ZImm, DL, MVT::i8));
@@ -35643,7 +35651,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       PSHUFBMask.push_back(DAG.getConstant(M, DL, MVT::i8));
     }
     MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes);
-    Res = DAG.getBitcast(ByteVT, V1);
+    Res = CanonicalizeShuffleInput(ByteVT, V1);
     SDValue PSHUFBMaskOp = DAG.getBuildVector(ByteVT, DL, PSHUFBMask);
     Res = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, Res, PSHUFBMaskOp);
     return DAG.getBitcast(RootVT, Res);
@@ -35673,8 +35681,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
       VPPERMMask.push_back(DAG.getConstant(M, DL, MVT::i8));
     }
     MVT ByteVT = MVT::v16i8;
-    V1 = DAG.getBitcast(ByteVT, V1);
-    V2 = DAG.getBitcast(ByteVT, V2);
+    V1 = CanonicalizeShuffleInput(ByteVT, V1);
+    V2 = CanonicalizeShuffleInput(ByteVT, V2);
     SDValue VPPERMMaskOp = DAG.getBuildVector(ByteVT, DL, VPPERMMask);
     Res = DAG.getNode(X86ISD::VPPERM, DL, ByteVT, V1, V2, VPPERMMaskOp);
     return DAG.getBitcast(RootVT, Res);
@@ -35700,8 +35708,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
         (MaskVT == MVT::v8i16 || MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) ||
        (Subtarget.hasVBMI() && AllowBWIVPERMV3 &&
         (MaskVT == MVT::v16i8 || MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8)))) {
-    V1 = DAG.getBitcast(MaskVT, V1);
-    V2 = DAG.getBitcast(MaskVT, V2);
+    V1 = CanonicalizeShuffleInput(MaskVT, V1);
+    V2 = CanonicalizeShuffleInput(MaskVT, V2);
     Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG);
     return DAG.getBitcast(RootVT, Res);
   }