Add vpx_sse and vpx_highbd_sse
authorWan-Teh Chang <wtc@google.com>
Fri, 29 Sep 2023 17:45:32 +0000 (10:45 -0700)
committerWan-Teh Chang <wtc@google.com>
Mon, 20 Nov 2023 22:59:27 +0000 (14:59 -0800)
The code is ported from libaom's aom_sse and aom_highbd_sse at
commit 1e20d2da96515524864b21010dbe23809cff2e9b.

The vpx_sse and vpx_highbd_sse functions will be used by vpx_dsp/psnr.c.

Bug: webm:1819
Change-Id: I4fbffa9000ab92755de5387b1ddd4370cb7020f7

test/sum_squares_test.cc
vpx_dsp/arm/highbd_sse_neon.c [new file with mode: 0644]
vpx_dsp/arm/sse_neon.c [new file with mode: 0644]
vpx_dsp/arm/sse_neon_dotprod.c [new file with mode: 0644]
vpx_dsp/arm/sum_neon.h
vpx_dsp/sse.c [new file with mode: 0644]
vpx_dsp/vpx_dsp.mk
vpx_dsp/vpx_dsp_rtcd_defs.pl
vpx_dsp/x86/sse_avx2.c [new file with mode: 0644]
vpx_dsp/x86/sse_sse4.c [new file with mode: 0644]

index 5abb464..725d5eb 100644 (file)
 #include "test/clear_system_state.h"
 #include "test/register_state_check.h"
 #include "test/util.h"
+#include "vpx_mem/vpx_mem.h"
 #include "vpx_ports/mem.h"
+#include "vpx_ports/vpx_timer.h"
 
 using libvpx_test::ACMRandom;
+using ::testing::Combine;
+using ::testing::Range;
+using ::testing::ValuesIn;
 
 namespace {
 const int kNumIterations = 10000;
@@ -126,4 +131,210 @@ INSTANTIATE_TEST_SUITE_P(
     ::testing::Values(make_tuple(&vpx_sum_squares_2d_i16_c,
                                  &vpx_sum_squares_2d_i16_msa)));
 #endif  // HAVE_MSA
+
+typedef int64_t (*SSEFunc)(const uint8_t *a, int a_stride, const uint8_t *b,
+                           int b_stride, int width, int height);
+
+struct TestSSEFuncs {
+  TestSSEFuncs(SSEFunc ref = nullptr, SSEFunc tst = nullptr, int depth = 0)
+      : ref_func(ref), tst_func(tst), bit_depth(depth) {}
+  SSEFunc ref_func;  // Pointer to reference function
+  SSEFunc tst_func;  // Pointer to tested function
+  int bit_depth;
+};
+
+typedef std::tuple<TestSSEFuncs, int> SSETestParam;
+
+class SSETest : public ::testing::TestWithParam<SSETestParam> {
+ public:
+  ~SSETest() override = default;
+  void SetUp() override {
+    params_ = GET_PARAM(0);
+    width_ = GET_PARAM(1);
+    is_hbd_ =
+#if CONFIG_VP9_HIGHBITDEPTH
+        params_.ref_func == vpx_highbd_sse_c;
+#else
+        false;
+#endif
+    rnd_.Reset(ACMRandom::DeterministicSeed());
+    src_ = reinterpret_cast<uint8_t *>(vpx_memalign(32, 256 * 256 * 2));
+    ref_ = reinterpret_cast<uint8_t *>(vpx_memalign(32, 256 * 256 * 2));
+    ASSERT_NE(src_, nullptr);
+    ASSERT_NE(ref_, nullptr);
+  }
+
+  void TearDown() override {
+    vpx_free(src_);
+    vpx_free(ref_);
+  }
+  void RunTest(bool is_random, int width, int height, int run_times);
+
+  void GenRandomData(int width, int height, int stride) {
+    uint16_t *src16 = reinterpret_cast<uint16_t *>(src_);
+    uint16_t *ref16 = reinterpret_cast<uint16_t *>(ref_);
+    const int msb = 11;  // Up to 12 bit input
+    const int limit = 1 << (msb + 1);
+    for (int ii = 0; ii < height; ii++) {
+      for (int jj = 0; jj < width; jj++) {
+        if (!is_hbd_) {
+          src_[ii * stride + jj] = rnd_.Rand8();
+          ref_[ii * stride + jj] = rnd_.Rand8();
+        } else {
+          src16[ii * stride + jj] = rnd_(limit);
+          ref16[ii * stride + jj] = rnd_(limit);
+        }
+      }
+    }
+  }
+
+  void GenExtremeData(int width, int height, int stride, uint8_t *data,
+                      int16_t val) {
+    uint16_t *data16 = reinterpret_cast<uint16_t *>(data);
+    for (int ii = 0; ii < height; ii++) {
+      for (int jj = 0; jj < width; jj++) {
+        if (!is_hbd_) {
+          data[ii * stride + jj] = static_cast<uint8_t>(val);
+        } else {
+          data16[ii * stride + jj] = val;
+        }
+      }
+    }
+  }
+
+ protected:
+  bool is_hbd_;
+  int width_;
+  TestSSEFuncs params_;
+  uint8_t *src_;
+  uint8_t *ref_;
+  ACMRandom rnd_;
+};
+GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(SSETest);
+
+void SSETest::RunTest(bool is_random, int width, int height, int run_times) {
+  int failed = 0;
+  vpx_usec_timer ref_timer, test_timer;
+  for (int k = 0; k < 3; k++) {
+    int stride = 4 << rnd_(7);  // Up to 256 stride
+    while (stride < width) {    // Make sure it's valid
+      stride = 4 << rnd_(7);
+    }
+    if (is_random) {
+      GenRandomData(width, height, stride);
+    } else {
+      const int msb = is_hbd_ ? 12 : 8;  // Up to 12 bit input
+      const int limit = (1 << msb) - 1;
+      if (k == 0) {
+        GenExtremeData(width, height, stride, src_, 0);
+        GenExtremeData(width, height, stride, ref_, limit);
+      } else {
+        GenExtremeData(width, height, stride, src_, limit);
+        GenExtremeData(width, height, stride, ref_, 0);
+      }
+    }
+    int64_t res_ref, res_tst;
+    uint8_t *src = src_;
+    uint8_t *ref = ref_;
+#if CONFIG_VP9_HIGHBITDEPTH
+    if (is_hbd_) {
+      src = CONVERT_TO_BYTEPTR(src_);
+      ref = CONVERT_TO_BYTEPTR(ref_);
+    }
+#endif
+    res_ref = params_.ref_func(src, stride, ref, stride, width, height);
+    res_tst = params_.tst_func(src, stride, ref, stride, width, height);
+    if (run_times > 1) {
+      vpx_usec_timer_start(&ref_timer);
+      for (int j = 0; j < run_times; j++) {
+        params_.ref_func(src, stride, ref, stride, width, height);
+      }
+      vpx_usec_timer_mark(&ref_timer);
+      const int elapsed_time_c =
+          static_cast<int>(vpx_usec_timer_elapsed(&ref_timer));
+
+      vpx_usec_timer_start(&test_timer);
+      for (int j = 0; j < run_times; j++) {
+        params_.tst_func(src, stride, ref, stride, width, height);
+      }
+      vpx_usec_timer_mark(&test_timer);
+      const int elapsed_time_simd =
+          static_cast<int>(vpx_usec_timer_elapsed(&test_timer));
+
+      printf(
+          "c_time=%d \t simd_time=%d \t "
+          "gain=%d\n",
+          elapsed_time_c, elapsed_time_simd,
+          (elapsed_time_c / elapsed_time_simd));
+    } else {
+      if (!failed) {
+        failed = res_ref != res_tst;
+        EXPECT_EQ(res_ref, res_tst)
+            << "Error:" << (is_hbd_ ? "hbd " : " ") << k << " SSE Test ["
+            << width << "x" << height
+            << "] C output does not match optimized output.";
+      }
+    }
+  }
+}
+
+TEST_P(SSETest, OperationCheck) {
+  for (int height = 4; height <= 128; height += 4) {
+    RunTest(true, width_, height, 1);  // GenRandomData
+  }
+}
+
+TEST_P(SSETest, ExtremeValues) {
+  for (int height = 4; height <= 128; height += 4) {
+    RunTest(false, width_, height, 1);
+  }
+}
+
+TEST_P(SSETest, DISABLED_Speed) {
+  for (int height = 4; height <= 128; height += 4) {
+    RunTest(true, width_, height, 100);
+  }
+}
+
+#if HAVE_NEON
+TestSSEFuncs sse_neon[] = {
+  TestSSEFuncs(&vpx_sse_c, &vpx_sse_neon),
+#if CONFIG_VP9_HIGHBITDEPTH
+  TestSSEFuncs(&vpx_highbd_sse_c, &vpx_highbd_sse_neon)
+#endif
+};
+INSTANTIATE_TEST_SUITE_P(NEON, SSETest,
+                         Combine(ValuesIn(sse_neon), Range(4, 129, 4)));
+#endif  // HAVE_NEON
+
+#if HAVE_NEON_DOTPROD
+TestSSEFuncs sse_neon_dotprod[] = {
+  TestSSEFuncs(&vpx_sse_c, &vpx_sse_neon_dotprod),
+};
+INSTANTIATE_TEST_SUITE_P(NEON_DOTPROD, SSETest,
+                         Combine(ValuesIn(sse_neon_dotprod), Range(4, 129, 4)));
+#endif  // HAVE_NEON_DOTPROD
+
+#if HAVE_SSE4_1
+TestSSEFuncs sse_sse4[] = {
+  TestSSEFuncs(&vpx_sse_c, &vpx_sse_sse4_1),
+#if CONFIG_VP9_HIGHBITDEPTH
+  TestSSEFuncs(&vpx_highbd_sse_c, &vpx_highbd_sse_sse4_1)
+#endif
+};
+INSTANTIATE_TEST_SUITE_P(SSE4_1, SSETest,
+                         Combine(ValuesIn(sse_sse4), Range(4, 129, 4)));
+#endif  // HAVE_SSE4_1
+
+#if HAVE_AVX2
+
+TestSSEFuncs sse_avx2[] = {
+  TestSSEFuncs(&vpx_sse_c, &vpx_sse_avx2),
+#if CONFIG_VP9_HIGHBITDEPTH
+  TestSSEFuncs(&vpx_highbd_sse_c, &vpx_highbd_sse_avx2)
+#endif
+};
+INSTANTIATE_TEST_SUITE_P(AVX2, SSETest,
+                         Combine(ValuesIn(sse_avx2), Range(4, 129, 4)));
+#endif  // HAVE_AVX2
 }  // namespace
diff --git a/vpx_dsp/arm/highbd_sse_neon.c b/vpx_dsp/arm/highbd_sse_neon.c
new file mode 100644 (file)
index 0000000..717ad6b
--- /dev/null
@@ -0,0 +1,288 @@
+/*
+ *  Copyright (c) 2023 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include <arm_neon.h>
+
+#include "./vpx_dsp_rtcd.h"
+#include "vpx_dsp/arm/sum_neon.h"
+
+static INLINE void highbd_sse_8x1_init_neon(const uint16_t *src,
+                                            const uint16_t *ref,
+                                            uint32x4_t *sse_acc0,
+                                            uint32x4_t *sse_acc1) {
+  uint16x8_t s = vld1q_u16(src);
+  uint16x8_t r = vld1q_u16(ref);
+
+  uint16x8_t abs_diff = vabdq_u16(s, r);
+  uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
+  uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
+
+  *sse_acc0 = vmull_u16(abs_diff_lo, abs_diff_lo);
+  *sse_acc1 = vmull_u16(abs_diff_hi, abs_diff_hi);
+}
+
+static INLINE void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref,
+                                       uint32x4_t *sse_acc0,
+                                       uint32x4_t *sse_acc1) {
+  uint16x8_t s = vld1q_u16(src);
+  uint16x8_t r = vld1q_u16(ref);
+
+  uint16x8_t abs_diff = vabdq_u16(s, r);
+  uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
+  uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
+
+  *sse_acc0 = vmlal_u16(*sse_acc0, abs_diff_lo, abs_diff_lo);
+  *sse_acc1 = vmlal_u16(*sse_acc1, abs_diff_hi, abs_diff_hi);
+}
+
+static INLINE int64_t highbd_sse_128xh_neon(const uint16_t *src, int src_stride,
+                                            const uint16_t *ref, int ref_stride,
+                                            int height) {
+  uint32x4_t sse[16];
+  highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
+  highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
+  highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
+  highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
+  highbd_sse_8x1_init_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
+  highbd_sse_8x1_init_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
+  highbd_sse_8x1_init_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
+  highbd_sse_8x1_init_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
+  highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
+  highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
+  highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
+  highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
+  highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
+  highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
+  highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
+  highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
+
+  src += src_stride;
+  ref += ref_stride;
+
+  while (--height != 0) {
+    highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
+    highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
+    highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
+    highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
+    highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
+    highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
+    highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
+    highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
+    highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
+    highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
+    highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
+    highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
+    highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
+    highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
+    highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
+    highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
+
+    src += src_stride;
+    ref += ref_stride;
+  }
+
+  return horizontal_long_add_uint32x4_x16(sse);
+}
+
+static INLINE int64_t highbd_sse_64xh_neon(const uint16_t *src, int src_stride,
+                                           const uint16_t *ref, int ref_stride,
+                                           int height) {
+  uint32x4_t sse[8];
+  highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
+  highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
+  highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
+  highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
+  highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
+  highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
+  highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
+  highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
+
+  src += src_stride;
+  ref += ref_stride;
+
+  while (--height != 0) {
+    highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
+    highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
+    highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
+    highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
+    highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
+    highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
+    highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
+    highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
+
+    src += src_stride;
+    ref += ref_stride;
+  }
+
+  return horizontal_long_add_uint32x4_x8(sse);
+}
+
+static INLINE int64_t highbd_sse_32xh_neon(const uint16_t *src, int src_stride,
+                                           const uint16_t *ref, int ref_stride,
+                                           int height) {
+  uint32x4_t sse[8];
+  highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
+  highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
+  highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
+  highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
+
+  src += src_stride;
+  ref += ref_stride;
+
+  while (--height != 0) {
+    highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
+    highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
+    highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
+    highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
+
+    src += src_stride;
+    ref += ref_stride;
+  }
+
+  return horizontal_long_add_uint32x4_x8(sse);
+}
+
+static INLINE int64_t highbd_sse_16xh_neon(const uint16_t *src, int src_stride,
+                                           const uint16_t *ref, int ref_stride,
+                                           int height) {
+  uint32x4_t sse[4];
+  highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
+  highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
+
+  src += src_stride;
+  ref += ref_stride;
+
+  while (--height != 0) {
+    highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
+    highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
+
+    src += src_stride;
+    ref += ref_stride;
+  }
+
+  return horizontal_long_add_uint32x4_x4(sse);
+}
+
+static INLINE int64_t highbd_sse_8xh_neon(const uint16_t *src, int src_stride,
+                                          const uint16_t *ref, int ref_stride,
+                                          int height) {
+  uint32x4_t sse[2];
+  highbd_sse_8x1_init_neon(src, ref, &sse[0], &sse[1]);
+
+  src += src_stride;
+  ref += ref_stride;
+
+  while (--height != 0) {
+    highbd_sse_8x1_neon(src, ref, &sse[0], &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+  }
+
+  return horizontal_long_add_uint32x4_x2(sse);
+}
+
+static INLINE int64_t highbd_sse_4xh_neon(const uint16_t *src, int src_stride,
+                                          const uint16_t *ref, int ref_stride,
+                                          int height) {
+  // Peel the first loop iteration.
+  uint16x4_t s = vld1_u16(src);
+  uint16x4_t r = vld1_u16(ref);
+
+  uint16x4_t abs_diff = vabd_u16(s, r);
+  uint32x4_t sse = vmull_u16(abs_diff, abs_diff);
+
+  src += src_stride;
+  ref += ref_stride;
+
+  while (--height != 0) {
+    s = vld1_u16(src);
+    r = vld1_u16(ref);
+
+    abs_diff = vabd_u16(s, r);
+    sse = vmlal_u16(sse, abs_diff, abs_diff);
+
+    src += src_stride;
+    ref += ref_stride;
+  }
+
+  return horizontal_long_add_uint32x4(sse);
+}
+
+static INLINE int64_t highbd_sse_wxh_neon(const uint16_t *src, int src_stride,
+                                          const uint16_t *ref, int ref_stride,
+                                          int width, int height) {
+  // { 0, 1, 2, 3, 4, 5, 6, 7 }
+  uint16x8_t k01234567 = vmovl_u8(vcreate_u8(0x0706050403020100));
+  uint16x8_t remainder_mask = vcltq_u16(k01234567, vdupq_n_u16(width & 7));
+  uint64_t sse = 0;
+
+  do {
+    int w = width;
+    int offset = 0;
+
+    do {
+      uint16x8_t s = vld1q_u16(src + offset);
+      uint16x8_t r = vld1q_u16(ref + offset);
+      uint16x8_t abs_diff;
+      uint16x4_t abs_diff_lo;
+      uint16x4_t abs_diff_hi;
+      uint32x4_t sse_u32;
+
+      if (w < 8) {
+        // Mask out-of-range elements.
+        s = vandq_u16(s, remainder_mask);
+        r = vandq_u16(r, remainder_mask);
+      }
+
+      abs_diff = vabdq_u16(s, r);
+      abs_diff_lo = vget_low_u16(abs_diff);
+      abs_diff_hi = vget_high_u16(abs_diff);
+
+      sse_u32 = vmull_u16(abs_diff_lo, abs_diff_lo);
+      sse_u32 = vmlal_u16(sse_u32, abs_diff_hi, abs_diff_hi);
+
+      sse += horizontal_long_add_uint32x4(sse_u32);
+
+      offset += 8;
+      w -= 8;
+    } while (w > 0);
+
+    src += src_stride;
+    ref += ref_stride;
+  } while (--height != 0);
+
+  return sse;
+}
+
+int64_t vpx_highbd_sse_neon(const uint8_t *src8, int src_stride,
+                            const uint8_t *ref8, int ref_stride, int width,
+                            int height) {
+  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+
+  switch (width) {
+    case 4:
+      return highbd_sse_4xh_neon(src, src_stride, ref, ref_stride, height);
+    case 8:
+      return highbd_sse_8xh_neon(src, src_stride, ref, ref_stride, height);
+    case 16:
+      return highbd_sse_16xh_neon(src, src_stride, ref, ref_stride, height);
+    case 32:
+      return highbd_sse_32xh_neon(src, src_stride, ref, ref_stride, height);
+    case 64:
+      return highbd_sse_64xh_neon(src, src_stride, ref, ref_stride, height);
+    case 128:
+      return highbd_sse_128xh_neon(src, src_stride, ref, ref_stride, height);
+    default:
+      return highbd_sse_wxh_neon(src, src_stride, ref, ref_stride, width,
+                                 height);
+  }
+}
diff --git a/vpx_dsp/arm/sse_neon.c b/vpx_dsp/arm/sse_neon.c
new file mode 100644 (file)
index 0000000..0b4a6e5
--- /dev/null
@@ -0,0 +1,210 @@
+/*
+ *  Copyright (c) 2023 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include <arm_neon.h>
+
+#include "./vpx_dsp_rtcd.h"
+#include "vpx_dsp/arm/mem_neon.h"
+#include "vpx_dsp/arm/sum_neon.h"
+
+static INLINE void sse_16x1_neon(const uint8_t *src, const uint8_t *ref,
+                                 uint32x4_t *sse) {
+  uint8x16_t s = vld1q_u8(src);
+  uint8x16_t r = vld1q_u8(ref);
+
+  uint8x16_t abs_diff = vabdq_u8(s, r);
+  uint8x8_t abs_diff_lo = vget_low_u8(abs_diff);
+  uint8x8_t abs_diff_hi = vget_high_u8(abs_diff);
+
+  *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_lo, abs_diff_lo));
+  *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_hi, abs_diff_hi));
+}
+
+static INLINE void sse_8x1_neon(const uint8_t *src, const uint8_t *ref,
+                                uint32x4_t *sse) {
+  uint8x8_t s = vld1_u8(src);
+  uint8x8_t r = vld1_u8(ref);
+
+  uint8x8_t abs_diff = vabd_u8(s, r);
+
+  *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
+}
+
+static INLINE void sse_4x2_neon(const uint8_t *src, int src_stride,
+                                const uint8_t *ref, int ref_stride,
+                                uint32x4_t *sse) {
+  uint8x8_t s = load_unaligned_u8(src, src_stride);
+  uint8x8_t r = load_unaligned_u8(ref, ref_stride);
+
+  uint8x8_t abs_diff = vabd_u8(s, r);
+
+  *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
+}
+
+static INLINE uint32_t sse_wxh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int width, int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  if ((width & 0x07) && ((width & 0x07) < 5)) {
+    int i = height;
+    do {
+      int j = 0;
+      do {
+        sse_8x1_neon(src + j, ref + j, &sse);
+        sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse);
+        j += 8;
+      } while (j + 4 < width);
+
+      sse_4x2_neon(src + j, src_stride, ref + j, ref_stride, &sse);
+      src += 2 * src_stride;
+      ref += 2 * ref_stride;
+      i -= 2;
+    } while (i != 0);
+  } else {
+    int i = height;
+    do {
+      int j = 0;
+      do {
+        sse_8x1_neon(src + j, ref + j, &sse);
+        j += 8;
+      } while (j < width);
+
+      src += src_stride;
+      ref += ref_stride;
+    } while (--i != 0);
+  }
+  return horizontal_add_uint32x4(sse);
+}
+
+static INLINE uint32_t sse_128xh_neon(const uint8_t *src, int src_stride,
+                                      const uint8_t *ref, int ref_stride,
+                                      int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_16x1_neon(src, ref, &sse[0]);
+    sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+    sse_16x1_neon(src + 32, ref + 32, &sse[0]);
+    sse_16x1_neon(src + 48, ref + 48, &sse[1]);
+    sse_16x1_neon(src + 64, ref + 64, &sse[0]);
+    sse_16x1_neon(src + 80, ref + 80, &sse[1]);
+    sse_16x1_neon(src + 96, ref + 96, &sse[0]);
+    sse_16x1_neon(src + 112, ref + 112, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_64xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_16x1_neon(src, ref, &sse[0]);
+    sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+    sse_16x1_neon(src + 32, ref + 32, &sse[0]);
+    sse_16x1_neon(src + 48, ref + 48, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_32xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_16x1_neon(src, ref, &sse[0]);
+    sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_16xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_16x1_neon(src, ref, &sse[0]);
+    src += src_stride;
+    ref += ref_stride;
+    sse_16x1_neon(src, ref, &sse[1]);
+    src += src_stride;
+    ref += ref_stride;
+    i -= 2;
+  } while (i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_8xh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  int i = height;
+  do {
+    sse_8x1_neon(src, ref, &sse);
+
+    src += src_stride;
+    ref += ref_stride;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(sse);
+}
+
+static INLINE uint32_t sse_4xh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  int i = height;
+  do {
+    sse_4x2_neon(src, src_stride, ref, ref_stride, &sse);
+
+    src += 2 * src_stride;
+    ref += 2 * ref_stride;
+    i -= 2;
+  } while (i != 0);
+
+  return horizontal_add_uint32x4(sse);
+}
+
+int64_t vpx_sse_neon(const uint8_t *src, int src_stride, const uint8_t *ref,
+                     int ref_stride, int width, int height) {
+  switch (width) {
+    case 4: return sse_4xh_neon(src, src_stride, ref, ref_stride, height);
+    case 8: return sse_8xh_neon(src, src_stride, ref, ref_stride, height);
+    case 16: return sse_16xh_neon(src, src_stride, ref, ref_stride, height);
+    case 32: return sse_32xh_neon(src, src_stride, ref, ref_stride, height);
+    case 64: return sse_64xh_neon(src, src_stride, ref, ref_stride, height);
+    case 128: return sse_128xh_neon(src, src_stride, ref, ref_stride, height);
+    default:
+      return sse_wxh_neon(src, src_stride, ref, ref_stride, width, height);
+  }
+}
diff --git a/vpx_dsp/arm/sse_neon_dotprod.c b/vpx_dsp/arm/sse_neon_dotprod.c
new file mode 100644 (file)
index 0000000..0f11b7c
--- /dev/null
@@ -0,0 +1,223 @@
+/*
+ *  Copyright (c) 2023 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include <arm_neon.h>
+
+#include "./vpx_dsp_rtcd.h"
+#include "vpx_dsp/arm/mem_neon.h"
+#include "vpx_dsp/arm/sum_neon.h"
+
+static INLINE void sse_16x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
+                                         uint32x4_t *sse) {
+  uint8x16_t s = vld1q_u8(src);
+  uint8x16_t r = vld1q_u8(ref);
+
+  uint8x16_t abs_diff = vabdq_u8(s, r);
+
+  *sse = vdotq_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE void sse_8x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
+                                        uint32x2_t *sse) {
+  uint8x8_t s = vld1_u8(src);
+  uint8x8_t r = vld1_u8(ref);
+
+  uint8x8_t abs_diff = vabd_u8(s, r);
+
+  *sse = vdot_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE void sse_4x2_neon_dotprod(const uint8_t *src, int src_stride,
+                                        const uint8_t *ref, int ref_stride,
+                                        uint32x2_t *sse) {
+  uint8x8_t s = load_unaligned_u8(src, src_stride);
+  uint8x8_t r = load_unaligned_u8(ref, ref_stride);
+
+  uint8x8_t abs_diff = vabd_u8(s, r);
+
+  *sse = vdot_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE uint32_t sse_wxh_neon_dotprod(const uint8_t *src, int src_stride,
+                                            const uint8_t *ref, int ref_stride,
+                                            int width, int height) {
+  uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
+
+  if ((width & 0x07) && ((width & 0x07) < 5)) {
+    int i = height;
+    do {
+      int j = 0;
+      do {
+        sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
+        sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
+                             &sse[1]);
+        j += 8;
+      } while (j + 4 < width);
+
+      sse_4x2_neon_dotprod(src + j, src_stride, ref + j, ref_stride, &sse[0]);
+      src += 2 * src_stride;
+      ref += 2 * ref_stride;
+      i -= 2;
+    } while (i != 0);
+  } else {
+    int i = height;
+    do {
+      int j = 0;
+      do {
+        sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
+        sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
+                             &sse[1]);
+        j += 8;
+      } while (j < width);
+
+      src += 2 * src_stride;
+      ref += 2 * ref_stride;
+      i -= 2;
+    } while (i != 0);
+  }
+  return horizontal_add_uint32x4(vcombine_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_128xh_neon_dotprod(const uint8_t *src,
+                                              int src_stride,
+                                              const uint8_t *ref,
+                                              int ref_stride, int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_16x1_neon_dotprod(src, ref, &sse[0]);
+    sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
+    sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
+    sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
+    sse_16x1_neon_dotprod(src + 64, ref + 64, &sse[0]);
+    sse_16x1_neon_dotprod(src + 80, ref + 80, &sse[1]);
+    sse_16x1_neon_dotprod(src + 96, ref + 96, &sse[0]);
+    sse_16x1_neon_dotprod(src + 112, ref + 112, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_64xh_neon_dotprod(const uint8_t *src, int src_stride,
+                                             const uint8_t *ref, int ref_stride,
+                                             int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_16x1_neon_dotprod(src, ref, &sse[0]);
+    sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
+    sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
+    sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_32xh_neon_dotprod(const uint8_t *src, int src_stride,
+                                             const uint8_t *ref, int ref_stride,
+                                             int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_16x1_neon_dotprod(src, ref, &sse[0]);
+    sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_16xh_neon_dotprod(const uint8_t *src, int src_stride,
+                                             const uint8_t *ref, int ref_stride,
+                                             int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_16x1_neon_dotprod(src, ref, &sse[0]);
+    src += src_stride;
+    ref += ref_stride;
+    sse_16x1_neon_dotprod(src, ref, &sse[1]);
+    src += src_stride;
+    ref += ref_stride;
+    i -= 2;
+  } while (i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_8xh_neon_dotprod(const uint8_t *src, int src_stride,
+                                            const uint8_t *ref, int ref_stride,
+                                            int height) {
+  uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
+
+  int i = height;
+  do {
+    sse_8x1_neon_dotprod(src, ref, &sse[0]);
+    src += src_stride;
+    ref += ref_stride;
+    sse_8x1_neon_dotprod(src, ref, &sse[1]);
+    src += src_stride;
+    ref += ref_stride;
+    i -= 2;
+  } while (i != 0);
+
+  return horizontal_add_uint32x4(vcombine_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_4xh_neon_dotprod(const uint8_t *src, int src_stride,
+                                            const uint8_t *ref, int ref_stride,
+                                            int height) {
+  uint32x2_t sse = vdup_n_u32(0);
+
+  int i = height;
+  do {
+    sse_4x2_neon_dotprod(src, src_stride, ref, ref_stride, &sse);
+
+    src += 2 * src_stride;
+    ref += 2 * ref_stride;
+    i -= 2;
+  } while (i != 0);
+
+  return horizontal_add_uint32x2(sse);
+}
+
+int64_t vpx_sse_neon_dotprod(const uint8_t *src, int src_stride,
+                             const uint8_t *ref, int ref_stride, int width,
+                             int height) {
+  switch (width) {
+    case 4:
+      return sse_4xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
+    case 8:
+      return sse_8xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
+    case 16:
+      return sse_16xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
+    case 32:
+      return sse_32xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
+    case 64:
+      return sse_64xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
+    case 128:
+      return sse_128xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
+    default:
+      return sse_wxh_neon_dotprod(src, src_stride, ref, ref_stride, width,
+                                  height);
+  }
+}
index 75c170d..11821dc 100644 (file)
@@ -221,4 +221,55 @@ static INLINE uint64_t horizontal_add_uint64x2(const uint64x2_t a) {
 #endif
 }
 
+static INLINE uint64_t horizontal_long_add_uint32x4_x2(const uint32x4_t a[2]) {
+  return horizontal_long_add_uint32x4(a[0]) +
+         horizontal_long_add_uint32x4(a[1]);
+}
+
+static INLINE uint64_t horizontal_long_add_uint32x4_x4(const uint32x4_t a[4]) {
+  uint64x2_t sum = vpaddlq_u32(a[0]);
+  sum = vpadalq_u32(sum, a[1]);
+  sum = vpadalq_u32(sum, a[2]);
+  sum = vpadalq_u32(sum, a[3]);
+
+  return horizontal_add_uint64x2(sum);
+}
+
+static INLINE uint64_t horizontal_long_add_uint32x4_x8(const uint32x4_t a[8]) {
+  uint64x2_t sum[2];
+  sum[0] = vpaddlq_u32(a[0]);
+  sum[1] = vpaddlq_u32(a[1]);
+  sum[0] = vpadalq_u32(sum[0], a[2]);
+  sum[1] = vpadalq_u32(sum[1], a[3]);
+  sum[0] = vpadalq_u32(sum[0], a[4]);
+  sum[1] = vpadalq_u32(sum[1], a[5]);
+  sum[0] = vpadalq_u32(sum[0], a[6]);
+  sum[1] = vpadalq_u32(sum[1], a[7]);
+
+  return horizontal_add_uint64x2(vaddq_u64(sum[0], sum[1]));
+}
+
+static INLINE uint64_t
+horizontal_long_add_uint32x4_x16(const uint32x4_t a[16]) {
+  uint64x2_t sum[2];
+  sum[0] = vpaddlq_u32(a[0]);
+  sum[1] = vpaddlq_u32(a[1]);
+  sum[0] = vpadalq_u32(sum[0], a[2]);
+  sum[1] = vpadalq_u32(sum[1], a[3]);
+  sum[0] = vpadalq_u32(sum[0], a[4]);
+  sum[1] = vpadalq_u32(sum[1], a[5]);
+  sum[0] = vpadalq_u32(sum[0], a[6]);
+  sum[1] = vpadalq_u32(sum[1], a[7]);
+  sum[0] = vpadalq_u32(sum[0], a[8]);
+  sum[1] = vpadalq_u32(sum[1], a[9]);
+  sum[0] = vpadalq_u32(sum[0], a[10]);
+  sum[1] = vpadalq_u32(sum[1], a[11]);
+  sum[0] = vpadalq_u32(sum[0], a[12]);
+  sum[1] = vpadalq_u32(sum[1], a[13]);
+  sum[0] = vpadalq_u32(sum[0], a[14]);
+  sum[1] = vpadalq_u32(sum[1], a[15]);
+
+  return horizontal_add_uint64x2(vaddq_u64(sum[0], sum[1]));
+}
+
 #endif  // VPX_VPX_DSP_ARM_SUM_NEON_H_
diff --git a/vpx_dsp/sse.c b/vpx_dsp/sse.c
new file mode 100644 (file)
index 0000000..6cb4b70
--- /dev/null
@@ -0,0 +1,58 @@
+/*
+ *  Copyright (c) 2023 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+/*
+ * Sum the square of the difference between every corresponding element of the
+ * buffers.
+ */
+
+#include <stdlib.h>
+
+#include "./vpx_config.h"
+#include "./vpx_dsp_rtcd.h"
+
+#include "vpx/vpx_integer.h"
+
+int64_t vpx_sse_c(const uint8_t *a, int a_stride, const uint8_t *b,
+                  int b_stride, int width, int height) {
+  int y, x;
+  int64_t sse = 0;
+
+  for (y = 0; y < height; y++) {
+    for (x = 0; x < width; x++) {
+      const int32_t diff = abs(a[x] - b[x]);
+      sse += diff * diff;
+    }
+
+    a += a_stride;
+    b += b_stride;
+  }
+  return sse;
+}
+
+#if CONFIG_VP9_HIGHBITDEPTH
+int64_t vpx_highbd_sse_c(const uint8_t *a8, int a_stride, const uint8_t *b8,
+                         int b_stride, int width, int height) {
+  int y, x;
+  int64_t sse = 0;
+  uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+  uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+  for (y = 0; y < height; y++) {
+    for (x = 0; x < width; x++) {
+      const int32_t diff = (int32_t)(a[x]) - (int32_t)(b[x]);
+      sse += diff * diff;
+    }
+
+    a += a_stride;
+    b += b_stride;
+  }
+  return sse;
+}
+#endif
index 84fd969..93abf39 100644 (file)
@@ -31,10 +31,15 @@ DSP_SRCS-yes += bitwriter_buffer.c
 DSP_SRCS-yes += bitwriter_buffer.h
 DSP_SRCS-yes += psnr.c
 DSP_SRCS-yes += psnr.h
+DSP_SRCS-yes += sse.c
 DSP_SRCS-$(CONFIG_INTERNAL_STATS) += ssim.c
 DSP_SRCS-$(CONFIG_INTERNAL_STATS) += ssim.h
 DSP_SRCS-$(CONFIG_INTERNAL_STATS) += psnrhvs.c
 DSP_SRCS-$(CONFIG_INTERNAL_STATS) += fastssim.c
+DSP_SRCS-$(HAVE_NEON) += arm/sse_neon.c
+DSP_SRCS-$(HAVE_NEON_DOTPROD) += arm/sse_neon_dotprod.c
+DSP_SRCS-$(HAVE_SSE4_1) += x86/sse_sse4.c
+DSP_SRCS-$(HAVE_AVX2) += x86/sse_avx2.c
 endif
 
 ifeq ($(CONFIG_DECODERS),yes)
@@ -447,6 +452,7 @@ DSP_SRCS-$(HAVE_SSE2)   += x86/highbd_variance_sse2.c
 DSP_SRCS-$(HAVE_SSE2)   += x86/highbd_variance_impl_sse2.asm
 DSP_SRCS-$(HAVE_SSE2)   += x86/highbd_subpel_variance_impl_sse2.asm
 DSP_SRCS-$(HAVE_NEON)   += arm/highbd_avg_pred_neon.c
+DSP_SRCS-$(HAVE_NEON)   += arm/highbd_sse_neon.c
 DSP_SRCS-$(HAVE_NEON)   += arm/highbd_variance_neon.c
 DSP_SRCS-$(HAVE_NEON)   += arm/highbd_subpel_variance_neon.c
 endif  # CONFIG_VP9_HIGHBITDEPTH
index c9cdc28..e9d63f6 100644 (file)
@@ -744,6 +744,9 @@ if (vpx_config("CONFIG_ENCODERS") eq "yes") {
 add_proto qw/void vpx_subtract_block/, "int rows, int cols, int16_t *diff_ptr, ptrdiff_t diff_stride, const uint8_t *src_ptr, ptrdiff_t src_stride, const uint8_t *pred_ptr, ptrdiff_t pred_stride";
 specialize qw/vpx_subtract_block neon msa mmi sse2 avx2 vsx lsx/;
 
+add_proto qw/int64_t/, "vpx_sse", "const uint8_t *a, int a_stride, const uint8_t *b,int b_stride, int width, int height";
+specialize qw/vpx_sse sse4_1 avx2 neon neon_dotprod/;
+
 #
 # Single block SAD
 #
@@ -1026,6 +1029,9 @@ if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
   add_proto qw/void vpx_highbd_subtract_block/, "int rows, int cols, int16_t *diff_ptr, ptrdiff_t diff_stride, const uint8_t *src8_ptr, ptrdiff_t src_stride, const uint8_t *pred8_ptr, ptrdiff_t pred_stride, int bd";
   specialize qw/vpx_highbd_subtract_block neon avx2/;
 
+  add_proto qw/int64_t/, "vpx_highbd_sse", "const uint8_t *a8, int a_stride, const uint8_t *b8,int b_stride, int width, int height";
+  specialize qw/vpx_highbd_sse sse4_1 avx2 neon/;
+
   #
   # Single block SAD
   #
diff --git a/vpx_dsp/x86/sse_avx2.c b/vpx_dsp/x86/sse_avx2.c
new file mode 100644 (file)
index 0000000..9754467
--- /dev/null
@@ -0,0 +1,401 @@
+/*
+ *  Copyright (c) 2023 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include <smmintrin.h>
+#include <immintrin.h>
+
+#include "./vpx_config.h"
+#include "./vpx_dsp_rtcd.h"
+
+#include "vpx_ports/mem.h"
+#include "vpx_dsp/x86/mem_sse2.h"
+
+static INLINE void sse_w32_avx2(__m256i *sum, const uint8_t *a,
+                                const uint8_t *b) {
+  const __m256i v_a0 = _mm256_loadu_si256((const __m256i *)a);
+  const __m256i v_b0 = _mm256_loadu_si256((const __m256i *)b);
+  const __m256i zero = _mm256_setzero_si256();
+  const __m256i v_a00_w = _mm256_unpacklo_epi8(v_a0, zero);
+  const __m256i v_a01_w = _mm256_unpackhi_epi8(v_a0, zero);
+  const __m256i v_b00_w = _mm256_unpacklo_epi8(v_b0, zero);
+  const __m256i v_b01_w = _mm256_unpackhi_epi8(v_b0, zero);
+  const __m256i v_d00_w = _mm256_sub_epi16(v_a00_w, v_b00_w);
+  const __m256i v_d01_w = _mm256_sub_epi16(v_a01_w, v_b01_w);
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d00_w, v_d00_w));
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d01_w, v_d01_w));
+}
+
+static INLINE int64_t summary_all_avx2(const __m256i *sum_all) {
+  int64_t sum;
+  __m256i zero = _mm256_setzero_si256();
+  const __m256i sum0_4x64 = _mm256_unpacklo_epi32(*sum_all, zero);
+  const __m256i sum1_4x64 = _mm256_unpackhi_epi32(*sum_all, zero);
+  const __m256i sum_4x64 = _mm256_add_epi64(sum0_4x64, sum1_4x64);
+  const __m128i sum_2x64 = _mm_add_epi64(_mm256_castsi256_si128(sum_4x64),
+                                         _mm256_extracti128_si256(sum_4x64, 1));
+  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
+  _mm_storel_epi64((__m128i *)&sum, sum_1x64);
+  return sum;
+}
+
+#if CONFIG_VP9_HIGHBITDEPTH
+static INLINE void summary_32_avx2(const __m256i *sum32, __m256i *sum) {
+  const __m256i sum0_4x64 =
+      _mm256_cvtepu32_epi64(_mm256_castsi256_si128(*sum32));
+  const __m256i sum1_4x64 =
+      _mm256_cvtepu32_epi64(_mm256_extracti128_si256(*sum32, 1));
+  const __m256i sum_4x64 = _mm256_add_epi64(sum0_4x64, sum1_4x64);
+  *sum = _mm256_add_epi64(*sum, sum_4x64);
+}
+
+static INLINE int64_t summary_4x64_avx2(const __m256i sum_4x64) {
+  int64_t sum;
+  const __m128i sum_2x64 = _mm_add_epi64(_mm256_castsi256_si128(sum_4x64),
+                                         _mm256_extracti128_si256(sum_4x64, 1));
+  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
+
+  _mm_storel_epi64((__m128i *)&sum, sum_1x64);
+  return sum;
+}
+#endif
+
+static INLINE void sse_w4x4_avx2(const uint8_t *a, int a_stride,
+                                 const uint8_t *b, int b_stride, __m256i *sum) {
+  const __m128i v_a0 = load_unaligned_u32(a);
+  const __m128i v_a1 = load_unaligned_u32(a + a_stride);
+  const __m128i v_a2 = load_unaligned_u32(a + a_stride * 2);
+  const __m128i v_a3 = load_unaligned_u32(a + a_stride * 3);
+  const __m128i v_b0 = load_unaligned_u32(b);
+  const __m128i v_b1 = load_unaligned_u32(b + b_stride);
+  const __m128i v_b2 = load_unaligned_u32(b + b_stride * 2);
+  const __m128i v_b3 = load_unaligned_u32(b + b_stride * 3);
+  const __m128i v_a0123 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_a0, v_a1),
+                                             _mm_unpacklo_epi32(v_a2, v_a3));
+  const __m128i v_b0123 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_b0, v_b1),
+                                             _mm_unpacklo_epi32(v_b2, v_b3));
+  const __m256i v_a_w = _mm256_cvtepu8_epi16(v_a0123);
+  const __m256i v_b_w = _mm256_cvtepu8_epi16(v_b0123);
+  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
+}
+
+static INLINE void sse_w8x2_avx2(const uint8_t *a, int a_stride,
+                                 const uint8_t *b, int b_stride, __m256i *sum) {
+  const __m128i v_a0 = _mm_loadl_epi64((const __m128i *)a);
+  const __m128i v_a1 = _mm_loadl_epi64((const __m128i *)(a + a_stride));
+  const __m128i v_b0 = _mm_loadl_epi64((const __m128i *)b);
+  const __m128i v_b1 = _mm_loadl_epi64((const __m128i *)(b + b_stride));
+  const __m256i v_a_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_a0, v_a1));
+  const __m256i v_b_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_b0, v_b1));
+  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
+}
+
+int64_t vpx_sse_avx2(const uint8_t *a, int a_stride, const uint8_t *b,
+                     int b_stride, int width, int height) {
+  int32_t y = 0;
+  int64_t sse = 0;
+  __m256i sum = _mm256_setzero_si256();
+  __m256i zero = _mm256_setzero_si256();
+  switch (width) {
+    case 4:
+      do {
+        sse_w4x4_avx2(a, a_stride, b, b_stride, &sum);
+        a += a_stride << 2;
+        b += b_stride << 2;
+        y += 4;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 8:
+      do {
+        sse_w8x2_avx2(a, a_stride, b, b_stride, &sum);
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 16:
+      do {
+        const __m128i v_a0 = _mm_loadu_si128((const __m128i *)a);
+        const __m128i v_a1 = _mm_loadu_si128((const __m128i *)(a + a_stride));
+        const __m128i v_b0 = _mm_loadu_si128((const __m128i *)b);
+        const __m128i v_b1 = _mm_loadu_si128((const __m128i *)(b + b_stride));
+        const __m256i v_a =
+            _mm256_insertf128_si256(_mm256_castsi128_si256(v_a0), v_a1, 0x01);
+        const __m256i v_b =
+            _mm256_insertf128_si256(_mm256_castsi128_si256(v_b0), v_b1, 0x01);
+        const __m256i v_al = _mm256_unpacklo_epi8(v_a, zero);
+        const __m256i v_au = _mm256_unpackhi_epi8(v_a, zero);
+        const __m256i v_bl = _mm256_unpacklo_epi8(v_b, zero);
+        const __m256i v_bu = _mm256_unpackhi_epi8(v_b, zero);
+        const __m256i v_asub = _mm256_sub_epi16(v_al, v_bl);
+        const __m256i v_bsub = _mm256_sub_epi16(v_au, v_bu);
+        const __m256i temp =
+            _mm256_add_epi32(_mm256_madd_epi16(v_asub, v_asub),
+                             _mm256_madd_epi16(v_bsub, v_bsub));
+        sum = _mm256_add_epi32(sum, temp);
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 32:
+      do {
+        sse_w32_avx2(&sum, a, b);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 64:
+      do {
+        sse_w32_avx2(&sum, a, b);
+        sse_w32_avx2(&sum, a + 32, b + 32);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 128:
+      do {
+        sse_w32_avx2(&sum, a, b);
+        sse_w32_avx2(&sum, a + 32, b + 32);
+        sse_w32_avx2(&sum, a + 64, b + 64);
+        sse_w32_avx2(&sum, a + 96, b + 96);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    default:
+      if ((width & 0x07) == 0) {
+        do {
+          int i = 0;
+          do {
+            sse_w8x2_avx2(a + i, a_stride, b + i, b_stride, &sum);
+            i += 8;
+          } while (i < width);
+          a += a_stride << 1;
+          b += b_stride << 1;
+          y += 2;
+        } while (y < height);
+      } else {
+        do {
+          int i = 0;
+          do {
+            const uint8_t *a2;
+            const uint8_t *b2;
+            sse_w8x2_avx2(a + i, a_stride, b + i, b_stride, &sum);
+            a2 = a + i + (a_stride << 1);
+            b2 = b + i + (b_stride << 1);
+            sse_w8x2_avx2(a2, a_stride, b2, b_stride, &sum);
+            i += 8;
+          } while (i + 4 < width);
+          sse_w4x4_avx2(a + i, a_stride, b + i, b_stride, &sum);
+          a += a_stride << 2;
+          b += b_stride << 2;
+          y += 4;
+        } while (y < height);
+      }
+      sse = summary_all_avx2(&sum);
+      break;
+  }
+
+  return sse;
+}
+
+#if CONFIG_VP9_HIGHBITDEPTH
+static INLINE void highbd_sse_w16_avx2(__m256i *sum, const uint16_t *a,
+                                       const uint16_t *b) {
+  const __m256i v_a_w = _mm256_loadu_si256((const __m256i *)a);
+  const __m256i v_b_w = _mm256_loadu_si256((const __m256i *)b);
+  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
+}
+
+static INLINE void highbd_sse_w4x4_avx2(__m256i *sum, const uint16_t *a,
+                                        int a_stride, const uint16_t *b,
+                                        int b_stride) {
+  const __m128i v_a0 = _mm_loadl_epi64((const __m128i *)a);
+  const __m128i v_a1 = _mm_loadl_epi64((const __m128i *)(a + a_stride));
+  const __m128i v_a2 = _mm_loadl_epi64((const __m128i *)(a + a_stride * 2));
+  const __m128i v_a3 = _mm_loadl_epi64((const __m128i *)(a + a_stride * 3));
+  const __m128i v_b0 = _mm_loadl_epi64((const __m128i *)b);
+  const __m128i v_b1 = _mm_loadl_epi64((const __m128i *)(b + b_stride));
+  const __m128i v_b2 = _mm_loadl_epi64((const __m128i *)(b + b_stride * 2));
+  const __m128i v_b3 = _mm_loadl_epi64((const __m128i *)(b + b_stride * 3));
+  const __m128i v_a_hi = _mm_unpacklo_epi64(v_a0, v_a1);
+  const __m128i v_a_lo = _mm_unpacklo_epi64(v_a2, v_a3);
+  const __m256i v_a_w =
+      _mm256_insertf128_si256(_mm256_castsi128_si256(v_a_lo), v_a_hi, 1);
+  const __m128i v_b_hi = _mm_unpacklo_epi64(v_b0, v_b1);
+  const __m128i v_b_lo = _mm_unpacklo_epi64(v_b2, v_b3);
+  const __m256i v_b_w =
+      _mm256_insertf128_si256(_mm256_castsi128_si256(v_b_lo), v_b_hi, 1);
+  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
+}
+
+static INLINE void highbd_sse_w8x2_avx2(__m256i *sum, const uint16_t *a,
+                                        int a_stride, const uint16_t *b,
+                                        int b_stride) {
+  const __m128i v_a_hi = _mm_loadu_si128((const __m128i *)(a + a_stride));
+  const __m128i v_a_lo = _mm_loadu_si128((const __m128i *)a);
+  const __m256i v_a_w =
+      _mm256_insertf128_si256(_mm256_castsi128_si256(v_a_lo), v_a_hi, 1);
+  const __m128i v_b_hi = _mm_loadu_si128((const __m128i *)(b + b_stride));
+  const __m128i v_b_lo = _mm_loadu_si128((const __m128i *)b);
+  const __m256i v_b_w =
+      _mm256_insertf128_si256(_mm256_castsi128_si256(v_b_lo), v_b_hi, 1);
+  const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
+}
+
+int64_t vpx_highbd_sse_avx2(const uint8_t *a8, int a_stride, const uint8_t *b8,
+                            int b_stride, int width, int height) {
+  int32_t y = 0;
+  int64_t sse = 0;
+  uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+  uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+  __m256i sum = _mm256_setzero_si256();
+  switch (width) {
+    case 4:
+      do {
+        highbd_sse_w4x4_avx2(&sum, a, a_stride, b, b_stride);
+        a += a_stride << 2;
+        b += b_stride << 2;
+        y += 4;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 8:
+      do {
+        highbd_sse_w8x2_avx2(&sum, a, a_stride, b, b_stride);
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 16:
+      do {
+        highbd_sse_w16_avx2(&sum, a, b);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_avx2(&sum);
+      break;
+    case 32:
+      do {
+        int l = 0;
+        __m256i sum32 = _mm256_setzero_si256();
+        do {
+          highbd_sse_w16_avx2(&sum32, a, b);
+          highbd_sse_w16_avx2(&sum32, a + 16, b + 16);
+          a += a_stride;
+          b += b_stride;
+          l += 1;
+        } while (l < 64 && l < (height - y));
+        summary_32_avx2(&sum32, &sum);
+        y += 64;
+      } while (y < height);
+      sse = summary_4x64_avx2(sum);
+      break;
+    case 64:
+      do {
+        int l = 0;
+        __m256i sum32 = _mm256_setzero_si256();
+        do {
+          highbd_sse_w16_avx2(&sum32, a, b);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 1, b + 16 * 1);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 2, b + 16 * 2);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 3, b + 16 * 3);
+          a += a_stride;
+          b += b_stride;
+          l += 1;
+        } while (l < 32 && l < (height - y));
+        summary_32_avx2(&sum32, &sum);
+        y += 32;
+      } while (y < height);
+      sse = summary_4x64_avx2(sum);
+      break;
+    case 128:
+      do {
+        int l = 0;
+        __m256i sum32 = _mm256_setzero_si256();
+        do {
+          highbd_sse_w16_avx2(&sum32, a, b);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 1, b + 16 * 1);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 2, b + 16 * 2);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 3, b + 16 * 3);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 4, b + 16 * 4);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 5, b + 16 * 5);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 6, b + 16 * 6);
+          highbd_sse_w16_avx2(&sum32, a + 16 * 7, b + 16 * 7);
+          a += a_stride;
+          b += b_stride;
+          l += 1;
+        } while (l < 16 && l < (height - y));
+        summary_32_avx2(&sum32, &sum);
+        y += 16;
+      } while (y < height);
+      sse = summary_4x64_avx2(sum);
+      break;
+    default:
+      if (width & 0x7) {
+        do {
+          int i = 0;
+          __m256i sum32 = _mm256_setzero_si256();
+          do {
+            const uint16_t *a2;
+            const uint16_t *b2;
+            highbd_sse_w8x2_avx2(&sum32, a + i, a_stride, b + i, b_stride);
+            a2 = a + i + (a_stride << 1);
+            b2 = b + i + (b_stride << 1);
+            highbd_sse_w8x2_avx2(&sum32, a2, a_stride, b2, b_stride);
+            i += 8;
+          } while (i + 4 < width);
+          highbd_sse_w4x4_avx2(&sum32, a + i, a_stride, b + i, b_stride);
+          summary_32_avx2(&sum32, &sum);
+          a += a_stride << 2;
+          b += b_stride << 2;
+          y += 4;
+        } while (y < height);
+      } else {
+        do {
+          int l = 0;
+          __m256i sum32 = _mm256_setzero_si256();
+          do {
+            int i = 0;
+            do {
+              highbd_sse_w8x2_avx2(&sum32, a + i, a_stride, b + i, b_stride);
+              i += 8;
+            } while (i < width);
+            a += a_stride << 1;
+            b += b_stride << 1;
+            l += 2;
+          } while (l < 8 && l < (height - y));
+          summary_32_avx2(&sum32, &sum);
+          y += 8;
+        } while (y < height);
+      }
+      sse = summary_4x64_avx2(sum);
+      break;
+  }
+  return sse;
+}
+#endif  // CONFIG_VP9_HIGHBITDEPTH
diff --git a/vpx_dsp/x86/sse_sse4.c b/vpx_dsp/x86/sse_sse4.c
new file mode 100644 (file)
index 0000000..1c2744e
--- /dev/null
@@ -0,0 +1,359 @@
+/*
+ *  Copyright (c) 2023 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include <assert.h>
+#include <smmintrin.h>
+
+#include "./vpx_config.h"
+#include "./vpx_dsp_rtcd.h"
+
+#include "vpx_ports/mem.h"
+#include "vpx/vpx_integer.h"
+#include "vpx_dsp/x86/mem_sse2.h"
+
+static INLINE int64_t summary_all_sse4(const __m128i *sum_all) {
+  int64_t sum;
+  const __m128i sum0 = _mm_cvtepu32_epi64(*sum_all);
+  const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum_all, 8));
+  const __m128i sum_2x64 = _mm_add_epi64(sum0, sum1);
+  const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
+  _mm_storel_epi64((__m128i *)&sum, sum_1x64);
+  return sum;
+}
+
+#if CONFIG_VP9_HIGHBITDEPTH
+static INLINE void summary_32_sse4(const __m128i *sum32, __m128i *sum64) {
+  const __m128i sum0 = _mm_cvtepu32_epi64(*sum32);
+  const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum32, 8));
+  *sum64 = _mm_add_epi64(sum0, *sum64);
+  *sum64 = _mm_add_epi64(sum1, *sum64);
+}
+#endif
+
+static INLINE void sse_w16_sse4_1(__m128i *sum, const uint8_t *a,
+                                  const uint8_t *b) {
+  const __m128i v_a0 = _mm_loadu_si128((const __m128i *)a);
+  const __m128i v_b0 = _mm_loadu_si128((const __m128i *)b);
+  const __m128i v_a00_w = _mm_cvtepu8_epi16(v_a0);
+  const __m128i v_a01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_a0, 8));
+  const __m128i v_b00_w = _mm_cvtepu8_epi16(v_b0);
+  const __m128i v_b01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_b0, 8));
+  const __m128i v_d00_w = _mm_sub_epi16(v_a00_w, v_b00_w);
+  const __m128i v_d01_w = _mm_sub_epi16(v_a01_w, v_b01_w);
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d00_w, v_d00_w));
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d01_w, v_d01_w));
+}
+
+static INLINE void sse4x2_sse4_1(const uint8_t *a, int a_stride,
+                                 const uint8_t *b, int b_stride, __m128i *sum) {
+  const __m128i v_a0 = load_unaligned_u32(a);
+  const __m128i v_a1 = load_unaligned_u32(a + a_stride);
+  const __m128i v_b0 = load_unaligned_u32(b);
+  const __m128i v_b1 = load_unaligned_u32(b + b_stride);
+  const __m128i v_a_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_a0, v_a1));
+  const __m128i v_b_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_b0, v_b1));
+  const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
+}
+
+static INLINE void sse8_sse4_1(const uint8_t *a, const uint8_t *b,
+                               __m128i *sum) {
+  const __m128i v_a0 = _mm_loadl_epi64((const __m128i *)a);
+  const __m128i v_b0 = _mm_loadl_epi64((const __m128i *)b);
+  const __m128i v_a_w = _mm_cvtepu8_epi16(v_a0);
+  const __m128i v_b_w = _mm_cvtepu8_epi16(v_b0);
+  const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
+}
+
+int64_t vpx_sse_sse4_1(const uint8_t *a, int a_stride, const uint8_t *b,
+                       int b_stride, int width, int height) {
+  int y = 0;
+  int64_t sse = 0;
+  __m128i sum = _mm_setzero_si128();
+  switch (width) {
+    case 4:
+      do {
+        sse4x2_sse4_1(a, a_stride, b, b_stride, &sum);
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 8:
+      do {
+        sse8_sse4_1(a, b, &sum);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 16:
+      do {
+        sse_w16_sse4_1(&sum, a, b);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 32:
+      do {
+        sse_w16_sse4_1(&sum, a, b);
+        sse_w16_sse4_1(&sum, a + 16, b + 16);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 64:
+      do {
+        sse_w16_sse4_1(&sum, a, b);
+        sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
+        sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
+        sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 128:
+      do {
+        sse_w16_sse4_1(&sum, a, b);
+        sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
+        sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
+        sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
+        sse_w16_sse4_1(&sum, a + 16 * 4, b + 16 * 4);
+        sse_w16_sse4_1(&sum, a + 16 * 5, b + 16 * 5);
+        sse_w16_sse4_1(&sum, a + 16 * 6, b + 16 * 6);
+        sse_w16_sse4_1(&sum, a + 16 * 7, b + 16 * 7);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    default:
+      if (width & 0x07) {
+        do {
+          int i = 0;
+          do {
+            sse8_sse4_1(a + i, b + i, &sum);
+            sse8_sse4_1(a + i + a_stride, b + i + b_stride, &sum);
+            i += 8;
+          } while (i + 4 < width);
+          sse4x2_sse4_1(a + i, a_stride, b + i, b_stride, &sum);
+          a += (a_stride << 1);
+          b += (b_stride << 1);
+          y += 2;
+        } while (y < height);
+      } else {
+        do {
+          int i = 0;
+          do {
+            sse8_sse4_1(a + i, b + i, &sum);
+            i += 8;
+          } while (i < width);
+          a += a_stride;
+          b += b_stride;
+          y += 1;
+        } while (y < height);
+      }
+      sse = summary_all_sse4(&sum);
+      break;
+  }
+
+  return sse;
+}
+
+#if CONFIG_VP9_HIGHBITDEPTH
+static INLINE void highbd_sse_w4x2_sse4_1(__m128i *sum, const uint16_t *a,
+                                          int a_stride, const uint16_t *b,
+                                          int b_stride) {
+  const __m128i v_a0 = _mm_loadl_epi64((const __m128i *)a);
+  const __m128i v_a1 = _mm_loadl_epi64((const __m128i *)(a + a_stride));
+  const __m128i v_b0 = _mm_loadl_epi64((const __m128i *)b);
+  const __m128i v_b1 = _mm_loadl_epi64((const __m128i *)(b + b_stride));
+  const __m128i v_a_w = _mm_unpacklo_epi64(v_a0, v_a1);
+  const __m128i v_b_w = _mm_unpacklo_epi64(v_b0, v_b1);
+  const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
+}
+
+static INLINE void highbd_sse_w8_sse4_1(__m128i *sum, const uint16_t *a,
+                                        const uint16_t *b) {
+  const __m128i v_a_w = _mm_loadu_si128((const __m128i *)a);
+  const __m128i v_b_w = _mm_loadu_si128((const __m128i *)b);
+  const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
+  *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
+}
+
+int64_t vpx_highbd_sse_sse4_1(const uint8_t *a8, int a_stride,
+                              const uint8_t *b8, int b_stride, int width,
+                              int height) {
+  int32_t y = 0;
+  int64_t sse = 0;
+  uint16_t *a = CONVERT_TO_SHORTPTR(a8);
+  uint16_t *b = CONVERT_TO_SHORTPTR(b8);
+  __m128i sum = _mm_setzero_si128();
+  switch (width) {
+    case 4:
+      do {
+        highbd_sse_w4x2_sse4_1(&sum, a, a_stride, b, b_stride);
+        a += a_stride << 1;
+        b += b_stride << 1;
+        y += 2;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 8:
+      do {
+        highbd_sse_w8_sse4_1(&sum, a, b);
+        a += a_stride;
+        b += b_stride;
+        y += 1;
+      } while (y < height);
+      sse = summary_all_sse4(&sum);
+      break;
+    case 16:
+      do {
+        int l = 0;
+        __m128i sum32 = _mm_setzero_si128();
+        do {
+          highbd_sse_w8_sse4_1(&sum32, a, b);
+          highbd_sse_w8_sse4_1(&sum32, a + 8, b + 8);
+          a += a_stride;
+          b += b_stride;
+          l += 1;
+        } while (l < 64 && l < (height - y));
+        summary_32_sse4(&sum32, &sum);
+        y += 64;
+      } while (y < height);
+      _mm_storel_epi64((__m128i *)&sse,
+                       _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
+      break;
+    case 32:
+      do {
+        int l = 0;
+        __m128i sum32 = _mm_setzero_si128();
+        do {
+          highbd_sse_w8_sse4_1(&sum32, a, b);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
+          a += a_stride;
+          b += b_stride;
+          l += 1;
+        } while (l < 32 && l < (height - y));
+        summary_32_sse4(&sum32, &sum);
+        y += 32;
+      } while (y < height);
+      _mm_storel_epi64((__m128i *)&sse,
+                       _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
+      break;
+    case 64:
+      do {
+        int l = 0;
+        __m128i sum32 = _mm_setzero_si128();
+        do {
+          highbd_sse_w8_sse4_1(&sum32, a, b);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7);
+          a += a_stride;
+          b += b_stride;
+          l += 1;
+        } while (l < 16 && l < (height - y));
+        summary_32_sse4(&sum32, &sum);
+        y += 16;
+      } while (y < height);
+      _mm_storel_epi64((__m128i *)&sse,
+                       _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
+      break;
+    case 128:
+      do {
+        int l = 0;
+        __m128i sum32 = _mm_setzero_si128();
+        do {
+          highbd_sse_w8_sse4_1(&sum32, a, b);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 8, b + 8 * 8);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 9, b + 8 * 9);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 10, b + 8 * 10);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 11, b + 8 * 11);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 12, b + 8 * 12);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 13, b + 8 * 13);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 14, b + 8 * 14);
+          highbd_sse_w8_sse4_1(&sum32, a + 8 * 15, b + 8 * 15);
+          a += a_stride;
+          b += b_stride;
+          l += 1;
+        } while (l < 8 && l < (height - y));
+        summary_32_sse4(&sum32, &sum);
+        y += 8;
+      } while (y < height);
+      _mm_storel_epi64((__m128i *)&sse,
+                       _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
+      break;
+    default:
+      if (width & 0x7) {
+        do {
+          __m128i sum32 = _mm_setzero_si128();
+          int i = 0;
+          do {
+            highbd_sse_w8_sse4_1(&sum32, a + i, b + i);
+            highbd_sse_w8_sse4_1(&sum32, a + i + a_stride, b + i + b_stride);
+            i += 8;
+          } while (i + 4 < width);
+          highbd_sse_w4x2_sse4_1(&sum32, a + i, a_stride, b + i, b_stride);
+          a += (a_stride << 1);
+          b += (b_stride << 1);
+          y += 2;
+          summary_32_sse4(&sum32, &sum);
+        } while (y < height);
+      } else {
+        do {
+          int l = 0;
+          __m128i sum32 = _mm_setzero_si128();
+          do {
+            int i = 0;
+            do {
+              highbd_sse_w8_sse4_1(&sum32, a + i, b + i);
+              i += 8;
+            } while (i < width);
+            a += a_stride;
+            b += b_stride;
+            l += 1;
+          } while (l < 8 && l < (height - y));
+          summary_32_sse4(&sum32, &sum);
+          y += 8;
+        } while (y < height);
+      }
+      _mm_storel_epi64((__m128i *)&sse,
+                       _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
+      break;
+  }
+  return sse;
+}
+#endif  // CONFIG_VP9_HIGHBITDEPTH