[ARM] Move add(VMLALVA(A, X, Y), B) to VMLALVA(add(A, B), X, Y)
authorDavid Green <david.green@arm.com>
Wed, 14 Jul 2021 19:06:49 +0000 (20:06 +0100)
committerDavid Green <david.green@arm.com>
Wed, 14 Jul 2021 19:06:49 +0000 (20:06 +0100)
For i64 reductions we currently try and convert add(VMLALV(X, Y), B) to
VMLALVA(B, X, Y), incorporating the addition into the VMLALVA. If we
have an add of an existing VMLALVA, this patch pushes the add up above
the VMLALVA so that it may potentially be simplified further, for
example being folded into another VMLALV.

Differential Revision: https://reviews.llvm.org/D105686

llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/Thumb2/mve-vecreduce-mla.ll

index 5894afb..63959ee 100644 (file)
@@ -13056,24 +13056,35 @@ static SDValue PerformADDVecReduce(SDNode *N,
   //   t1: i32,i32 = ARMISD::VADDLVs x
   //   t2: i64 = build_pair t1, t1:1
   //   t3: i64 = add t2, y
+  // Otherwise we try to push the add up above VADDLVAx, to potentially allow
+  // the add to be simplified seperately.
   // We also need to check for sext / zext and commutitive adds.
   auto MakeVecReduce = [&](unsigned Opcode, unsigned OpcodeA, SDValue NA,
                            SDValue NB) {
     if (NB->getOpcode() != ISD::BUILD_PAIR)
       return SDValue();
     SDValue VecRed = NB->getOperand(0);
-    if (VecRed->getOpcode() != Opcode || VecRed.getResNo() != 0 ||
+    if ((VecRed->getOpcode() != Opcode && VecRed->getOpcode() != OpcodeA) ||
+        VecRed.getResNo() != 0 ||
         NB->getOperand(1) != SDValue(VecRed.getNode(), 1))
       return SDValue();
 
     SDLoc dl(N);
+    if (VecRed->getOpcode() == OpcodeA) {
+      // add(NA, VADDLVA(Inp), Y) -> VADDLVA(add(NA, Inp), Y)
+      SDValue Inp = DCI.DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64,
+                                    VecRed.getOperand(0), VecRed.getOperand(1));
+      NA = DCI.DAG.getNode(ISD::ADD, dl, MVT::i64, Inp, NA);
+    }
+
     SmallVector<SDValue, 4> Ops;
     Ops.push_back(DCI.DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA,
                                   DCI.DAG.getConstant(0, dl, MVT::i32)));
     Ops.push_back(DCI.DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA,
                                   DCI.DAG.getConstant(1, dl, MVT::i32)));
-    for (unsigned i = 0, e = VecRed.getNumOperands(); i < e; i++)
-      Ops.push_back(VecRed->getOperand(i));
+    unsigned S = VecRed->getOpcode() == OpcodeA ? 2 : 0;
+    for (unsigned I = S, E = VecRed.getNumOperands(); I < E; I++)
+      Ops.push_back(VecRed->getOperand(I));
     SDValue Red = DCI.DAG.getNode(OpcodeA, dl,
                                   DCI.DAG.getVTList({MVT::i32, MVT::i32}), Ops);
     return DCI.DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Red,
index 44222e5..14ce320 100644 (file)
@@ -1224,24 +1224,20 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v16i8_v16i64_acc_zext(<16 x i8> %x, <16 x i8> %y, i64 %a) {
 ; CHECK-LABEL: add_v16i8_v16i64_acc_zext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r4, r5, r7, lr}
-; CHECK-NEXT:    push {r4, r5, r7, lr}
 ; CHECK-NEXT:    .pad #32
 ; CHECK-NEXT:    sub sp, #32
-; CHECK-NEXT:    add r4, sp, #16
+; CHECK-NEXT:    add r2, sp, #16
 ; CHECK-NEXT:    mov r3, sp
-; CHECK-NEXT:    vstrw.32 q1, [r4]
+; CHECK-NEXT:    vstrw.32 q1, [r2]
 ; CHECK-NEXT:    vstrw.32 q0, [r3]
-; CHECK-NEXT:    vldrb.u16 q0, [r4]
+; CHECK-NEXT:    vldrb.u16 q0, [r2]
 ; CHECK-NEXT:    vldrb.u16 q1, [r3]
-; CHECK-NEXT:    vmlalv.u16 r2, r5, q1, q0
-; CHECK-NEXT:    vldrb.u16 q0, [r4, #8]
+; CHECK-NEXT:    vmlalva.u16 r0, r1, q1, q0
+; CHECK-NEXT:    vldrb.u16 q0, [r2, #8]
 ; CHECK-NEXT:    vldrb.u16 q1, [r3, #8]
-; CHECK-NEXT:    vmlalva.u16 r2, r5, q1, q0
-; CHECK-NEXT:    adds r0, r0, r2
-; CHECK-NEXT:    adcs r1, r5
+; CHECK-NEXT:    vmlalva.u16 r0, r1, q1, q0
 ; CHECK-NEXT:    add sp, #32
-; CHECK-NEXT:    pop {r4, r5, r7, pc}
+; CHECK-NEXT:    bx lr
 entry:
   %xx = zext <16 x i8> %x to <16 x i64>
   %yy = zext <16 x i8> %y to <16 x i64>
@@ -1254,24 +1250,20 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v16i8_v16i64_acc_sext(<16 x i8> %x, <16 x i8> %y, i64 %a) {
 ; CHECK-LABEL: add_v16i8_v16i64_acc_sext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r4, r5, r7, lr}
-; CHECK-NEXT:    push {r4, r5, r7, lr}
 ; CHECK-NEXT:    .pad #32
 ; CHECK-NEXT:    sub sp, #32
-; CHECK-NEXT:    add r4, sp, #16
+; CHECK-NEXT:    add r2, sp, #16
 ; CHECK-NEXT:    mov r3, sp
-; CHECK-NEXT:    vstrw.32 q1, [r4]
+; CHECK-NEXT:    vstrw.32 q1, [r2]
 ; CHECK-NEXT:    vstrw.32 q0, [r3]
-; CHECK-NEXT:    vldrb.s16 q0, [r4]
+; CHECK-NEXT:    vldrb.s16 q0, [r2]
 ; CHECK-NEXT:    vldrb.s16 q1, [r3]
-; CHECK-NEXT:    vmlalv.s16 r2, r5, q1, q0
-; CHECK-NEXT:    vldrb.s16 q0, [r4, #8]
+; CHECK-NEXT:    vmlalva.s16 r0, r1, q1, q0
+; CHECK-NEXT:    vldrb.s16 q0, [r2, #8]
 ; CHECK-NEXT:    vldrb.s16 q1, [r3, #8]
-; CHECK-NEXT:    vmlalva.s16 r2, r5, q1, q0
-; CHECK-NEXT:    adds r0, r0, r2
-; CHECK-NEXT:    adcs r1, r5
+; CHECK-NEXT:    vmlalva.s16 r0, r1, q1, q0
 ; CHECK-NEXT:    add sp, #32
-; CHECK-NEXT:    pop {r4, r5, r7, pc}
+; CHECK-NEXT:    bx lr
 entry:
   %xx = sext <16 x i8> %x to <16 x i64>
   %yy = sext <16 x i8> %y to <16 x i64>
@@ -1284,17 +1276,15 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v16i8_v16i64_acc_zext_load(<16 x i8> *%xp, <16 x i8> *%yp, i64 %a) {
 ; CHECK-LABEL: add_v16i8_v16i64_acc_zext_load:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r5, lr}
-; CHECK-NEXT:    push {r5, lr}
 ; CHECK-NEXT:    vldrb.u16 q0, [r1]
 ; CHECK-NEXT:    vldrb.u16 q1, [r0]
-; CHECK-NEXT:    vmlalv.u16 r12, r5, q1, q0
+; CHECK-NEXT:    vmlalva.u16 r2, r3, q1, q0
 ; CHECK-NEXT:    vldrb.u16 q0, [r1, #8]
 ; CHECK-NEXT:    vldrb.u16 q1, [r0, #8]
-; CHECK-NEXT:    vmlalva.u16 r12, r5, q1, q0
-; CHECK-NEXT:    adds.w r0, r12, r2
-; CHECK-NEXT:    adc.w r1, r5, r3
-; CHECK-NEXT:    pop {r5, pc}
+; CHECK-NEXT:    vmlalva.u16 r2, r3, q1, q0
+; CHECK-NEXT:    mov r0, r2
+; CHECK-NEXT:    mov r1, r3
+; CHECK-NEXT:    bx lr
 entry:
   %x = load <16 x i8>, <16 x i8>* %xp
   %y = load <16 x i8>, <16 x i8>* %yp
@@ -1309,17 +1299,15 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v16i8_v16i64_acc_sext_load(<16 x i8> *%xp, <16 x i8> *%yp, i64 %a) {
 ; CHECK-LABEL: add_v16i8_v16i64_acc_sext_load:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r5, lr}
-; CHECK-NEXT:    push {r5, lr}
 ; CHECK-NEXT:    vldrb.s16 q0, [r1]
 ; CHECK-NEXT:    vldrb.s16 q1, [r0]
-; CHECK-NEXT:    vmlalv.s16 r12, r5, q1, q0
+; CHECK-NEXT:    vmlalva.s16 r2, r3, q1, q0
 ; CHECK-NEXT:    vldrb.s16 q0, [r1, #8]
 ; CHECK-NEXT:    vldrb.s16 q1, [r0, #8]
-; CHECK-NEXT:    vmlalva.s16 r12, r5, q1, q0
-; CHECK-NEXT:    adds.w r0, r12, r2
-; CHECK-NEXT:    adc.w r1, r5, r3
-; CHECK-NEXT:    pop {r5, pc}
+; CHECK-NEXT:    vmlalva.s16 r2, r3, q1, q0
+; CHECK-NEXT:    mov r0, r2
+; CHECK-NEXT:    mov r1, r3
+; CHECK-NEXT:    bx lr
 entry:
   %x = load <16 x i8>, <16 x i8>* %xp
   %y = load <16 x i8>, <16 x i8>* %yp