[X86][SSE] Replace combineShuffleWithHorizOp with canonicalizeShuffleMaskWithHorizOp
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 16 Aug 2020 11:26:09 +0000 (12:26 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 16 Aug 2020 11:26:27 +0000 (12:26 +0100)
Instead of just attempting to fold shuffle(HOP,HOP) for a specific target shuffle, make this part of combineX86ShufflesRecursively so we can perform this on the combined shuffle chain, which is particularly useful for recognising more cases of where we're performing multiple HOPs that can be merged and pre-AVX where we don't have good blend/unary target shuffle support.

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/haddsub-shuf.ll
llvm/test/CodeGen/X86/haddsub-undef.ll
llvm/test/CodeGen/X86/phaddsub.ll

index e72d089..ce90883 100644 (file)
@@ -35328,6 +35328,110 @@ static SDValue combineX86ShuffleChainWithExtract(
   return SDValue();
 }
 
+// Canonicalize the combined shuffle mask chain with horizontal ops.
+// NOTE: This may update the Ops and Mask.
+static SDValue canonicalizeShuffleMaskWithHorizOp(
+    MutableArrayRef<SDValue> Ops, MutableArrayRef<int> Mask,
+    unsigned RootSizeInBits, const SDLoc &DL, SelectionDAG &DAG,
+    const X86Subtarget &Subtarget) {
+
+  // Combine binary shuffle of 2 similar 'Horizontal' instructions into a
+  // single instruction. Attempt to match a v2X64 repeating shuffle pattern that
+  // represents the LHS/RHS inputs for the lower/upper halves.
+  if (Mask.empty() || Ops.empty() || 2 < Ops.size())
+    return SDValue();
+
+  SDValue BC0 = peekThroughBitcasts(Ops.front());
+  SDValue BC1 = peekThroughBitcasts(Ops.back());
+  EVT VT0 = BC0.getValueType();
+  EVT VT1 = BC1.getValueType();
+  unsigned Opcode0 = BC0.getOpcode();
+  unsigned Opcode1 = BC1.getOpcode();
+  if (Opcode0 != Opcode1 || VT0 != VT1 || VT0.getSizeInBits() != RootSizeInBits)
+    return SDValue();
+
+  bool isHoriz = (Opcode0 == X86ISD::FHADD || Opcode0 == X86ISD::HADD ||
+                  Opcode0 == X86ISD::FHSUB || Opcode0 == X86ISD::HSUB);
+  bool isPack = (Opcode0 == X86ISD::PACKSS || Opcode0 == X86ISD::PACKUS);
+  if (!isHoriz && !isPack)
+    return SDValue();
+
+  if (Mask.size() == VT0.getVectorNumElements()) {
+    int NumElts = VT0.getVectorNumElements();
+    int NumLanes = VT0.getSizeInBits() / 128;
+    int NumEltsPerLane = NumElts / NumLanes;
+    int NumHalfEltsPerLane = NumEltsPerLane / 2;
+
+    // Canonicalize binary shuffles of horizontal ops that use the
+    // same sources to an unary shuffle.
+    // TODO: Try to perform this fold even if the shuffle remains.
+    if (Ops.size() == 2) {
+      auto ContainsOps = [](SDValue HOp, SDValue Op) {
+        return Op == HOp.getOperand(0) || Op == HOp.getOperand(1);
+      };
+      // Commute if all BC0's ops are contained in BC1.
+      if (ContainsOps(BC1, BC0.getOperand(0)) &&
+          ContainsOps(BC1, BC0.getOperand(1))) {
+        ShuffleVectorSDNode::commuteMask(Mask);
+        std::swap(Ops[0], Ops[1]);
+        std::swap(BC0, BC1);
+      }
+
+      // If BC1 can be represented by BC0, then convert to unary shuffle.
+      if (ContainsOps(BC0, BC1.getOperand(0)) &&
+          ContainsOps(BC0, BC1.getOperand(1))) {
+        for (int &M : Mask) {
+          if (M < NumElts) // BC0 element or UNDEF/Zero sentinel.
+            continue;
+          int SubLane = ((M % NumEltsPerLane) >= NumHalfEltsPerLane) ? 1 : 0;
+          M -= NumElts + (SubLane * NumHalfEltsPerLane);
+          if (BC1.getOperand(SubLane) != BC0.getOperand(0))
+            M += NumHalfEltsPerLane;
+        }
+      }
+    }
+
+    // Canonicalize unary horizontal ops to only refer to lower halves.
+    for (int i = 0; i != NumElts; ++i) {
+      int &M = Mask[i];
+      if (isUndefOrZero(M))
+        continue;
+      if (M < NumElts && BC0.getOperand(0) == BC0.getOperand(1) &&
+          (M % NumEltsPerLane) >= NumHalfEltsPerLane)
+        M -= NumHalfEltsPerLane;
+      if (NumElts <= M && BC1.getOperand(0) == BC1.getOperand(1) &&
+          (M % NumEltsPerLane) >= NumHalfEltsPerLane)
+        M -= NumHalfEltsPerLane;
+    }
+  }
+
+  unsigned EltSizeInBits = RootSizeInBits / Mask.size();
+  SmallVector<int, 16> TargetMask128, WideMask128;
+  if (isRepeatedTargetShuffleMask(128, EltSizeInBits, Mask, TargetMask128) &&
+      scaleShuffleElements(TargetMask128, 2, WideMask128)) {
+    assert(isUndefOrZeroOrInRange(WideMask128, 0, 4) && "Illegal shuffle");
+    bool SingleOp = (Ops.size() == 1);
+    if (!isHoriz || shouldUseHorizontalOp(SingleOp, DAG, Subtarget)) {
+      SDValue Lo = isInRange(WideMask128[0], 0, 2) ? BC0 : BC1;
+      SDValue Hi = isInRange(WideMask128[1], 0, 2) ? BC0 : BC1;
+      Lo = Lo.getOperand(WideMask128[0] & 1);
+      Hi = Hi.getOperand(WideMask128[1] & 1);
+      if (SingleOp) {
+        MVT SrcVT = BC0.getOperand(0).getSimpleValueType();
+        SDValue Undef = DAG.getUNDEF(SrcVT);
+        SDValue Zero = getZeroVector(SrcVT, Subtarget, DAG, DL);
+        Lo = (WideMask128[0] == SM_SentinelZero ? Zero : Lo);
+        Hi = (WideMask128[1] == SM_SentinelZero ? Zero : Hi);
+        Lo = (WideMask128[0] == SM_SentinelUndef ? Undef : Lo);
+        Hi = (WideMask128[1] == SM_SentinelUndef ? Undef : Hi);
+      }
+      return DAG.getNode(Opcode0, DL, VT0, Lo, Hi);
+    }
+  }
+
+  return SDValue();
+}
+
 // Attempt to constant fold all of the constant source ops.
 // Returns true if the entire shuffle is folded to a constant.
 // TODO: Extend this to merge multiple constant Ops and update the mask.
@@ -35685,6 +35789,12 @@ static SDValue combineX86ShufflesRecursively(
           Ops, Mask, Root, HasVariableMask, DAG, Subtarget))
     return Cst;
 
+  // Canonicalize the combined shuffle mask chain with horizontal ops.
+  // NOTE: This will update the Ops and Mask.
+  if (SDValue HOp = canonicalizeShuffleMaskWithHorizOp(
+          Ops, Mask, RootSizeInBits, SDLoc(Root), DAG, Subtarget))
+    return DAG.getBitcast(Root.getValueType(), HOp);
+
   // We can only combine unary and binary shuffle mask cases.
   if (Ops.size() <= 2) {
     // Minor canonicalization of the accumulated shuffle mask to make it easier
@@ -35900,113 +36010,6 @@ combineRedundantDWordShuffle(SDValue N, MutableArrayRef<int> Mask,
   return V;
 }
 
-// TODO: Merge with foldShuffleOfHorizOp.
-static SDValue combineShuffleWithHorizOp(SDValue N, MVT VT, const SDLoc &DL,
-                                         SelectionDAG &DAG,
-                                         const X86Subtarget &Subtarget) {
-  bool IsUnary;
-  SmallVector<int, 64> TargetMask;
-  SmallVector<SDValue, 2> TargetOps;
-  if (!isTargetShuffle(N.getOpcode()) ||
-      !getTargetShuffleMask(N.getNode(), VT, true, TargetOps, TargetMask,
-                            IsUnary))
-    return SDValue();
-
-  // Combine binary shuffle of 2 similar 'Horizontal' instructions into a
-  // single instruction. Attempt to match a v2X64 repeating shuffle pattern that
-  // represents the LHS/RHS inputs for the lower/upper halves.
-  if (TargetMask.empty() || TargetOps.empty() || 2 < TargetOps.size())
-    return SDValue();
-
-  SDValue BC0 = peekThroughBitcasts(TargetOps.front());
-  SDValue BC1 = peekThroughBitcasts(TargetOps.back());
-  EVT VT0 = BC0.getValueType();
-  EVT VT1 = BC1.getValueType();
-  unsigned Opcode0 = BC0.getOpcode();
-  unsigned Opcode1 = BC1.getOpcode();
-  if (Opcode0 != Opcode1 || VT0 != VT1)
-    return SDValue();
-
-  bool isHoriz = (Opcode0 == X86ISD::FHADD || Opcode0 == X86ISD::HADD ||
-                  Opcode0 == X86ISD::FHSUB || Opcode0 == X86ISD::HSUB);
-  bool isPack = (Opcode0 == X86ISD::PACKSS || Opcode0 == X86ISD::PACKUS);
-  if (!isHoriz && !isPack)
-    return SDValue();
-
-  if (TargetMask.size() == VT0.getVectorNumElements()) {
-    int NumElts = VT0.getVectorNumElements();
-    int NumLanes = VT0.getSizeInBits() / 128;
-    int NumEltsPerLane = NumElts / NumLanes;
-    int NumHalfEltsPerLane = NumEltsPerLane / 2;
-
-    // Canonicalize binary shuffles of horizontal ops that use the
-    // same sources to an unary shuffle.
-    // TODO: Try to perform this fold even if the shuffle remains.
-    if (BC0 != BC1) {
-      auto ContainsOps = [](SDValue HOp, SDValue Op) {
-        return Op == HOp.getOperand(0) || Op == HOp.getOperand(1);
-      };
-      // Commute if all BC0's ops are contained in BC1.
-      if (ContainsOps(BC1, BC0.getOperand(0)) &&
-          ContainsOps(BC1, BC0.getOperand(1))) {
-        ShuffleVectorSDNode::commuteMask(TargetMask);
-        std::swap(BC0, BC1);
-      }
-      // If BC1 can be represented by BC0, then convert to unary shuffle.
-      if (ContainsOps(BC0, BC1.getOperand(0)) &&
-          ContainsOps(BC0, BC1.getOperand(1))) {
-        for (int &M : TargetMask) {
-          if (M < NumElts) // BC0 element or UNDEF/Zero sentinel.
-            continue;
-          int SubLane = ((M % NumEltsPerLane) >= NumHalfEltsPerLane) ? 1 : 0;
-          M -= NumElts + (SubLane * NumHalfEltsPerLane);
-          if (BC1.getOperand(SubLane) != BC0.getOperand(0))
-            M += NumHalfEltsPerLane;
-        }
-      }
-    }
-
-    // Canonicalize unary horizontal ops to only refer to lower halves.
-    for (int i = 0; i != NumElts; ++i) {
-      int &M = TargetMask[i];
-      if (isUndefOrZero(M))
-        continue;
-      if (M < NumElts && BC0.getOperand(0) == BC0.getOperand(1) &&
-          (M % NumEltsPerLane) >= NumHalfEltsPerLane)
-        M -= NumHalfEltsPerLane;
-      if (NumElts <= M && BC1.getOperand(0) == BC1.getOperand(1) &&
-          (M % NumEltsPerLane) >= NumHalfEltsPerLane)
-        M -= NumHalfEltsPerLane;
-    }
-  }
-
-  SmallVector<int, 16> TargetMask128, WideMask128;
-  if (isRepeatedTargetShuffleMask(128, VT, TargetMask, TargetMask128) &&
-      scaleShuffleElements(TargetMask128, 2, WideMask128)) {
-    assert(isUndefOrZeroOrInRange(WideMask128, 0, 4) && "Illegal shuffle");
-    bool SingleOp = (TargetOps.size() == 1);
-    if (!isHoriz || shouldUseHorizontalOp(SingleOp, DAG, Subtarget)) {
-      SDValue Lo = isInRange(WideMask128[0], 0, 2) ? BC0 : BC1;
-      SDValue Hi = isInRange(WideMask128[1], 0, 2) ? BC0 : BC1;
-      Lo = Lo.getOperand(WideMask128[0] & 1);
-      Hi = Hi.getOperand(WideMask128[1] & 1);
-      if (SingleOp) {
-        MVT SrcVT = BC0.getOperand(0).getSimpleValueType();
-        SDValue Undef = DAG.getUNDEF(SrcVT);
-        SDValue Zero = getZeroVector(SrcVT, Subtarget, DAG, DL);
-        Lo = (WideMask128[0] == SM_SentinelZero ? Zero : Lo);
-        Hi = (WideMask128[1] == SM_SentinelZero ? Zero : Hi);
-        Lo = (WideMask128[0] == SM_SentinelUndef ? Undef : Lo);
-        Hi = (WideMask128[1] == SM_SentinelUndef ? Undef : Hi);
-      }
-      SDValue Horiz = DAG.getNode(Opcode0, DL, VT0, Lo, Hi);
-      return DAG.getBitcast(VT, Horiz);
-    }
-  }
-
-  return SDValue();
-}
-
 // Attempt to commute shufps LHS loads:
 // permilps(shufps(load(),x)) --> permilps(shufps(x,load()))
 static SDValue combineCommutableSHUFP(SDValue N, MVT VT, const SDLoc &DL,
@@ -36069,9 +36072,6 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
   SmallVector<int, 4> Mask;
   unsigned Opcode = N.getOpcode();
 
-  if (SDValue R = combineShuffleWithHorizOp(N, VT, DL, DAG, Subtarget))
-    return R;
-
   if (SDValue R = combineCommutableSHUFP(N, VT, DL, DAG))
     return R;
 
index 4f7528b..f78f55c 100644 (file)
@@ -889,9 +889,6 @@ define <4 x float> @PR34724_1(<4 x float> %a, <4 x float> %b) {
 ; SSSE3_FAST-LABEL: PR34724_1:
 ; SSSE3_FAST:       # %bb.0:
 ; SSSE3_FAST-NEXT:    haddps %xmm1, %xmm0
-; SSSE3_FAST-NEXT:    haddps %xmm1, %xmm1
-; SSSE3_FAST-NEXT:    shufps {{.*#+}} xmm1 = xmm1[3,0],xmm0[2,0]
-; SSSE3_FAST-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,0]
 ; SSSE3_FAST-NEXT:    retq
 ;
 ; AVX1_SLOW-LABEL: PR34724_1:
@@ -942,9 +939,6 @@ define <4 x float> @PR34724_2(<4 x float> %a, <4 x float> %b) {
 ; SSSE3_FAST-LABEL: PR34724_2:
 ; SSSE3_FAST:       # %bb.0:
 ; SSSE3_FAST-NEXT:    haddps %xmm1, %xmm0
-; SSSE3_FAST-NEXT:    haddps %xmm1, %xmm1
-; SSSE3_FAST-NEXT:    shufps {{.*#+}} xmm1 = xmm1[3,0],xmm0[2,0]
-; SSSE3_FAST-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,0]
 ; SSSE3_FAST-NEXT:    retq
 ;
 ; AVX1_SLOW-LABEL: PR34724_2:
index 90b5bc4..cb0eea1 100644 (file)
@@ -201,7 +201,7 @@ define <4 x float> @test8_undef(<4 x float> %a, <4 x float> %b) {
 ; SSE-FAST-LABEL: test8_undef:
 ; SSE-FAST:       # %bb.0:
 ; SSE-FAST-NEXT:    haddps %xmm0, %xmm0
-; SSE-FAST-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,1,1,3]
+; SSE-FAST-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,1,1,1]
 ; SSE-FAST-NEXT:    retq
 ;
 ; AVX-SLOW-LABEL: test8_undef:
@@ -588,11 +588,8 @@ define <4 x float> @add_ps_016(<4 x float> %0, <4 x float> %1) {
 ;
 ; SSE-FAST-LABEL: add_ps_016:
 ; SSE-FAST:       # %bb.0:
-; SSE-FAST-NEXT:    movaps %xmm1, %xmm2
-; SSE-FAST-NEXT:    haddps %xmm0, %xmm2
-; SSE-FAST-NEXT:    haddps %xmm1, %xmm1
-; SSE-FAST-NEXT:    shufps {{.*#+}} xmm1 = xmm1[3,1],xmm2[0,0]
-; SSE-FAST-NEXT:    shufps {{.*#+}} xmm1 = xmm1[0,2],xmm2[3,3]
+; SSE-FAST-NEXT:    haddps %xmm0, %xmm1
+; SSE-FAST-NEXT:    shufps {{.*#+}} xmm1 = xmm1[1,0,3,3]
 ; SSE-FAST-NEXT:    movaps %xmm1, %xmm0
 ; SSE-FAST-NEXT:    retq
 ;
@@ -608,9 +605,7 @@ define <4 x float> @add_ps_016(<4 x float> %0, <4 x float> %1) {
 ; AVX-FAST-LABEL: add_ps_016:
 ; AVX-FAST:       # %bb.0:
 ; AVX-FAST-NEXT:    vhaddps %xmm0, %xmm1, %xmm0
-; AVX-FAST-NEXT:    vhaddps %xmm1, %xmm1, %xmm1
-; AVX-FAST-NEXT:    vshufps {{.*#+}} xmm0 = xmm0[0,3],xmm1[3,3]
-; AVX-FAST-NEXT:    vpermilps {{.*#+}} xmm0 = xmm0[2,0,1,3]
+; AVX-FAST-NEXT:    vpermilps {{.*#+}} xmm0 = xmm0[1,0,3,1]
 ; AVX-FAST-NEXT:    retq
   %3 = shufflevector <4 x float> %1, <4 x float> %0, <2 x i32> <i32 0, i32 6>
   %4 = shufflevector <4 x float> %1, <4 x float> %0, <2 x i32> <i32 1, i32 7>
@@ -662,11 +657,23 @@ define <4 x float> @add_ps_018(<4 x float> %x) {
 ; SSE-NEXT:    movsldup {{.*#+}} xmm0 = xmm0[0,0,2,2]
 ; SSE-NEXT:    retq
 ;
-; AVX-LABEL: add_ps_018:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vhaddps %xmm0, %xmm0, %xmm0
-; AVX-NEXT:    vmovsldup {{.*#+}} xmm0 = xmm0[0,0,2,2]
-; AVX-NEXT:    retq
+; AVX1-SLOW-LABEL: add_ps_018:
+; AVX1-SLOW:       # %bb.0:
+; AVX1-SLOW-NEXT:    vhaddps %xmm0, %xmm0, %xmm0
+; AVX1-SLOW-NEXT:    vmovsldup {{.*#+}} xmm0 = xmm0[0,0,2,2]
+; AVX1-SLOW-NEXT:    retq
+;
+; AVX1-FAST-LABEL: add_ps_018:
+; AVX1-FAST:       # %bb.0:
+; AVX1-FAST-NEXT:    vhaddps %xmm0, %xmm0, %xmm0
+; AVX1-FAST-NEXT:    vmovsldup {{.*#+}} xmm0 = xmm0[0,0,2,2]
+; AVX1-FAST-NEXT:    retq
+;
+; AVX512-LABEL: add_ps_018:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vhaddps %xmm0, %xmm0, %xmm0
+; AVX512-NEXT:    vbroadcastss %xmm0, %xmm0
+; AVX512-NEXT:    retq
   %l = shufflevector <4 x float> %x, <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 0, i32 undef>
   %r = shufflevector <4 x float> %x, <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 1, i32 undef>
   %add = fadd <4 x float> %l, %r
@@ -1011,9 +1018,6 @@ define <4 x float> @PR34724_add_v4f32_u123(<4 x float> %0, <4 x float> %1) {
 ; SSE-FAST-LABEL: PR34724_add_v4f32_u123:
 ; SSE-FAST:       # %bb.0:
 ; SSE-FAST-NEXT:    haddps %xmm1, %xmm0
-; SSE-FAST-NEXT:    haddps %xmm1, %xmm1
-; SSE-FAST-NEXT:    shufps {{.*#+}} xmm1 = xmm1[3,0],xmm0[2,0]
-; SSE-FAST-NEXT:    shufps {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,0]
 ; SSE-FAST-NEXT:    retq
 ;
 ; AVX-SLOW-LABEL: PR34724_add_v4f32_u123:
index 163631c..911fa8d 100644 (file)
@@ -507,7 +507,7 @@ define <8 x i16> @phaddw_single_source2(<8 x i16> %x) {
 ; AVX2-SHUF-LABEL: phaddw_single_source2:
 ; AVX2-SHUF:       # %bb.0:
 ; AVX2-SHUF-NEXT:    vphaddw %xmm0, %xmm0, %xmm0
-; AVX2-SHUF-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[10,11,8,9,10,11,8,9,10,11,8,9,10,11,8,9]
+; AVX2-SHUF-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[2,3,0,1,2,3,0,1,2,3,0,1,2,3,0,1]
 ; AVX2-SHUF-NEXT:    retq
   %l = shufflevector <8 x i16> %x, <8 x i16> undef, <8 x i32> <i32 undef, i32 undef, i32 undef, i32 undef, i32 0, i32 2, i32 4, i32 6>
   %r = shufflevector <8 x i16> %x, <8 x i16> undef, <8 x i32> <i32 undef, i32 undef, i32 undef, i32 undef, i32 1, i32 3, i32 5, i32 7>