Add 2D-specific Neon horizontal convolution functions
authorJonathan Wright <jonathan.wright@arm.com>
Thu, 4 May 2023 15:33:38 +0000 (16:33 +0100)
committerJonathan Wright <jonathan.wright@arm.com>
Sat, 13 May 2023 19:43:20 +0000 (20:43 +0100)
2D 8-tap convolution filtering is performed in two passes -
horizontal and vertical. The horizontal pass must produce enough
input data for the subsequent vertical pass - 3 rows above and 4 rows
below, in addition to the actual block height.

At present, all Neon horizontal convolution algorithms process 4 rows
at a time, but this means we end up doing at least 1 row too much
work in the 2D first pass case where we need h + 7, not h + 8 rows of
output.

This patch adds additional dot-product (SDOT and USDOT) Neon paths
that process h + 7 rows of data exactly, saving the work of the
unnecessary extra row. It is impractical to take a similar approach
for the Armv8.0 MLA paths since we have to transpose the data block
both before and after calling the convolution helper functions.

vpx_convolve_neon performance impact: we observe a speedup of ~9% for
smaller (and wider) blocks, and a speedup of 0-3% for larger blocks.
This is to be expected since the proportion of redundant work
decreases as the block height increases.

Change-Id: Ie77ad1848707d2d48bb8851345a469aae9d097e1

vpx_dsp/arm/mem_neon.h
vpx_dsp/arm/vpx_convolve8_neon.c
vpx_dsp/arm/vpx_convolve8_neon.h
vpx_dsp/arm/vpx_convolve_neon.c

index 1a20da7..586bfb8 100644 (file)
@@ -263,6 +263,16 @@ static INLINE void store_u8(uint8_t *buf, ptrdiff_t stride, const uint8x8_t a) {
   vst1_lane_u32((uint32_t *)buf, a_u32, 1);
 }
 
+static INLINE void store_u8_8x3(uint8_t *s, const ptrdiff_t p,
+                                const uint8x8_t s0, const uint8x8_t s1,
+                                const uint8x8_t s2) {
+  vst1_u8(s, s0);
+  s += p;
+  vst1_u8(s, s1);
+  s += p;
+  vst1_u8(s, s2);
+}
+
 static INLINE void load_u8_8x4(const uint8_t *s, const ptrdiff_t p,
                                uint8x8_t *const s0, uint8x8_t *const s1,
                                uint8x8_t *const s2, uint8x8_t *const s3) {
@@ -287,6 +297,16 @@ static INLINE void store_u8_8x4(uint8_t *s, const ptrdiff_t p,
   vst1_u8(s, s3);
 }
 
+static INLINE void load_u8_16x3(const uint8_t *s, const ptrdiff_t p,
+                                uint8x16_t *const s0, uint8x16_t *const s1,
+                                uint8x16_t *const s2) {
+  *s0 = vld1q_u8(s);
+  s += p;
+  *s1 = vld1q_u8(s);
+  s += p;
+  *s2 = vld1q_u8(s);
+}
+
 static INLINE void load_u8_16x4(const uint8_t *s, const ptrdiff_t p,
                                 uint8x16_t *const s0, uint8x16_t *const s1,
                                 uint8x16_t *const s2, uint8x16_t *const s3) {
index f217a3f..505d067 100644 (file)
@@ -57,6 +57,111 @@ DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = {
 
 #if defined(__ARM_FEATURE_MATMUL_INT8)
 
+void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+                                 uint8_t *dst, ptrdiff_t dst_stride,
+                                 const InterpKernel *filter, int x0_q4,
+                                 int x_step_q4, int y0_q4, int y_step_q4, int w,
+                                 int h) {
+  const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4]));
+  uint8x16_t s0, s1, s2, s3;
+
+  assert((intptr_t)dst % 4 == 0);
+  assert(dst_stride % 4 == 0);
+  assert(x_step_q4 == 16);
+  assert(h % 4 == 3);
+
+  (void)x_step_q4;
+  (void)y0_q4;
+  (void)y_step_q4;
+
+  src -= 3;
+
+  if (w == 4) {
+    const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+    int16x4_t d0, d1, d2, d3;
+    uint8x8_t d01, d23;
+
+    do {
+      load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
+
+      d0 = convolve8_4_usdot(s0, filters, perm_tbl);
+      d1 = convolve8_4_usdot(s1, filters, perm_tbl);
+      d2 = convolve8_4_usdot(s2, filters, perm_tbl);
+      d3 = convolve8_4_usdot(s3, filters, perm_tbl);
+      d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+      d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+
+      store_u8(dst + 0 * dst_stride, dst_stride, d01);
+      store_u8(dst + 2 * dst_stride, dst_stride, d23);
+
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 3);
+
+    /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for
+     * further details on possible values of block height. */
+    load_u8_16x3(src, src_stride, &s0, &s1, &s2);
+
+    d0 = convolve8_4_usdot(s0, filters, perm_tbl);
+    d1 = convolve8_4_usdot(s1, filters, perm_tbl);
+    d2 = convolve8_4_usdot(s2, filters, perm_tbl);
+    d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+    d23 = vqrshrun_n_s16(vcombine_s16(d2, vdup_n_s16(0)), FILTER_BITS);
+
+    store_u8(dst + 0 * dst_stride, dst_stride, d01);
+    store_u8_4x1(dst + 2 * dst_stride, d23);
+  } else {
+    const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+    const uint8_t *s;
+    uint8_t *d;
+    int width;
+    uint8x8_t d0, d1, d2, d3;
+
+    do {
+      width = w;
+      s = src;
+      d = dst;
+      do {
+        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+        d0 = convolve8_8_usdot(s0, filters, perm_tbl);
+        d1 = convolve8_8_usdot(s1, filters, perm_tbl);
+        d2 = convolve8_8_usdot(s2, filters, perm_tbl);
+        d3 = convolve8_8_usdot(s3, filters, perm_tbl);
+
+        store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width > 0);
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 3);
+
+    /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for
+     * further details on possible values of block height. */
+    width = w;
+    s = src;
+    d = dst;
+    do {
+      load_u8_16x3(s, src_stride, &s0, &s1, &s2);
+
+      d0 = convolve8_8_usdot(s0, filters, perm_tbl);
+      d1 = convolve8_8_usdot(s1, filters, perm_tbl);
+      d2 = convolve8_8_usdot(s2, filters, perm_tbl);
+
+      store_u8_8x3(d, dst_stride, d0, d1, d2);
+
+      s += 8;
+      d += 8;
+      width -= 8;
+    } while (width > 0);
+  }
+}
+
 void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
                               uint8_t *dst, ptrdiff_t dst_stride,
                               const InterpKernel *filter, int x0_q4,
@@ -96,7 +201,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
       src += 4 * src_stride;
       dst += 4 * dst_stride;
       h -= 4;
-    } while (h > 0);
+    } while (h != 0);
   } else {
     const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
     const uint8_t *s;
@@ -125,7 +230,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
       src += 4 * src_stride;
       dst += 4 * dst_stride;
       h -= 4;
-    } while (h > 0);
+    } while (h != 0);
   }
 }
 
@@ -611,6 +716,114 @@ void vpx_convolve8_avg_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
 
 #else  // !defined(__ARM_FEATURE_MATMUL_INT8)
 
+void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+                                 uint8_t *dst, ptrdiff_t dst_stride,
+                                 const InterpKernel *filter, int x0_q4,
+                                 int x_step_q4, int y0_q4, int y_step_q4, int w,
+                                 int h) {
+  const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4]));
+  const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter[x0_q4]), 128);
+  const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
+  const uint8x16_t range_limit = vdupq_n_u8(128);
+  uint8x16_t s0, s1, s2, s3;
+
+  assert((intptr_t)dst % 4 == 0);
+  assert(dst_stride % 4 == 0);
+  assert(x_step_q4 == 16);
+  assert(h % 4 == 3);
+
+  (void)x_step_q4;
+  (void)y0_q4;
+  (void)y_step_q4;
+
+  src -= 3;
+
+  if (w == 4) {
+    const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+    int16x4_t d0, d1, d2, d3;
+    uint8x8_t d01, d23;
+
+    do {
+      load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
+
+      d0 = convolve8_4_sdot(s0, filters, correction, range_limit, perm_tbl);
+      d1 = convolve8_4_sdot(s1, filters, correction, range_limit, perm_tbl);
+      d2 = convolve8_4_sdot(s2, filters, correction, range_limit, perm_tbl);
+      d3 = convolve8_4_sdot(s3, filters, correction, range_limit, perm_tbl);
+      d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+      d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+
+      store_u8(dst + 0 * dst_stride, dst_stride, d01);
+      store_u8(dst + 2 * dst_stride, dst_stride, d23);
+
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 3);
+
+    /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for
+     * further details on possible values of block height. */
+    load_u8_16x3(src, src_stride, &s0, &s1, &s2);
+
+    d0 = convolve8_4_sdot(s0, filters, correction, range_limit, perm_tbl);
+    d1 = convolve8_4_sdot(s1, filters, correction, range_limit, perm_tbl);
+    d2 = convolve8_4_sdot(s2, filters, correction, range_limit, perm_tbl);
+    d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+    d23 = vqrshrun_n_s16(vcombine_s16(d2, vdup_n_s16(0)), FILTER_BITS);
+
+    store_u8(dst + 0 * dst_stride, dst_stride, d01);
+    store_u8_4x1(dst + 2 * dst_stride, d23);
+  } else {
+    const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+    const uint8_t *s;
+    uint8_t *d;
+    int width;
+    uint8x8_t d0, d1, d2, d3;
+
+    do {
+      width = w;
+      s = src;
+      d = dst;
+      do {
+        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+        d0 = convolve8_8_sdot(s0, filters, correction, range_limit, perm_tbl);
+        d1 = convolve8_8_sdot(s1, filters, correction, range_limit, perm_tbl);
+        d2 = convolve8_8_sdot(s2, filters, correction, range_limit, perm_tbl);
+        d3 = convolve8_8_sdot(s3, filters, correction, range_limit, perm_tbl);
+
+        store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width != 0);
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 3);
+
+    /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for
+     * further details on possible values of block height. */
+    width = w;
+    s = src;
+    d = dst;
+    do {
+      load_u8_16x3(s, src_stride, &s0, &s1, &s2);
+
+      d0 = convolve8_8_sdot(s0, filters, correction, range_limit, perm_tbl);
+      d1 = convolve8_8_sdot(s1, filters, correction, range_limit, perm_tbl);
+      d2 = convolve8_8_sdot(s2, filters, correction, range_limit, perm_tbl);
+
+      store_u8_8x3(d, dst_stride, d0, d1, d2);
+
+      s += 8;
+      d += 8;
+      width -= 8;
+    } while (width != 0);
+  }
+}
+
 void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
                               uint8_t *dst, ptrdiff_t dst_stride,
                               const InterpKernel *filter, int x0_q4,
@@ -653,7 +866,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
       src += 4 * src_stride;
       dst += 4 * dst_stride;
       h -= 4;
-    } while (h > 0);
+    } while (h != 0);
   } else {
     const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
     const uint8_t *s;
@@ -682,7 +895,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
       src += 4 * src_stride;
       dst += 4 * dst_stride;
       h -= 4;
-    } while (h > 0);
+    } while (h != 0);
   }
 }
 
index c838d40..2f78583 100644 (file)
 #include "./vpx_dsp_rtcd.h"
 #include "vpx_dsp/vpx_filter.h"
 
+#if VPX_ARCH_AARCH64 && \
+    (defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8))
+void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+                                 uint8_t *dst, ptrdiff_t dst_stride,
+                                 const InterpKernel *filter, int x0_q4,
+                                 int x_step_q4, int y0_q4, int y_step_q4, int w,
+                                 int h);
+#endif
+
 #if VPX_ARCH_AARCH64 && defined(__ARM_FEATURE_DOTPROD)
 
 static INLINE int16x4_t convolve8_4_sdot_partial(const int8x16_t samples_lo,
index 830f317..f7db3e6 100644 (file)
 #include "vpx_dsp/vpx_dsp_common.h"
 #include "vpx_ports/mem.h"
 
+#if VPX_ARCH_AARCH64 && \
+    (defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8))
+#include "vpx_dsp/arm/vpx_convolve8_neon.h"
+
+void vpx_convolve8_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
+                        ptrdiff_t dst_stride, const InterpKernel *filter,
+                        int x0_q4, int x_step_q4, int y0_q4, int y_step_q4,
+                        int w, int h) {
+  /* Given our constraints: w <= 64, h <= 64, taps == 8 we can reduce the
+   * maximum buffer size to 64 * (64 + 7). */
+  uint8_t temp[64 * 71];
+
+  /* Account for the vertical phase needing 3 lines prior and 4 lines post. */
+  const int intermediate_height = h + 7;
+
+  assert(y_step_q4 == 16);
+  assert(x_step_q4 == 16);
+
+  /* Filter starting 3 lines back. */
+  vpx_convolve8_2d_horiz_neon(src - src_stride * 3, src_stride, temp, w, filter,
+                              x0_q4, x_step_q4, y0_q4, y_step_q4, w,
+                              intermediate_height);
+
+  /* Step into the temp buffer 3 lines to get the actual frame data */
+  vpx_convolve8_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4,
+                          x_step_q4, y0_q4, y_step_q4, w, h);
+}
+
+void vpx_convolve8_avg_neon(const uint8_t *src, ptrdiff_t src_stride,
+                            uint8_t *dst, ptrdiff_t dst_stride,
+                            const InterpKernel *filter, int x0_q4,
+                            int x_step_q4, int y0_q4, int y_step_q4, int w,
+                            int h) {
+  uint8_t temp[64 * 71];
+  const int intermediate_height = h + 7;
+
+  assert(y_step_q4 == 16);
+  assert(x_step_q4 == 16);
+
+  vpx_convolve8_2d_horiz_neon(src - src_stride * 3, src_stride, temp, w, filter,
+                              x0_q4, x_step_q4, y0_q4, y_step_q4, w,
+                              intermediate_height);
+
+  vpx_convolve8_avg_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4,
+                              x_step_q4, y0_q4, y_step_q4, w, h);
+}
+
+#else  // !(VPX_ARCH_AARCH64 &&
+       //   (defined(__ARM_FEATURE_DOTPROD) ||
+       //    defined(__ARM_FEATURE_MATMUL_INT8)))
+
 void vpx_convolve8_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
                         ptrdiff_t dst_stride, const InterpKernel *filter,
                         int x0_q4, int x_step_q4, int y0_q4, int y_step_q4,
@@ -63,3 +114,7 @@ void vpx_convolve8_avg_neon(const uint8_t *src, ptrdiff_t src_stride,
   vpx_convolve8_avg_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4,
                               x_step_q4, y0_q4, y_step_q4, w, h);
 }
+
+#endif  // #if VPX_ARCH_AARCH64 &&
+        //     (defined(__ARM_FEATURE_DOTPROD) ||
+        //      defined(__ARM_FEATURE_MATMUL_INT8))