Optimize Neon implementation of high bitdepth MSE functions
authorSalome Thirot <salome.thirot@arm.com>
Mon, 27 Feb 2023 17:58:18 +0000 (17:58 +0000)
committerSalome Thirot <salome.thirot@arm.com>
Wed, 1 Mar 2023 13:35:03 +0000 (13:35 +0000)
Currently MSE functions just call the variance helpers but don't
actually use the computed sum. This patch adds dedicated helpers to
perform the computation of sse.

Add the corresponding tests as well.

Change-Id: I96a8590e3410e84d77f7187344688e02efe03902

test/variance_test.cc
vpx_dsp/arm/highbd_variance_neon.c

index a68cfad..1359bc4 100644 (file)
@@ -1508,6 +1508,22 @@ INSTANTIATE_TEST_SUITE_P(
 
 #if CONFIG_VP9_HIGHBITDEPTH
 INSTANTIATE_TEST_SUITE_P(
+    NEON, VpxHBDMseTest,
+    ::testing::Values(
+        MseParams(4, 4, &vpx_highbd_12_mse16x16_neon, VPX_BITS_12),
+        MseParams(4, 3, &vpx_highbd_12_mse16x8_neon, VPX_BITS_12),
+        MseParams(3, 4, &vpx_highbd_12_mse8x16_neon, VPX_BITS_12),
+        MseParams(3, 3, &vpx_highbd_12_mse8x8_neon, VPX_BITS_12),
+        MseParams(4, 4, &vpx_highbd_10_mse16x16_neon, VPX_BITS_10),
+        MseParams(4, 3, &vpx_highbd_10_mse16x8_neon, VPX_BITS_10),
+        MseParams(3, 4, &vpx_highbd_10_mse8x16_neon, VPX_BITS_10),
+        MseParams(3, 3, &vpx_highbd_10_mse8x8_neon, VPX_BITS_10),
+        MseParams(4, 4, &vpx_highbd_8_mse16x16_neon, VPX_BITS_8),
+        MseParams(4, 3, &vpx_highbd_8_mse16x8_neon, VPX_BITS_8),
+        MseParams(3, 4, &vpx_highbd_8_mse8x16_neon, VPX_BITS_8),
+        MseParams(3, 3, &vpx_highbd_8_mse8x8_neon, VPX_BITS_8)));
+
+INSTANTIATE_TEST_SUITE_P(
     NEON, VpxHBDVarianceTest,
     ::testing::Values(
         VarianceParams(6, 6, &vpx_highbd_12_variance64x64_neon, 12),
index 89bd5c5..d0b366c 100644 (file)
@@ -351,50 +351,159 @@ HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 64)
     *sum = (int)ROUND_POWER_OF_TWO(sum_long, 4);                      \
   }
 
-#define HIGHBD_MSE(w, h)                                              \
-  uint32_t vpx_highbd_8_mse##w##x##h##_neon(                          \
-      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
-      int ref_stride, uint32_t *sse) {                                \
-    uint64_t sse_long = 0;                                            \
-    int64_t sum_long = 0;                                             \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
-    highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
-                                 &sse_long, &sum_long);               \
-    *sse = (uint32_t)sse_long;                                        \
-    return *sse;                                                      \
-  }                                                                   \
-                                                                      \
-  uint32_t vpx_highbd_10_mse##w##x##h##_neon(                         \
-      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
-      int ref_stride, uint32_t *sse) {                                \
-    uint64_t sse_long = 0;                                            \
-    int64_t sum_long = 0;                                             \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
-    highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
-                                 &sse_long, &sum_long);               \
-    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);                 \
-    return *sse;                                                      \
-  }                                                                   \
-                                                                      \
-  uint32_t vpx_highbd_12_mse##w##x##h##_neon(                         \
-      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
-      int ref_stride, uint32_t *sse) {                                \
-    uint64_t sse_long = 0;                                            \
-    int64_t sum_long = 0;                                             \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
-    highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
-                                 &sse_long, &sum_long);               \
-    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);                 \
-    return *sse;                                                      \
-  }
-
 HIGHBD_GET_VAR(8)
 HIGHBD_GET_VAR(16)
 
-HIGHBD_MSE(16, 16)
-HIGHBD_MSE(16, 8)
-HIGHBD_MSE(8, 16)
-HIGHBD_MSE(8, 8)
+static INLINE uint32_t highbd_mse_wxh_neon(const uint16_t *src_ptr,
+                                           int src_stride,
+                                           const uint16_t *ref_ptr,
+                                           int ref_stride, int w, int h,
+                                           unsigned int *sse) {
+  uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = h;
+  do {
+    int j = 0;
+    do {
+      uint16x8_t s = vld1q_u16(src_ptr + j);
+      uint16x8_t r = vld1q_u16(ref_ptr + j);
+
+      uint16x8_t diff = vabdq_u16(s, r);
+
+      sse_u32[0] =
+          vmlal_u16(sse_u32[0], vget_low_u16(diff), vget_low_u16(diff));
+      sse_u32[1] =
+          vmlal_u16(sse_u32[1], vget_high_u16(diff), vget_high_u16(diff));
+
+      j += 8;
+    } while (j < w);
+
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+  return *sse;
+}
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t highbd_mse8_8xh_neon(const uint16_t *src_ptr,
+                                            int src_stride,
+                                            const uint16_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            unsigned int *sse) {
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = h / 2;
+  do {
+    uint16x8_t s0, s1, r0, r1;
+    uint8x16_t s, r, diff;
+
+    s0 = vld1q_u16(src_ptr);
+    src_ptr += src_stride;
+    s1 = vld1q_u16(src_ptr);
+    src_ptr += src_stride;
+    r0 = vld1q_u16(ref_ptr);
+    ref_ptr += ref_stride;
+    r1 = vld1q_u16(ref_ptr);
+    ref_ptr += ref_stride;
+
+    s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
+    r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
+
+    diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, diff, diff);
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(sse_u32);
+  return *sse;
+}
+
+static INLINE uint32_t highbd_mse8_16xh_neon(const uint16_t *src_ptr,
+                                             int src_stride,
+                                             const uint16_t *ref_ptr,
+                                             int ref_stride, int h,
+                                             unsigned int *sse) {
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = h;
+  do {
+    uint16x8_t s0, s1, r0, r1;
+    uint8x16_t s, r, diff;
+
+    s0 = vld1q_u16(src_ptr);
+    s1 = vld1q_u16(src_ptr + 8);
+    r0 = vld1q_u16(ref_ptr);
+    r1 = vld1q_u16(ref_ptr + 8);
+
+    s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
+    r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
+
+    diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, diff, diff);
+
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(sse_u32);
+  return *sse;
+}
+
+#else  // !defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t highbd_mse8_8xh_neon(const uint16_t *src_ptr,
+                                            int src_stride,
+                                            const uint16_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            unsigned int *sse) {
+  return highbd_mse_wxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 8, h,
+                             sse);
+}
+
+static INLINE uint32_t highbd_mse8_16xh_neon(const uint16_t *src_ptr,
+                                             int src_stride,
+                                             const uint16_t *ref_ptr,
+                                             int ref_stride, int h,
+                                             unsigned int *sse) {
+  return highbd_mse_wxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 16, h,
+                             sse);
+}
+
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
+#define HIGHBD_MSE_WXH_NEON(w, h)                                       \
+  uint32_t vpx_highbd_8_mse##w##x##h##_neon(                            \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse8_##w##xh_neon(src, src_stride, ref, ref_stride, h, sse); \
+    return *sse;                                                        \
+  }                                                                     \
+                                                                        \
+  uint32_t vpx_highbd_10_mse##w##x##h##_neon(                           \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse);   \
+    *sse = ROUND_POWER_OF_TWO(*sse, 4);                                 \
+    return *sse;                                                        \
+  }                                                                     \
+                                                                        \
+  uint32_t vpx_highbd_12_mse##w##x##h##_neon(                           \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse);   \
+    *sse = ROUND_POWER_OF_TWO(*sse, 8);                                 \
+    return *sse;                                                        \
+  }
+
+HIGHBD_MSE_WXH_NEON(16, 16)
+HIGHBD_MSE_WXH_NEON(16, 8)
+HIGHBD_MSE_WXH_NEON(8, 16)
+HIGHBD_MSE_WXH_NEON(8, 8)