[X86][AVX] Add support for narrowing 128-bit+ shuffle mask elements to 64-bits to...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 14 Jul 2016 12:58:04 +0000 (12:58 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 14 Jul 2016 12:58:04 +0000 (12:58 +0000)
Primarily this is to allow blend with zero instead of having to use vperm2f128, but we can use this in the future to deal with AVX512 cases where we need to keep the original element size to correctly fold masked operations.

llvm-svn: 275406

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/vector-shuffle-combining-avx.ll
llvm/test/CodeGen/X86/vector-shuffle-combining-avx2.ll

index 5f88671..29dd133 100644 (file)
@@ -25017,11 +25017,11 @@ static bool matchBinaryVectorShuffle(MVT SrcVT, ArrayRef<int> Mask,
 /// for this operation, or into a PSHUFB instruction which is a fully general
 /// instruction but should only be used to replace chains over a certain depth.
 static bool combineX86ShuffleChain(SDValue Input, SDValue Root,
-                                   ArrayRef<int> Mask, int Depth,
+                                   ArrayRef<int> BaseMask, int Depth,
                                    bool HasVariableMask, SelectionDAG &DAG,
                                    TargetLowering::DAGCombinerInfo &DCI,
                                    const X86Subtarget &Subtarget) {
-  assert(!Mask.empty() && "Cannot combine an empty shuffle mask!");
+  assert(!BaseMask.empty() && "Cannot combine an empty shuffle mask!");
 
   // Find the operand that enters the chain. Note that multiple uses are OK
   // here, we're not going to remove the operand we find.
@@ -25033,23 +25033,24 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root,
 
   SDValue Res;
 
-  unsigned NumMaskElts = Mask.size();
-  if (NumMaskElts == 1) {
-    assert(Mask[0] == 0 && "Invalid shuffle index found!");
+  unsigned NumBaseMaskElts = BaseMask.size();
+  if (NumBaseMaskElts == 1) {
+    assert(BaseMask[0] == 0 && "Invalid shuffle index found!");
     DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Input),
                   /*AddTo*/ true);
     return true;
   }
 
   unsigned RootSizeInBits = RootVT.getSizeInBits();
-  unsigned MaskEltSizeInBits = RootSizeInBits / NumMaskElts;
+  unsigned BaseMaskEltSizeInBits = RootSizeInBits / NumBaseMaskElts;
 
   // Don't combine if we are a AVX512/EVEX target and the mask element size
   // is different from the root element size - this would prevent writemasks
   // from being reused.
+  // TODO - this currently prevents all lane shuffles from occurring.
   // TODO - check for writemasks usage instead of always preventing combining.
   // TODO - attempt to narrow Mask back to writemask size.
-  if (RootVT.getScalarSizeInBits() != MaskEltSizeInBits &&
+  if (RootVT.getScalarSizeInBits() != BaseMaskEltSizeInBits &&
       (RootSizeInBits == 512 ||
        (Subtarget.hasVLX() && RootSizeInBits >= 128))) {
     return false;
@@ -25058,16 +25059,15 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root,
   // TODO - handle 128/256-bit lane shuffles of 512-bit vectors.
 
   // Handle 128-bit lane shuffles of 256-bit vectors.
-  // TODO - handle blend with zero cases.
-  if (VT.is256BitVector() && Mask.size() == 2 &&
-      !isSequentialOrUndefOrZeroInRange(Mask, 0, 2, 0)) {
+  if (VT.is256BitVector() && NumBaseMaskElts == 2 &&
+      !isSequentialOrUndefOrZeroInRange(BaseMask, 0, 2, 0)) {
     if (Depth == 1 && Root.getOpcode() == X86ISD::VPERM2X128)
       return false; // Nothing to do!
     MVT ShuffleVT = (VT.isFloatingPoint() || !Subtarget.hasAVX2() ? MVT::v4f64
                                                                   : MVT::v4i64);
     unsigned PermMask = 0;
-    PermMask |= ((Mask[0] < 0 ? 0x8 : (Mask[0] & 1)) << 0);
-    PermMask |= ((Mask[1] < 0 ? 0x8 : (Mask[1] & 1)) << 4);
+    PermMask |= ((BaseMask[0] < 0 ? 0x8 : (BaseMask[0] & 1)) << 0);
+    PermMask |= ((BaseMask[1] < 0 ? 0x8 : (BaseMask[1] & 1)) << 4);
 
     Res = DAG.getBitcast(ShuffleVT, Input);
     DCI.AddToWorklist(Res.getNode());
@@ -25080,8 +25080,19 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root,
     return true;
   }
 
-  if (MaskEltSizeInBits > 64)
-    return false;
+  // For masks that have been widened to 128-bit elements or more,
+  // narrow back down to 64-bit elements.
+  SmallVector<int, 64> Mask;
+  if (BaseMaskEltSizeInBits > 64) {
+    assert((BaseMaskEltSizeInBits % 64) == 0 && "Illegal mask size");
+    int MaskScale = BaseMaskEltSizeInBits / 64;
+    scaleShuffleMask(MaskScale, BaseMask, Mask);
+  } else {
+    Mask = SmallVector<int, 64>(BaseMask.begin(), BaseMask.end());
+  }
+
+  unsigned NumMaskElts = Mask.size();
+  unsigned MaskEltSizeInBits = RootSizeInBits / NumMaskElts;
 
   // Determine the effective mask value type.
   bool FloatDomain =
index 2dc50cf..ac18bba 100644 (file)
@@ -124,25 +124,11 @@ define <8 x float> @combine_vpermilvar_vperm2f128_zero_8f32(<8 x float> %a0) {
 }
 
 define <4 x double> @combine_vperm2f128_vpermilvar_as_vpblendpd(<4 x double> %a0) {
-; AVX1-LABEL: combine_vperm2f128_vpermilvar_as_vpblendpd:
-; AVX1:       # BB#0:
-; AVX1-NEXT:    vpermilpd {{.*#+}} ymm0 = ymm0[1,0,3,2]
-; AVX1-NEXT:    vxorpd %ymm1, %ymm1, %ymm1
-; AVX1-NEXT:    vblendpd {{.*#+}} ymm0 = ymm0[0,1],ymm1[2,3]
-; AVX1-NEXT:    vpermilpd {{.*#+}} ymm0 = ymm0[1,0,3,2]
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: combine_vperm2f128_vpermilvar_as_vpblendpd:
-; AVX2:       # BB#0:
-; AVX2-NEXT:    vpermilpd {{.*#+}} ymm0 = ymm0[1,0,3,2]
-; AVX2-NEXT:    vpshufb {{.*#+}} ymm0 = ymm0[8,9,10,11,12,13,14,15,0,1,2,3,4,5,6,7],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
-; AVX2-NEXT:    retq
-;
-; AVX512F-LABEL: combine_vperm2f128_vpermilvar_as_vpblendpd:
-; AVX512F:       # BB#0:
-; AVX512F-NEXT:    vpermilpd {{.*#+}} ymm0 = ymm0[1,0,3,2]
-; AVX512F-NEXT:    vpshufb {{.*#+}} ymm0 = ymm0[8,9,10,11,12,13,14,15,0,1,2,3,4,5,6,7],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
-; AVX512F-NEXT:    retq
+; ALL-LABEL: combine_vperm2f128_vpermilvar_as_vpblendpd:
+; ALL:       # BB#0:
+; ALL-NEXT:    vxorpd %ymm1, %ymm1, %ymm1
+; ALL-NEXT:    vblendpd {{.*#+}} ymm0 = ymm0[0,1],ymm1[2,3]
+; ALL-NEXT:    retq
   %1 = tail call <4 x double> @llvm.x86.avx.vpermilvar.pd.256(<4 x double> %a0, <4 x i64> <i64 2, i64 0, i64 2, i64 0>)
   %2 = shufflevector <4 x double> %1, <4 x double> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 4, i32 5>
   %3 = tail call <4 x double> @llvm.x86.avx.vpermilvar.pd.256(<4 x double> %2, <4 x i64> <i64 2, i64 0, i64 2, i64 0>)
index 97492dd..5137a12 100644 (file)
@@ -65,8 +65,8 @@ define <4 x i64> @combine_permq_pshufb_as_vperm2i128(<4 x i64> %a0) {
 define <32 x i8> @combine_permq_pshufb_as_vpblendd(<4 x i64> %a0) {
 ; CHECK-LABEL: combine_permq_pshufb_as_vpblendd:
 ; CHECK:       # BB#0:
-; CHECK-NEXT:    vpshufd {{.*#+}} ymm0 = ymm0[2,3,0,1,6,7,4,5]
-; CHECK-NEXT:    vpshufb {{.*#+}} ymm0 = ymm0[8,9,10,11,12,13,14,15,0,1,2,3,4,5,6,7],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
+; CHECK-NEXT:    vpxor %ymm1, %ymm1, %ymm1
+; CHECK-NEXT:    vpblendd {{.*#+}} ymm0 = ymm0[0,1,2,3],ymm1[4,5,6,7]
 ; CHECK-NEXT:    retq
   %1 = shufflevector <4 x i64> %a0, <4 x i64> undef, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
   %2 = bitcast <4 x i64> %1 to <32 x i8>