[ARM] Remove FlattenVectorShuffle and add PerformVQDMULHCombine.
authorDavid Green <david.green@arm.com>
Sun, 5 Feb 2023 20:59:49 +0000 (20:59 +0000)
committerDavid Green <david.green@arm.com>
Sun, 5 Feb 2023 20:59:49 +0000 (20:59 +0000)
This removes the FlattenVectorShuffle that folds shuffles through certain
binops. This is now handled by generic DAG combines for all but ARMISD::VQDMULH
where a PerformVQDMULHCombine is added to compensate. It pushes identical
shuffles down through the operation, in a similar way to the other combines in
DAG.

llvm/lib/Target/ARM/ARMISelLowering.cpp

index 07fa829..24ac30c 100644 (file)
@@ -15468,51 +15468,6 @@ static SDValue PerformSignExtendInregCombine(SDNode *N, SelectionDAG &DAG) {
   return SDValue();
 }
 
-// When lowering complex nodes that we recognize, like VQDMULH and MULH, we
-// can end up with shuffle(binop(shuffle, shuffle)), that can be simplified to
-// binop as the shuffles cancel out.
-static SDValue FlattenVectorShuffle(ShuffleVectorSDNode *N, SelectionDAG &DAG) {
-  EVT VT = N->getValueType(0);
-  if (!N->getOperand(1).isUndef() || N->getOperand(0).getValueType() != VT)
-    return SDValue();
-  SDValue Op = N->getOperand(0);
-
-  // Looking for binary operators that will have been folded from
-  // truncates/extends.
-  switch (Op.getOpcode()) {
-  case ARMISD::VQDMULH:
-  case ISD::MULHS:
-  case ISD::MULHU:
-  case ISD::ABDS:
-  case ISD::ABDU:
-  case ISD::AVGFLOORS:
-  case ISD::AVGFLOORU:
-  case ISD::AVGCEILS:
-  case ISD::AVGCEILU:
-    break;
-  default:
-    return SDValue();
-  }
-
-  ShuffleVectorSDNode *Op0 = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(0));
-  ShuffleVectorSDNode *Op1 = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(1));
-  if (!Op0 || !Op1 || !Op0->getOperand(1).isUndef() ||
-      !Op1->getOperand(1).isUndef() || Op0->getMask() != Op1->getMask() ||
-      Op0->getOperand(0).getValueType() != VT)
-    return SDValue();
-
-  // Check the mask turns into an identity shuffle.
-  ArrayRef<int> NMask = N->getMask();
-  ArrayRef<int> OpMask = Op0->getMask();
-  for (int i = 0, e = NMask.size(); i != e; i++) {
-    if (NMask[i] > 0 && OpMask[NMask[i]] > 0 && OpMask[NMask[i]] != i)
-      return SDValue();
-  }
-
-  return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
-                     Op0->getOperand(0), Op1->getOperand(0));
-}
-
 static SDValue
 PerformInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   SDValue Vec = N->getOperand(0);
@@ -15581,8 +15536,6 @@ static SDValue PerformShuffleVMOVNCombine(ShuffleVectorSDNode *N,
 /// PerformVECTOR_SHUFFLECombine - Target-specific dag combine xforms for
 /// ISD::VECTOR_SHUFFLE.
 static SDValue PerformVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG) {
-  if (SDValue R = FlattenVectorShuffle(cast<ShuffleVectorSDNode>(N), DAG))
-    return R;
   if (SDValue R = PerformShuffleVMOVNCombine(cast<ShuffleVectorSDNode>(N), DAG))
     return R;
 
@@ -17227,6 +17180,27 @@ static SDValue PerformVQMOVNCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue PerformVQDMULHCombine(SDNode *N,
+                                     TargetLowering::DAGCombinerInfo &DCI) {
+  EVT VT = N->getValueType(0);
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+
+  auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
+  auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
+  // Turn VQDMULH(shuffle, shuffle) -> shuffle(VQDMULH)
+  if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
+      LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
+      (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
+    SDLoc DL(N);
+    SDValue NewBinOp = DCI.DAG.getNode(N->getOpcode(), DL, VT,
+                                       LHS.getOperand(0), RHS.getOperand(0));
+    SDValue UndefV = LHS.getOperand(1);
+    return DCI.DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
+  }
+  return SDValue();
+}
+
 static SDValue PerformLongShiftCombine(SDNode *N, SelectionDAG &DAG) {
   SDLoc DL(N);
   SDValue Op0 = N->getOperand(0);
@@ -18755,6 +18729,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
   case ARMISD::VQMOVNs:
   case ARMISD::VQMOVNu:
     return PerformVQMOVNCombine(N, DCI);
+  case ARMISD::VQDMULH:
+    return PerformVQDMULHCombine(N, DCI);
   case ARMISD::ASRL:
   case ARMISD::LSRL:
   case ARMISD::LSLL: