Refine 8-bit intra prediction NEON optimization (mode dc)
authorLinfeng Zhang <linfengz@google.com>
Wed, 19 Oct 2016 20:37:26 +0000 (13:37 -0700)
committerLinfeng Zhang <linfengz@google.com>
Mon, 24 Oct 2016 20:18:51 +0000 (13:18 -0700)
dst += stride behaving better with gcc/clang
Expanding inline function dc_SIZExSIZE() save intructions for
vpx_dc_predictor_SIZExSIZE_neon().

Change-Id: Id0ccbd58b6a31df539141fd33bdf28633339150d

vpx_dsp/arm/intrapred_neon.c

index 38e79ed..da2b2c9 100644 (file)
 //------------------------------------------------------------------------------
 // DC 4x4
 
-// 'do_above' and 'do_left' facilitate branch removal when inlined.
-static INLINE void dc_4x4(uint8_t *dst, ptrdiff_t stride, const uint8_t *above,
-                          const uint8_t *left, int do_above, int do_left) {
-  uint16x4_t sum_top;
-  uint16x4_t sum_left;
-  uint16x4_t dc0;
-
-  if (do_above) {
-    const uint8x8_t A = vld1_u8(above);  // top row
-    const uint16x4_t p0 = vpaddl_u8(A);  // cascading summation of the top
-    sum_top = vpadd_u16(p0, p0);
-  }
-
-  if (do_left) {
-    const uint8x8_t L = vld1_u8(left);   // left border
-    const uint16x4_t p0 = vpaddl_u8(L);  // cascading summation of the left
-    sum_left = vpadd_u16(p0, p0);
-  }
-
-  if (do_above && do_left) {
-    const uint16x4_t sum = vadd_u16(sum_left, sum_top);
-    dc0 = vrshr_n_u16(sum, 3);
-  } else if (do_above) {
-    dc0 = vrshr_n_u16(sum_top, 2);
-  } else if (do_left) {
-    dc0 = vrshr_n_u16(sum_left, 2);
-  } else {
-    dc0 = vdup_n_u16(0x80);
-  }
+static INLINE uint16x4_t dc_sum_4(const uint8_t *ref) {
+  const uint8x8_t ref_u8 = vld1_u8(ref);
+  const uint16x4_t p0 = vpaddl_u8(ref_u8);
+  return vpadd_u16(p0, p0);
+}
 
-  {
-    const uint8x8_t dc = vdup_lane_u8(vreinterpret_u8_u16(dc0), 0);
-    int i;
-    for (i = 0; i < 4; ++i) {
-      vst1_lane_u32((uint32_t *)(dst + i * stride), vreinterpret_u32_u8(dc), 0);
-    }
+static INLINE void dc_store_4x4(uint8_t *dst, ptrdiff_t stride,
+                                const uint8x8_t dc) {
+  const uint8x8_t dc_dup = vdup_lane_u8(dc, 0);
+  int i;
+  for (i = 0; i < 4; ++i, dst += stride) {
+    vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(dc_dup), 0);
   }
 }
 
 void vpx_dc_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                                const uint8_t *above, const uint8_t *left) {
-  dc_4x4(dst, stride, above, left, 1, 1);
+  const uint8x8_t a = vld1_u8(above);
+  const uint8x8_t l = vld1_u8(left);
+  const uint16x8_t al = vaddl_u8(a, l);
+  uint16x4_t sum;
+  uint8x8_t dc;
+  sum = vpadd_u16(vget_low_u16(al), vget_low_u16(al));
+  sum = vpadd_u16(sum, sum);
+  dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 3));
+  dc_store_4x4(dst, stride, dc);
 }
 
 void vpx_dc_left_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                                     const uint8_t *above, const uint8_t *left) {
+  const uint16x4_t sum = dc_sum_4(left);
+  const uint8x8_t dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 2));
   (void)above;
-  dc_4x4(dst, stride, NULL, left, 0, 1);
+  dc_store_4x4(dst, stride, dc);
 }
 
 void vpx_dc_top_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                                    const uint8_t *above, const uint8_t *left) {
+  const uint16x4_t sum = dc_sum_4(above);
+  const uint8x8_t dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 2));
   (void)left;
-  dc_4x4(dst, stride, above, NULL, 1, 0);
+  dc_store_4x4(dst, stride, dc);
 }
 
 void vpx_dc_128_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride,
                                    const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t dc = vdup_n_u8(0x80);
   (void)above;
   (void)left;
-  dc_4x4(dst, stride, NULL, NULL, 0, 0);
+  dc_store_4x4(dst, stride, dc);
 }
 
 //------------------------------------------------------------------------------
 // DC 8x8
 
-// 'do_above' and 'do_left' facilitate branch removal when inlined.
-static INLINE void dc_8x8(uint8_t *dst, ptrdiff_t stride, const uint8_t *above,
-                          const uint8_t *left, int do_above, int do_left) {
-  uint16x8_t sum_top;
-  uint16x8_t sum_left;
-  uint8x8_t dc0;
-
-  if (do_above) {
-    const uint8x8_t A = vld1_u8(above);  // top row
-    const uint16x4_t p0 = vpaddl_u8(A);  // cascading summation of the top
-    const uint16x4_t p1 = vpadd_u16(p0, p0);
-    const uint16x4_t p2 = vpadd_u16(p1, p1);
-    sum_top = vcombine_u16(p2, p2);
-  }
-
-  if (do_left) {
-    const uint8x8_t L = vld1_u8(left);   // left border
-    const uint16x4_t p0 = vpaddl_u8(L);  // cascading summation of the left
-    const uint16x4_t p1 = vpadd_u16(p0, p0);
-    const uint16x4_t p2 = vpadd_u16(p1, p1);
-    sum_left = vcombine_u16(p2, p2);
-  }
-
-  if (do_above && do_left) {
-    const uint16x8_t sum = vaddq_u16(sum_left, sum_top);
-    dc0 = vrshrn_n_u16(sum, 4);
-  } else if (do_above) {
-    dc0 = vrshrn_n_u16(sum_top, 3);
-  } else if (do_left) {
-    dc0 = vrshrn_n_u16(sum_left, 3);
-  } else {
-    dc0 = vdup_n_u8(0x80);
-  }
+static INLINE uint16x4_t dc_sum_8(const uint8_t *ref) {
+  const uint8x8_t ref_u8 = vld1_u8(ref);
+  uint16x4_t sum = vpaddl_u8(ref_u8);
+  sum = vpadd_u16(sum, sum);
+  return vpadd_u16(sum, sum);
+}
 
-  {
-    const uint8x8_t dc = vdup_lane_u8(dc0, 0);
-    int i;
-    for (i = 0; i < 8; ++i) {
-      vst1_u32((uint32_t *)(dst + i * stride), vreinterpret_u32_u8(dc));
-    }
+static INLINE void dc_store_8x8(uint8_t *dst, ptrdiff_t stride,
+                                const uint8x8_t dc) {
+  const uint8x8_t dc_dup = vdup_lane_u8(dc, 0);
+  int i;
+  for (i = 0; i < 8; ++i, dst += stride) {
+    vst1_u8(dst, dc_dup);
   }
 }
 
 void vpx_dc_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                                const uint8_t *above, const uint8_t *left) {
-  dc_8x8(dst, stride, above, left, 1, 1);
+  const uint8x8_t above_u8 = vld1_u8(above);
+  const uint8x8_t left_u8 = vld1_u8(left);
+  const uint8x16_t above_and_left = vcombine_u8(above_u8, left_u8);
+  const uint16x8_t p0 = vpaddlq_u8(above_and_left);
+  uint16x4_t sum = vadd_u16(vget_low_u16(p0), vget_high_u16(p0));
+  uint8x8_t dc;
+  sum = vpadd_u16(sum, sum);
+  sum = vpadd_u16(sum, sum);
+  dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 4));
+  dc_store_8x8(dst, stride, dc);
 }
 
 void vpx_dc_left_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                                     const uint8_t *above, const uint8_t *left) {
+  const uint16x4_t sum = dc_sum_8(left);
+  const uint8x8_t dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 3));
   (void)above;
-  dc_8x8(dst, stride, NULL, left, 0, 1);
+  dc_store_8x8(dst, stride, dc);
 }
 
 void vpx_dc_top_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                                    const uint8_t *above, const uint8_t *left) {
+  const uint16x4_t sum = dc_sum_8(above);
+  const uint8x8_t dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 3));
   (void)left;
-  dc_8x8(dst, stride, above, NULL, 1, 0);
+  dc_store_8x8(dst, stride, dc);
 }
 
 void vpx_dc_128_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
                                    const uint8_t *above, const uint8_t *left) {
+  const uint8x8_t dc = vdup_n_u8(0x80);
   (void)above;
   (void)left;
-  dc_8x8(dst, stride, NULL, NULL, 0, 0);
+  dc_store_8x8(dst, stride, dc);
 }
 
 //------------------------------------------------------------------------------
 // DC 16x16
 
-// 'do_above' and 'do_left' facilitate branch removal when inlined.
-static INLINE void dc_16x16(uint8_t *dst, ptrdiff_t stride,
-                            const uint8_t *above, const uint8_t *left,
-                            int do_above, int do_left) {
-  uint16x8_t sum_top;
-  uint16x8_t sum_left;
-  uint8x8_t dc0;
-
-  if (do_above) {
-    const uint8x16_t A = vld1q_u8(above);  // top row
-    const uint16x8_t p0 = vpaddlq_u8(A);   // cascading summation of the top
-    const uint16x4_t p1 = vadd_u16(vget_low_u16(p0), vget_high_u16(p0));
-    const uint16x4_t p2 = vpadd_u16(p1, p1);
-    const uint16x4_t p3 = vpadd_u16(p2, p2);
-    sum_top = vcombine_u16(p3, p3);
-  }
-
-  if (do_left) {
-    const uint8x16_t L = vld1q_u8(left);  // left row
-    const uint16x8_t p0 = vpaddlq_u8(L);  // cascading summation of the left
-    const uint16x4_t p1 = vadd_u16(vget_low_u16(p0), vget_high_u16(p0));
-    const uint16x4_t p2 = vpadd_u16(p1, p1);
-    const uint16x4_t p3 = vpadd_u16(p2, p2);
-    sum_left = vcombine_u16(p3, p3);
-  }
-
-  if (do_above && do_left) {
-    const uint16x8_t sum = vaddq_u16(sum_left, sum_top);
-    dc0 = vrshrn_n_u16(sum, 5);
-  } else if (do_above) {
-    dc0 = vrshrn_n_u16(sum_top, 4);
-  } else if (do_left) {
-    dc0 = vrshrn_n_u16(sum_left, 4);
-  } else {
-    dc0 = vdup_n_u8(0x80);
-  }
+static INLINE uint16x4_t dc_sum_16(const uint8_t *ref) {
+  const uint8x16_t ref_u8 = vld1q_u8(ref);
+  const uint16x8_t p0 = vpaddlq_u8(ref_u8);
+  uint16x4_t sum = vadd_u16(vget_low_u16(p0), vget_high_u16(p0));
+  sum = vpadd_u16(sum, sum);
+  return vpadd_u16(sum, sum);
+}
 
-  {
-    const uint8x16_t dc = vdupq_lane_u8(dc0, 0);
-    int i;
-    for (i = 0; i < 16; ++i) {
-      vst1q_u8(dst + i * stride, dc);
-    }
+static INLINE void dc_store_16x16(uint8_t *dst, ptrdiff_t stride,
+                                  const uint8x8_t dc) {
+  const uint8x16_t dc_dup = vdupq_lane_u8(dc, 0);
+  int i;
+  for (i = 0; i < 16; ++i, dst += stride) {
+    vst1q_u8(dst, dc_dup);
   }
 }
 
 void vpx_dc_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                  const uint8_t *above, const uint8_t *left) {
-  dc_16x16(dst, stride, above, left, 1, 1);
+  const uint8x16_t ref0 = vld1q_u8(above);
+  const uint8x16_t ref1 = vld1q_u8(left);
+  const uint16x8_t p0 = vpaddlq_u8(ref0);
+  const uint16x8_t p1 = vpaddlq_u8(ref1);
+  const uint16x8_t p2 = vaddq_u16(p0, p1);
+  uint16x4_t sum = vadd_u16(vget_low_u16(p2), vget_high_u16(p2));
+  uint8x8_t dc;
+  sum = vpadd_u16(sum, sum);
+  sum = vpadd_u16(sum, sum);
+  dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 5));
+  dc_store_16x16(dst, stride, dc);
 }
 
 void vpx_dc_left_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                       const uint8_t *above,
                                       const uint8_t *left) {
+  const uint16x4_t sum = dc_sum_16(left);
+  const uint8x8_t dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 4));
   (void)above;
-  dc_16x16(dst, stride, NULL, left, 0, 1);
+  dc_store_16x16(dst, stride, dc);
 }
 
 void vpx_dc_top_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                      const uint8_t *above,
                                      const uint8_t *left) {
+  const uint16x4_t sum = dc_sum_16(above);
+  const uint8x8_t dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 4));
   (void)left;
-  dc_16x16(dst, stride, above, NULL, 1, 0);
+  dc_store_16x16(dst, stride, dc);
 }
 
 void vpx_dc_128_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
                                      const uint8_t *above,
                                      const uint8_t *left) {
+  const uint8x8_t dc = vdup_n_u8(0x80);
   (void)above;
   (void)left;
-  dc_16x16(dst, stride, NULL, NULL, 0, 0);
+  dc_store_16x16(dst, stride, dc);
 }
 
 //------------------------------------------------------------------------------
 // DC 32x32
 
-// 'do_above' and 'do_left' facilitate branch removal when inlined.
-static INLINE void dc_32x32(uint8_t *dst, ptrdiff_t stride,
-                            const uint8_t *above, const uint8_t *left,
-                            int do_above, int do_left) {
-  uint16x8_t sum_top;
-  uint16x8_t sum_left;
-  uint8x8_t dc0;
-
-  if (do_above) {
-    const uint8x16_t A0 = vld1q_u8(above);  // top row
-    const uint8x16_t A1 = vld1q_u8(above + 16);
-    const uint16x8_t p0 = vpaddlq_u8(A0);  // cascading summation of the top
-    const uint16x8_t p1 = vpaddlq_u8(A1);
-    const uint16x8_t p2 = vaddq_u16(p0, p1);
-    const uint16x4_t p3 = vadd_u16(vget_low_u16(p2), vget_high_u16(p2));
-    const uint16x4_t p4 = vpadd_u16(p3, p3);
-    const uint16x4_t p5 = vpadd_u16(p4, p4);
-    sum_top = vcombine_u16(p5, p5);
-  }
-
-  if (do_left) {
-    const uint8x16_t L0 = vld1q_u8(left);  // left row
-    const uint8x16_t L1 = vld1q_u8(left + 16);
-    const uint16x8_t p0 = vpaddlq_u8(L0);  // cascading summation of the left
-    const uint16x8_t p1 = vpaddlq_u8(L1);
-    const uint16x8_t p2 = vaddq_u16(p0, p1);
-    const uint16x4_t p3 = vadd_u16(vget_low_u16(p2), vget_high_u16(p2));
-    const uint16x4_t p4 = vpadd_u16(p3, p3);
-    const uint16x4_t p5 = vpadd_u16(p4, p4);
-    sum_left = vcombine_u16(p5, p5);
-  }
+static INLINE uint16x4_t dc_sum_32(const uint8_t *ref) {
+  const uint8x16x2_t r = vld2q_u8(ref);
+  const uint16x8_t p0 = vpaddlq_u8(r.val[0]);
+  const uint16x8_t p1 = vpaddlq_u8(r.val[1]);
+  const uint16x8_t p2 = vaddq_u16(p0, p1);
+  uint16x4_t sum = vadd_u16(vget_low_u16(p2), vget_high_u16(p2));
+  sum = vpadd_u16(sum, sum);
+  return vpadd_u16(sum, sum);
+}
 
-  if (do_above && do_left) {
-    const uint16x8_t sum = vaddq_u16(sum_left, sum_top);
-    dc0 = vrshrn_n_u16(sum, 6);
-  } else if (do_above) {
-    dc0 = vrshrn_n_u16(sum_top, 5);
-  } else if (do_left) {
-    dc0 = vrshrn_n_u16(sum_left, 5);
-  } else {
-    dc0 = vdup_n_u8(0x80);
-  }
+static INLINE void dc_store_32x32(uint8_t *dst, ptrdiff_t stride,
+                                  const uint8x8_t dc) {
+  uint8x16x2_t dc_dup;
+  int i;
+  dc_dup.val[0] = dc_dup.val[1] = vdupq_lane_u8(dc, 0);
 
-  {
-    const uint8x16_t dc = vdupq_lane_u8(dc0, 0);
-    int i;
-    for (i = 0; i < 32; ++i) {
-      vst1q_u8(dst + i * stride, dc);
-      vst1q_u8(dst + i * stride + 16, dc);
-    }
+  for (i = 0; i < 32; ++i, dst += stride) {
+    vst2q_u8(dst, dc_dup);
   }
 }
 
 void vpx_dc_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                  const uint8_t *above, const uint8_t *left) {
-  dc_32x32(dst, stride, above, left, 1, 1);
+  const uint8x16x2_t a = vld2q_u8(above);
+  const uint8x16x2_t l = vld2q_u8(left);
+  const uint16x8_t pa0 = vpaddlq_u8(a.val[0]);
+  const uint16x8_t pl0 = vpaddlq_u8(l.val[0]);
+  const uint16x8_t pa1 = vpaddlq_u8(a.val[1]);
+  const uint16x8_t pl1 = vpaddlq_u8(l.val[1]);
+  const uint16x8_t pa = vaddq_u16(pa0, pa1);
+  const uint16x8_t pl = vaddq_u16(pl0, pl1);
+  const uint16x8_t pal = vaddq_u16(pa, pl);
+  uint16x4_t sum = vadd_u16(vget_low_u16(pal), vget_high_u16(pal));
+  uint8x8_t dc;
+  sum = vpadd_u16(sum, sum);
+  sum = vpadd_u16(sum, sum);
+  dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 6));
+  dc_store_32x32(dst, stride, dc);
 }
 
 void vpx_dc_left_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                       const uint8_t *above,
                                       const uint8_t *left) {
+  const uint16x4_t sum = dc_sum_32(left);
+  const uint8x8_t dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 5));
   (void)above;
-  dc_32x32(dst, stride, NULL, left, 0, 1);
+  dc_store_32x32(dst, stride, dc);
 }
 
 void vpx_dc_top_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                      const uint8_t *above,
                                      const uint8_t *left) {
+  const uint16x4_t sum = dc_sum_32(above);
+  const uint8x8_t dc = vreinterpret_u8_u16(vrshr_n_u16(sum, 5));
   (void)left;
-  dc_32x32(dst, stride, above, NULL, 1, 0);
+  dc_store_32x32(dst, stride, dc);
 }
 
 void vpx_dc_128_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
                                      const uint8_t *above,
                                      const uint8_t *left) {
+  const uint8x8_t dc = vdup_n_u8(0x80);
   (void)above;
   (void)left;
-  dc_32x32(dst, stride, NULL, NULL, 0, 0);
+  dc_store_32x32(dst, stride, dc);
 }
 
 // -----------------------------------------------------------------------------