[ARM] Remove reduce(shuffle) if all the lanes are used
authorDavid Green <david.green@arm.com>
Tue, 7 Feb 2023 10:44:35 +0000 (10:44 +0000)
committerDavid Green <david.green@arm.com>
Tue, 7 Feb 2023 10:44:35 +0000 (10:44 +0000)
This looks for vaddv(shuffle) or vmlav(shuffle, shuffle), with a shuffle where
all the lanes are used once. Due to the reduction being commutative the shuffle
can be removed.

Differential Revision: https://reviews.llvm.org/D143382

llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/Thumb2/mve-vecreduce-add-combine.ll

index 24ac30c9648d1297268c083c74c20a0370743b98..fa5175d726691e19f3de67625606328eb2215c67 100644 (file)
@@ -17124,6 +17124,42 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Looks for vaddv(shuffle) or vmlav(shuffle, shuffle), with a shuffle where all
+// the lanes are used. Due to the reduction being commutative the shuffle can be
+// removed.
+static SDValue PerformReduceShuffleCombine(SDNode *N, SelectionDAG &DAG) {
+  unsigned VecOp = N->getOperand(0).getValueType().isVector() ? 0 : 2;
+  auto *Shuf = dyn_cast<ShuffleVectorSDNode>(N->getOperand(VecOp));
+  if (!Shuf || !Shuf->getOperand(1).isUndef())
+    return SDValue();
+
+  // Check all elements are used once in the mask.
+  ArrayRef<int> Mask = Shuf->getMask();
+  APInt SetElts(Mask.size(), 0);
+  for (int E : Mask) {
+    if (E < 0 || E >= (int)Mask.size())
+      return SDValue();
+    SetElts |= 1 << E;
+  }
+  if (!SetElts.isAllOnes())
+    return SDValue();
+
+  if (N->getNumOperands() != VecOp + 1) {
+    auto *Shuf2 = dyn_cast<ShuffleVectorSDNode>(N->getOperand(VecOp + 1));
+    if (!Shuf2 || !Shuf2->getOperand(1).isUndef() || Shuf2->getMask() != Mask)
+      return SDValue();
+  }
+
+  SmallVector<SDValue> Ops;
+  for (SDValue Op : N->ops()) {
+    if (Op.getValueType().isVector())
+      Ops.push_back(Op.getOperand(0));
+    else
+      Ops.push_back(Op);
+  }
+  return DAG.getNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops);
+}
+
 static SDValue PerformVMOVNCombine(SDNode *N,
                                    TargetLowering::DAGCombinerInfo &DCI) {
   SDValue Op0 = N->getOperand(0);
@@ -18724,6 +18760,19 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
     return PerformVCMPCombine(N, DCI.DAG, Subtarget);
   case ISD::VECREDUCE_ADD:
     return PerformVECREDUCE_ADDCombine(N, DCI.DAG, Subtarget);
+  case ARMISD::VADDVs:
+  case ARMISD::VADDVu:
+  case ARMISD::VADDLVs:
+  case ARMISD::VADDLVu:
+  case ARMISD::VADDLVAs:
+  case ARMISD::VADDLVAu:
+  case ARMISD::VMLAVs:
+  case ARMISD::VMLAVu:
+  case ARMISD::VMLALVs:
+  case ARMISD::VMLALVu:
+  case ARMISD::VMLALVAs:
+  case ARMISD::VMLALVAu:
+    return PerformReduceShuffleCombine(N, DCI.DAG);
   case ARMISD::VMOVN:
     return PerformVMOVNCombine(N, DCI);
   case ARMISD::VQMOVNs:
index 6454310f85e4c297a61218b30c802ba030fa8039..8a8b6c5b6ea20b0b4b0e54ba868e6ffd385d0acc 100644 (file)
@@ -103,39 +103,7 @@ entry:
 define arm_aapcs_vfpcc i16 @vaddv_shuffle_v16i8(<16 x i8> %s0) {
 ; CHECK-LABEL: vaddv_shuffle_v16i8:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.u8 r0, q0[0]
-; CHECK-NEXT:    vmov.8 q1[0], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[2]
-; CHECK-NEXT:    vmov.8 q1[1], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[4]
-; CHECK-NEXT:    vmov.8 q1[2], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[6]
-; CHECK-NEXT:    vmov.8 q1[3], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[8]
-; CHECK-NEXT:    vmov.8 q1[4], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[10]
-; CHECK-NEXT:    vmov.8 q1[5], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[12]
-; CHECK-NEXT:    vmov.8 q1[6], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[14]
-; CHECK-NEXT:    vmov.8 q1[7], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[1]
-; CHECK-NEXT:    vmov.8 q1[8], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[3]
-; CHECK-NEXT:    vmov.8 q1[9], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[5]
-; CHECK-NEXT:    vmov.8 q1[10], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[7]
-; CHECK-NEXT:    vmov.8 q1[11], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[9]
-; CHECK-NEXT:    vmov.8 q1[12], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[11]
-; CHECK-NEXT:    vmov.8 q1[13], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[13]
-; CHECK-NEXT:    vmov.8 q1[14], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[15]
-; CHECK-NEXT:    vmov.8 q1[15], r0
-; CHECK-NEXT:    vaddv.u8 r0, q1
+; CHECK-NEXT:    vaddv.u8 r0, q0
 ; CHECK-NEXT:    bx lr
 entry:
   %s2 = shufflevector <16 x i8> %s0, <16 x i8> %s0, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
@@ -232,11 +200,7 @@ entry:
 define arm_aapcs_vfpcc i64 @vaddv_shuffle_v4i32_long(<4 x i32> %s0) {
 ; CHECK-LABEL: vaddv_shuffle_v4i32_long:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.f32 s4, s3
-; CHECK-NEXT:    vmov.f32 s5, s2
-; CHECK-NEXT:    vmov.f32 s6, s1
-; CHECK-NEXT:    vmov.f32 s7, s0
-; CHECK-NEXT:    vaddlv.u32 r0, r1, q1
+; CHECK-NEXT:    vaddlv.u32 r0, r1, q0
 ; CHECK-NEXT:    bx lr
 entry:
   %s2 = shufflevector <4 x i32> %s0, <4 x i32> %s0, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -248,11 +212,7 @@ entry:
 define arm_aapcs_vfpcc i64 @vaddv_shuffle_v4i32_long_a(<4 x i32> %s0, i64 %a) {
 ; CHECK-LABEL: vaddv_shuffle_v4i32_long_a:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.f32 s4, s3
-; CHECK-NEXT:    vmov.f32 s5, s2
-; CHECK-NEXT:    vmov.f32 s6, s1
-; CHECK-NEXT:    vmov.f32 s7, s0
-; CHECK-NEXT:    vaddlva.u32 r0, r1, q1
+; CHECK-NEXT:    vaddlva.u32 r0, r1, q0
 ; CHECK-NEXT:    bx lr
 entry:
   %s2 = shufflevector <4 x i32> %s0, <4 x i32> %s0, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -265,71 +225,7 @@ entry:
 define arm_aapcs_vfpcc i16 @vmla_shuffle_v16i8(<16 x i8> %s0, <16 x i8> %s0b) {
 ; CHECK-LABEL: vmla_shuffle_v16i8:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.u8 r0, q1[0]
-; CHECK-NEXT:    vmov.8 q2[0], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[2]
-; CHECK-NEXT:    vmov.8 q2[1], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[4]
-; CHECK-NEXT:    vmov.8 q2[2], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[6]
-; CHECK-NEXT:    vmov.8 q2[3], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[8]
-; CHECK-NEXT:    vmov.8 q2[4], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[10]
-; CHECK-NEXT:    vmov.8 q2[5], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[12]
-; CHECK-NEXT:    vmov.8 q2[6], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[14]
-; CHECK-NEXT:    vmov.8 q2[7], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[1]
-; CHECK-NEXT:    vmov.8 q2[8], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[3]
-; CHECK-NEXT:    vmov.8 q2[9], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[5]
-; CHECK-NEXT:    vmov.8 q2[10], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[7]
-; CHECK-NEXT:    vmov.8 q2[11], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[9]
-; CHECK-NEXT:    vmov.8 q2[12], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[11]
-; CHECK-NEXT:    vmov.8 q2[13], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[13]
-; CHECK-NEXT:    vmov.8 q2[14], r0
-; CHECK-NEXT:    vmov.u8 r0, q1[15]
-; CHECK-NEXT:    vmov.8 q2[15], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[0]
-; CHECK-NEXT:    vmov.8 q1[0], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[2]
-; CHECK-NEXT:    vmov.8 q1[1], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[4]
-; CHECK-NEXT:    vmov.8 q1[2], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[6]
-; CHECK-NEXT:    vmov.8 q1[3], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[8]
-; CHECK-NEXT:    vmov.8 q1[4], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[10]
-; CHECK-NEXT:    vmov.8 q1[5], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[12]
-; CHECK-NEXT:    vmov.8 q1[6], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[14]
-; CHECK-NEXT:    vmov.8 q1[7], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[1]
-; CHECK-NEXT:    vmov.8 q1[8], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[3]
-; CHECK-NEXT:    vmov.8 q1[9], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[5]
-; CHECK-NEXT:    vmov.8 q1[10], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[7]
-; CHECK-NEXT:    vmov.8 q1[11], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[9]
-; CHECK-NEXT:    vmov.8 q1[12], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[11]
-; CHECK-NEXT:    vmov.8 q1[13], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[13]
-; CHECK-NEXT:    vmov.8 q1[14], r0
-; CHECK-NEXT:    vmov.u8 r0, q0[15]
-; CHECK-NEXT:    vmov.8 q1[15], r0
-; CHECK-NEXT:    vmlav.s8 r0, q1, q2
+; CHECK-NEXT:    vmlav.s8 r0, q0, q1
 ; CHECK-NEXT:    bx lr
 entry:
   %s2a = shufflevector <16 x i8> %s0, <16 x i8> %s0, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
@@ -423,15 +319,7 @@ entry:
 define arm_aapcs_vfpcc i64 @vmla_shuffle_v4i32_long(<4 x i32> %s0, <4 x i32> %s0b) {
 ; CHECK-LABEL: vmla_shuffle_v4i32_long:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.f32 s8, s7
-; CHECK-NEXT:    vmov.f32 s9, s6
-; CHECK-NEXT:    vmov.f32 s10, s5
-; CHECK-NEXT:    vmov.f32 s11, s4
-; CHECK-NEXT:    vmov.f32 s4, s3
-; CHECK-NEXT:    vmov.f32 s5, s2
-; CHECK-NEXT:    vmov.f32 s6, s1
-; CHECK-NEXT:    vmov.f32 s7, s0
-; CHECK-NEXT:    vmlalv.u32 r0, r1, q1, q2
+; CHECK-NEXT:    vmlalv.u32 r0, r1, q0, q1
 ; CHECK-NEXT:    bx lr
 entry:
   %s2a = shufflevector <4 x i32> %s0, <4 x i32> %s0, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -446,15 +334,7 @@ entry:
 define arm_aapcs_vfpcc i64 @vmla_shuffle_v4i32_long_a(<4 x i32> %s0, <4 x i32> %s0b, i64 %a) {
 ; CHECK-LABEL: vmla_shuffle_v4i32_long_a:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmov.f32 s8, s7
-; CHECK-NEXT:    vmov.f32 s9, s6
-; CHECK-NEXT:    vmov.f32 s10, s5
-; CHECK-NEXT:    vmov.f32 s11, s4
-; CHECK-NEXT:    vmov.f32 s4, s3
-; CHECK-NEXT:    vmov.f32 s5, s2
-; CHECK-NEXT:    vmov.f32 s6, s1
-; CHECK-NEXT:    vmov.f32 s7, s0
-; CHECK-NEXT:    vmlalva.u32 r0, r1, q1, q2
+; CHECK-NEXT:    vmlalva.u32 r0, r1, q0, q1
 ; CHECK-NEXT:    bx lr
 entry:
   %s2a = shufflevector <4 x i32> %s0, <4 x i32> %s0, <4 x i32> <i32 3, i32 2, i32 1, i32 0>