[ARM] Move add(VMLALVA(A, X, Y), B) to VMLALVA(add(A, B), X, Y)
authorDavid Green <david.green@arm.com>
Wed, 14 Jul 2021 19:06:49 +0000 (20:06 +0100)
committerDavid Green <david.green@arm.com>
Wed, 14 Jul 2021 19:06:49 +0000 (20:06 +0100)
For i64 reductions we currently try and convert add(VMLALV(X, Y), B) to
VMLALVA(B, X, Y), incorporating the addition into the VMLALVA. If we
have an add of an existing VMLALVA, this patch pushes the add up above
the VMLALVA so that it may potentially be simplified further, for
example being folded into another VMLALV.

Differential Revision: https://reviews.llvm.org/D105686

llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/Thumb2/mve-vecreduce-mla.ll

index 5894afbb578eb52c3de7b22a6dbc5cb7c1bccd99..63959eef1827ef87f730b47e16db86d1a180f954 100644 (file)
@@ -13056,24 +13056,35 @@ static SDValue PerformADDVecReduce(SDNode *N,
   //   t1: i32,i32 = ARMISD::VADDLVs x
   //   t2: i64 = build_pair t1, t1:1
   //   t3: i64 = add t2, y
+  // Otherwise we try to push the add up above VADDLVAx, to potentially allow
+  // the add to be simplified seperately.
   // We also need to check for sext / zext and commutitive adds.
   auto MakeVecReduce = [&](unsigned Opcode, unsigned OpcodeA, SDValue NA,
                            SDValue NB) {
     if (NB->getOpcode() != ISD::BUILD_PAIR)
       return SDValue();
     SDValue VecRed = NB->getOperand(0);
-    if (VecRed->getOpcode() != Opcode || VecRed.getResNo() != 0 ||
+    if ((VecRed->getOpcode() != Opcode && VecRed->getOpcode() != OpcodeA) ||
+        VecRed.getResNo() != 0 ||
         NB->getOperand(1) != SDValue(VecRed.getNode(), 1))
       return SDValue();
 
     SDLoc dl(N);
+    if (VecRed->getOpcode() == OpcodeA) {
+      // add(NA, VADDLVA(Inp), Y) -> VADDLVA(add(NA, Inp), Y)
+      SDValue Inp = DCI.DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64,
+                                    VecRed.getOperand(0), VecRed.getOperand(1));
+      NA = DCI.DAG.getNode(ISD::ADD, dl, MVT::i64, Inp, NA);
+    }
+
     SmallVector<SDValue, 4> Ops;
     Ops.push_back(DCI.DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA,
                                   DCI.DAG.getConstant(0, dl, MVT::i32)));
     Ops.push_back(DCI.DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA,
                                   DCI.DAG.getConstant(1, dl, MVT::i32)));
-    for (unsigned i = 0, e = VecRed.getNumOperands(); i < e; i++)
-      Ops.push_back(VecRed->getOperand(i));
+    unsigned S = VecRed->getOpcode() == OpcodeA ? 2 : 0;
+    for (unsigned I = S, E = VecRed.getNumOperands(); I < E; I++)
+      Ops.push_back(VecRed->getOperand(I));
     SDValue Red = DCI.DAG.getNode(OpcodeA, dl,
                                   DCI.DAG.getVTList({MVT::i32, MVT::i32}), Ops);
     return DCI.DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Red,
index 44222e58cc9bf1fd70b02a2b5a2b9ac403f3848b..14ce320da3c232b04bc24ea81a59c59a99313a48 100644 (file)
@@ -1224,24 +1224,20 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v16i8_v16i64_acc_zext(<16 x i8> %x, <16 x i8> %y, i64 %a) {
 ; CHECK-LABEL: add_v16i8_v16i64_acc_zext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r4, r5, r7, lr}
-; CHECK-NEXT:    push {r4, r5, r7, lr}
 ; CHECK-NEXT:    .pad #32
 ; CHECK-NEXT:    sub sp, #32
-; CHECK-NEXT:    add r4, sp, #16
+; CHECK-NEXT:    add r2, sp, #16
 ; CHECK-NEXT:    mov r3, sp
-; CHECK-NEXT:    vstrw.32 q1, [r4]
+; CHECK-NEXT:    vstrw.32 q1, [r2]
 ; CHECK-NEXT:    vstrw.32 q0, [r3]
-; CHECK-NEXT:    vldrb.u16 q0, [r4]
+; CHECK-NEXT:    vldrb.u16 q0, [r2]
 ; CHECK-NEXT:    vldrb.u16 q1, [r3]
-; CHECK-NEXT:    vmlalv.u16 r2, r5, q1, q0
-; CHECK-NEXT:    vldrb.u16 q0, [r4, #8]
+; CHECK-NEXT:    vmlalva.u16 r0, r1, q1, q0
+; CHECK-NEXT:    vldrb.u16 q0, [r2, #8]
 ; CHECK-NEXT:    vldrb.u16 q1, [r3, #8]
-; CHECK-NEXT:    vmlalva.u16 r2, r5, q1, q0
-; CHECK-NEXT:    adds r0, r0, r2
-; CHECK-NEXT:    adcs r1, r5
+; CHECK-NEXT:    vmlalva.u16 r0, r1, q1, q0
 ; CHECK-NEXT:    add sp, #32
-; CHECK-NEXT:    pop {r4, r5, r7, pc}
+; CHECK-NEXT:    bx lr
 entry:
   %xx = zext <16 x i8> %x to <16 x i64>
   %yy = zext <16 x i8> %y to <16 x i64>
@@ -1254,24 +1250,20 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v16i8_v16i64_acc_sext(<16 x i8> %x, <16 x i8> %y, i64 %a) {
 ; CHECK-LABEL: add_v16i8_v16i64_acc_sext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r4, r5, r7, lr}
-; CHECK-NEXT:    push {r4, r5, r7, lr}
 ; CHECK-NEXT:    .pad #32
 ; CHECK-NEXT:    sub sp, #32
-; CHECK-NEXT:    add r4, sp, #16
+; CHECK-NEXT:    add r2, sp, #16
 ; CHECK-NEXT:    mov r3, sp
-; CHECK-NEXT:    vstrw.32 q1, [r4]
+; CHECK-NEXT:    vstrw.32 q1, [r2]
 ; CHECK-NEXT:    vstrw.32 q0, [r3]
-; CHECK-NEXT:    vldrb.s16 q0, [r4]
+; CHECK-NEXT:    vldrb.s16 q0, [r2]
 ; CHECK-NEXT:    vldrb.s16 q1, [r3]
-; CHECK-NEXT:    vmlalv.s16 r2, r5, q1, q0
-; CHECK-NEXT:    vldrb.s16 q0, [r4, #8]
+; CHECK-NEXT:    vmlalva.s16 r0, r1, q1, q0
+; CHECK-NEXT:    vldrb.s16 q0, [r2, #8]
 ; CHECK-NEXT:    vldrb.s16 q1, [r3, #8]
-; CHECK-NEXT:    vmlalva.s16 r2, r5, q1, q0
-; CHECK-NEXT:    adds r0, r0, r2
-; CHECK-NEXT:    adcs r1, r5
+; CHECK-NEXT:    vmlalva.s16 r0, r1, q1, q0
 ; CHECK-NEXT:    add sp, #32
-; CHECK-NEXT:    pop {r4, r5, r7, pc}
+; CHECK-NEXT:    bx lr
 entry:
   %xx = sext <16 x i8> %x to <16 x i64>
   %yy = sext <16 x i8> %y to <16 x i64>
@@ -1284,17 +1276,15 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v16i8_v16i64_acc_zext_load(<16 x i8> *%xp, <16 x i8> *%yp, i64 %a) {
 ; CHECK-LABEL: add_v16i8_v16i64_acc_zext_load:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r5, lr}
-; CHECK-NEXT:    push {r5, lr}
 ; CHECK-NEXT:    vldrb.u16 q0, [r1]
 ; CHECK-NEXT:    vldrb.u16 q1, [r0]
-; CHECK-NEXT:    vmlalv.u16 r12, r5, q1, q0
+; CHECK-NEXT:    vmlalva.u16 r2, r3, q1, q0
 ; CHECK-NEXT:    vldrb.u16 q0, [r1, #8]
 ; CHECK-NEXT:    vldrb.u16 q1, [r0, #8]
-; CHECK-NEXT:    vmlalva.u16 r12, r5, q1, q0
-; CHECK-NEXT:    adds.w r0, r12, r2
-; CHECK-NEXT:    adc.w r1, r5, r3
-; CHECK-NEXT:    pop {r5, pc}
+; CHECK-NEXT:    vmlalva.u16 r2, r3, q1, q0
+; CHECK-NEXT:    mov r0, r2
+; CHECK-NEXT:    mov r1, r3
+; CHECK-NEXT:    bx lr
 entry:
   %x = load <16 x i8>, <16 x i8>* %xp
   %y = load <16 x i8>, <16 x i8>* %yp
@@ -1309,17 +1299,15 @@ entry:
 define arm_aapcs_vfpcc i64 @add_v16i8_v16i64_acc_sext_load(<16 x i8> *%xp, <16 x i8> *%yp, i64 %a) {
 ; CHECK-LABEL: add_v16i8_v16i64_acc_sext_load:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .save {r5, lr}
-; CHECK-NEXT:    push {r5, lr}
 ; CHECK-NEXT:    vldrb.s16 q0, [r1]
 ; CHECK-NEXT:    vldrb.s16 q1, [r0]
-; CHECK-NEXT:    vmlalv.s16 r12, r5, q1, q0
+; CHECK-NEXT:    vmlalva.s16 r2, r3, q1, q0
 ; CHECK-NEXT:    vldrb.s16 q0, [r1, #8]
 ; CHECK-NEXT:    vldrb.s16 q1, [r0, #8]
-; CHECK-NEXT:    vmlalva.s16 r12, r5, q1, q0
-; CHECK-NEXT:    adds.w r0, r12, r2
-; CHECK-NEXT:    adc.w r1, r5, r3
-; CHECK-NEXT:    pop {r5, pc}
+; CHECK-NEXT:    vmlalva.s16 r2, r3, q1, q0
+; CHECK-NEXT:    mov r0, r2
+; CHECK-NEXT:    mov r1, r3
+; CHECK-NEXT:    bx lr
 entry:
   %x = load <16 x i8>, <16 x i8>* %xp
   %y = load <16 x i8>, <16 x i8>* %yp