Optimize Armv8.0 Neon SAD4D 16xh, 32xh, and 64xh functions
authorJonathan Wright <jonathan.wright@arm.com>
Thu, 6 Apr 2023 15:14:51 +0000 (16:14 +0100)
committerJonathan Wright <jonathan.wright@arm.com>
Thu, 6 Apr 2023 16:41:01 +0000 (17:41 +0100)
Add a widening 4D reduction function operating on uint16x8_t vectors
and use it to optimize the final reduction in Armv8.0 Neon standard
bitdepth 16xh, 32xh and 64h SAD4D computations.

Also simplify the Armv8.0 Neon version of the sad64xhx4d_neon helper
function since VP9 block sizes are not large enough to require
widening to 32-bit accumulators before the final reduction.

Change-Id: I32b0a283d7688d8cdf21791add9476ed24c66a28

vpx_dsp/arm/sad4d_neon.c
vpx_dsp/arm/sum_neon.h

index ab00e0e..6ad6c96 100644 (file)
@@ -140,53 +140,43 @@ static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
 static INLINE void sad64xhx4d_neon(const uint8_t *src, int src_stride,
                                    const uint8_t *const ref[4], int ref_stride,
                                    uint32_t res[4], int h) {
-  int h_tmp = h > 64 ? 64 : h;
-  int i = 0;
-  vst1q_u32(res, vdupq_n_u32(0));
+  uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                           vdupq_n_u16(0) };
+  uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                           vdupq_n_u16(0) };
 
+  int i = 0;
   do {
-    uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                             vdupq_n_u16(0) };
-    uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                             vdupq_n_u16(0) };
-
-    do {
-      uint8x16_t s0, s1, s2, s3;
-
-      s0 = vld1q_u8(src + i * src_stride);
-      sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
-      sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
-      sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
-      sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
-
-      s1 = vld1q_u8(src + i * src_stride + 16);
-      sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
-      sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
-      sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
-      sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
-
-      s2 = vld1q_u8(src + i * src_stride + 32);
-      sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
-      sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
-      sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
-      sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
-
-      s3 = vld1q_u8(src + i * src_stride + 48);
-      sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
-      sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
-      sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
-      sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
-
-      i++;
-    } while (i < h_tmp);
-
-    res[0] += horizontal_long_add_uint16x8(sum_lo[0], sum_hi[0]);
-    res[1] += horizontal_long_add_uint16x8(sum_lo[1], sum_hi[1]);
-    res[2] += horizontal_long_add_uint16x8(sum_lo[2], sum_hi[2]);
-    res[3] += horizontal_long_add_uint16x8(sum_lo[3], sum_hi[3]);
-
-    h_tmp += 64;
+    uint8x16_t s0, s1, s2, s3;
+
+    s0 = vld1q_u8(src + i * src_stride);
+    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
+    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
+    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
+    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+
+    s1 = vld1q_u8(src + i * src_stride + 16);
+    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
+    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
+    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
+    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+
+    s2 = vld1q_u8(src + i * src_stride + 32);
+    sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
+    sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
+    sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
+    sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
+
+    s3 = vld1q_u8(src + i * src_stride + 48);
+    sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
+    sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
+    sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
+    sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
+
+    i++;
   } while (i < h);
+
+  vst1q_u32(res, horizontal_long_add_4d_uint16x8(sum_lo, sum_hi));
 }
 
 static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride,
@@ -216,10 +206,7 @@ static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride,
     i++;
   } while (i < h);
 
-  res[0] = horizontal_long_add_uint16x8(sum_lo[0], sum_hi[0]);
-  res[1] = horizontal_long_add_uint16x8(sum_lo[1], sum_hi[1]);
-  res[2] = horizontal_long_add_uint16x8(sum_lo[2], sum_hi[2]);
-  res[3] = horizontal_long_add_uint16x8(sum_lo[3], sum_hi[3]);
+  vst1q_u32(res, horizontal_long_add_4d_uint16x8(sum_lo, sum_hi));
 }
 
 static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride,
@@ -239,10 +226,7 @@ static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride,
     i++;
   } while (i < h);
 
-  res[0] = horizontal_add_uint16x8(sum[0]);
-  res[1] = horizontal_add_uint16x8(sum[1]);
-  res[2] = horizontal_add_uint16x8(sum[2]);
-  res[3] = horizontal_add_uint16x8(sum[3]);
+  vst1q_u32(res, horizontal_add_4d_uint16x8(sum));
 }
 
 #endif  // defined(__ARM_FEATURE_DOTPROD)
index 6259add..a0c72f9 100644 (file)
@@ -117,6 +117,31 @@ static INLINE uint32_t horizontal_long_add_uint16x8(const uint16x8_t vec_lo,
 #endif
 }
 
+static INLINE uint32x4_t horizontal_long_add_4d_uint16x8(
+    const uint16x8_t sum_lo[4], const uint16x8_t sum_hi[4]) {
+  const uint32x4_t a0 = vpaddlq_u16(sum_lo[0]);
+  const uint32x4_t a1 = vpaddlq_u16(sum_lo[1]);
+  const uint32x4_t a2 = vpaddlq_u16(sum_lo[2]);
+  const uint32x4_t a3 = vpaddlq_u16(sum_lo[3]);
+  const uint32x4_t b0 = vpadalq_u16(a0, sum_hi[0]);
+  const uint32x4_t b1 = vpadalq_u16(a1, sum_hi[1]);
+  const uint32x4_t b2 = vpadalq_u16(a2, sum_hi[2]);
+  const uint32x4_t b3 = vpadalq_u16(a3, sum_hi[3]);
+#if defined(__aarch64__)
+  const uint32x4_t c0 = vpaddq_u32(b0, b1);
+  const uint32x4_t c1 = vpaddq_u32(b2, b3);
+  return vpaddq_u32(c0, c1);
+#else
+  const uint32x2_t c0 = vadd_u32(vget_low_u32(b0), vget_high_u32(b0));
+  const uint32x2_t c1 = vadd_u32(vget_low_u32(b1), vget_high_u32(b1));
+  const uint32x2_t c2 = vadd_u32(vget_low_u32(b2), vget_high_u32(b2));
+  const uint32x2_t c3 = vadd_u32(vget_low_u32(b3), vget_high_u32(b3));
+  const uint32x2_t d0 = vpadd_u32(c0, c1);
+  const uint32x2_t d1 = vpadd_u32(c2, c3);
+  return vcombine_u32(d0, d1);
+#endif
+}
+
 static INLINE int32_t horizontal_add_int32x2(const int32x2_t a) {
 #if defined(__aarch64__)
   return vaddv_s32(a);