[ARM] Extra MVE VMLAV reduction patterns
authorDavid Green <david.green@arm.com>
Fri, 29 May 2020 13:45:08 +0000 (14:45 +0100)
committerDavid Green <david.green@arm.com>
Fri, 29 May 2020 15:23:24 +0000 (16:23 +0100)
These patterns for i8 and i16 VMLA's were missing. They end up from
legalized vector.reduce.add.v8i16 and vector.reduce.add.v16i8, and
although the instruction works differently (the mul and add are
performed in a higher precision), I believe it is OK because only an
i8/i16 are demanded from them, and so the results will be the same. At
least, they pass any testing I can think to run on them.

There are some tests that end up looking worse, but are quite artificial
due to passing half vector types through a call boundary. I would not
expect the vmull to realistically come up like that, and a vmlava is
likely better a lot of the time.

Differential Revision: https://reviews.llvm.org/D80524

llvm/lib/Target/ARM/ARMInstrMVE.td
llvm/test/CodeGen/Thumb2/mve-vecreduce-mla.ll

index a5ea45b..4f72730 100644 (file)
@@ -1019,22 +1019,32 @@ def ARMVMLALVAu      : SDNode<"ARMISD::VMLALVAu", SDTVecReduce2LA>;
 let Predicates = [HasMVEInt] in {
   def : Pat<(i32 (vecreduce_add (mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)))),
             (i32 (MVE_VMLADAVu32 $src1, $src2))>;
-  def : Pat<(i32 (ARMVMLAVs (v16i8 MQPR:$val1), (v16i8 MQPR:$val2))),
-            (i32 (MVE_VMLADAVs8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
-  def : Pat<(i32 (ARMVMLAVu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2))),
-            (i32 (MVE_VMLADAVu8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
+  def : Pat<(i32 (vecreduce_add (mul (v8i16 MQPR:$src1), (v8i16 MQPR:$src2)))),
+            (i32 (MVE_VMLADAVu16 $src1, $src2))>;
   def : Pat<(i32 (ARMVMLAVs (v8i16 MQPR:$val1), (v8i16 MQPR:$val2))),
             (i32 (MVE_VMLADAVs16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)))>;
   def : Pat<(i32 (ARMVMLAVu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2))),
             (i32 (MVE_VMLADAVu16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)))>;
+  def : Pat<(i32 (vecreduce_add (mul (v16i8 MQPR:$src1), (v16i8 MQPR:$src2)))),
+            (i32 (MVE_VMLADAVu8 $src1, $src2))>;
+  def : Pat<(i32 (ARMVMLAVs (v16i8 MQPR:$val1), (v16i8 MQPR:$val2))),
+            (i32 (MVE_VMLADAVs8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
+  def : Pat<(i32 (ARMVMLAVu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2))),
+            (i32 (MVE_VMLADAVu8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
 
   def : Pat<(i32 (add (i32 (vecreduce_add (mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)))),
                                           (i32 tGPREven:$src3))),
             (i32 (MVE_VMLADAVau32 $src3, $src1, $src2))>;
+  def : Pat<(i32 (add (i32 (vecreduce_add (mul (v8i16 MQPR:$src1), (v8i16 MQPR:$src2)))),
+                                          (i32 tGPREven:$src3))),
+            (i32 (MVE_VMLADAVau16 $src3, $src1, $src2))>;
   def : Pat<(i32 (add (ARMVMLAVs (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)), tGPREven:$Rd)),
             (i32 (MVE_VMLADAVas16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)))>;
   def : Pat<(i32 (add (ARMVMLAVu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)), tGPREven:$Rd)),
             (i32 (MVE_VMLADAVau16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)))>;
+  def : Pat<(i32 (add (i32 (vecreduce_add (mul (v16i8 MQPR:$src1), (v16i8 MQPR:$src2)))),
+                                          (i32 tGPREven:$src3))),
+            (i32 (MVE_VMLADAVau8 $src3, $src1, $src2))>;
   def : Pat<(i32 (add (ARMVMLAVs (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)), tGPREven:$Rd)),
             (i32 (MVE_VMLADAVas8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
   def : Pat<(i32 (add (ARMVMLAVu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)), tGPREven:$Rd)),
index 0716f58..67a0075 100644 (file)
@@ -135,8 +135,7 @@ entry:
 define arm_aapcs_vfpcc zeroext i16 @add_v8i16_v8i16(<8 x i16> %x, <8 x i16> %y) {
 ; CHECK-LABEL: add_v8i16_v8i16:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmul.i16 q0, q0, q1
-; CHECK-NEXT:    vaddv.u16 r0, q0
+; CHECK-NEXT:    vmlav.u16 r0, q0, q1
 ; CHECK-NEXT:    uxth r0, r0
 ; CHECK-NEXT:    bx lr
 entry:
@@ -438,8 +437,9 @@ entry:
 define arm_aapcs_vfpcc zeroext i16 @add_v8i8_v8i16_zext(<8 x i8> %x, <8 x i8> %y) {
 ; CHECK-LABEL: add_v8i8_v8i16_zext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmullb.u8 q0, q0, q1
-; CHECK-NEXT:    vaddv.u16 r0, q0
+; CHECK-NEXT:    vmovlb.u8 q1, q1
+; CHECK-NEXT:    vmovlb.u8 q0, q0
+; CHECK-NEXT:    vmlav.u16 r0, q0, q1
 ; CHECK-NEXT:    uxth r0, r0
 ; CHECK-NEXT:    bx lr
 entry:
@@ -453,8 +453,9 @@ entry:
 define arm_aapcs_vfpcc signext i16 @add_v8i8_v8i16_sext(<8 x i8> %x, <8 x i8> %y) {
 ; CHECK-LABEL: add_v8i8_v8i16_sext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmullb.s8 q0, q0, q1
-; CHECK-NEXT:    vaddv.u16 r0, q0
+; CHECK-NEXT:    vmovlb.s8 q1, q1
+; CHECK-NEXT:    vmovlb.s8 q0, q0
+; CHECK-NEXT:    vmlav.u16 r0, q0, q1
 ; CHECK-NEXT:    sxth r0, r0
 ; CHECK-NEXT:    bx lr
 entry:
@@ -468,8 +469,7 @@ entry:
 define arm_aapcs_vfpcc zeroext i8 @add_v16i8_v16i8(<16 x i8> %x, <16 x i8> %y) {
 ; CHECK-LABEL: add_v16i8_v16i8:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmul.i8 q0, q0, q1
-; CHECK-NEXT:    vaddv.u8 r0, q0
+; CHECK-NEXT:    vmlav.u8 r0, q0, q1
 ; CHECK-NEXT:    uxtb r0, r0
 ; CHECK-NEXT:    bx lr
 entry:
@@ -1086,8 +1086,7 @@ entry:
 define arm_aapcs_vfpcc zeroext i16 @add_v8i16_v8i16_acc(<8 x i16> %x, <8 x i16> %y, i16 %a) {
 ; CHECK-LABEL: add_v8i16_v8i16_acc:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmul.i16 q0, q0, q1
-; CHECK-NEXT:    vaddva.u16 r0, q0
+; CHECK-NEXT:    vmlava.u16 r0, q0, q1
 ; CHECK-NEXT:    uxth r0, r0
 ; CHECK-NEXT:    bx lr
 entry:
@@ -1408,8 +1407,9 @@ entry:
 define arm_aapcs_vfpcc zeroext i16 @add_v8i8_v8i16_acc_zext(<8 x i8> %x, <8 x i8> %y, i16 %a) {
 ; CHECK-LABEL: add_v8i8_v8i16_acc_zext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmullb.u8 q0, q0, q1
-; CHECK-NEXT:    vaddva.u16 r0, q0
+; CHECK-NEXT:    vmovlb.u8 q1, q1
+; CHECK-NEXT:    vmovlb.u8 q0, q0
+; CHECK-NEXT:    vmlava.u16 r0, q0, q1
 ; CHECK-NEXT:    uxth r0, r0
 ; CHECK-NEXT:    bx lr
 entry:
@@ -1424,8 +1424,9 @@ entry:
 define arm_aapcs_vfpcc signext i16 @add_v8i8_v8i16_acc_sext(<8 x i8> %x, <8 x i8> %y, i16 %a) {
 ; CHECK-LABEL: add_v8i8_v8i16_acc_sext:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmullb.s8 q0, q0, q1
-; CHECK-NEXT:    vaddva.u16 r0, q0
+; CHECK-NEXT:    vmovlb.s8 q1, q1
+; CHECK-NEXT:    vmovlb.s8 q0, q0
+; CHECK-NEXT:    vmlava.u16 r0, q0, q1
 ; CHECK-NEXT:    sxth r0, r0
 ; CHECK-NEXT:    bx lr
 entry:
@@ -1440,8 +1441,7 @@ entry:
 define arm_aapcs_vfpcc zeroext i8 @add_v16i8_v16i8_acc(<16 x i8> %x, <16 x i8> %y, i8 %a) {
 ; CHECK-LABEL: add_v16i8_v16i8_acc:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmul.i8 q0, q0, q1
-; CHECK-NEXT:    vaddva.u8 r0, q0
+; CHECK-NEXT:    vmlava.u8 r0, q0, q1
 ; CHECK-NEXT:    uxtb r0, r0
 ; CHECK-NEXT:    bx lr
 entry: