[NEON] highbd partial DCT functions
authorKonstantinos Margaritis <konma@vectorcamp.gr>
Thu, 6 Oct 2022 10:26:05 +0000 (10:26 +0000)
committerKonstantinos Margaritis <konma@vectorcamp.gr>
Mon, 10 Oct 2022 11:47:39 +0000 (11:47 +0000)
Change-Id: I7dd4e698469562f5b1f948cc36f8403b490dcb6a

test/dct_partial_test.cc
vpx_dsp/arm/fdct_partial_neon.c
vpx_dsp/vpx_dsp_rtcd_defs.pl

index 8d0e3a9..e57fa0f 100644 (file)
@@ -145,11 +145,17 @@ INSTANTIATE_TEST_SUITE_P(
 #if CONFIG_VP9_HIGHBITDEPTH
 INSTANTIATE_TEST_SUITE_P(
     NEON, PartialFdctTest,
-    ::testing::Values(make_tuple(&vpx_fdct32x32_1_neon, 32, VPX_BITS_8),
-                      make_tuple(&vpx_fdct16x16_1_neon, 16, VPX_BITS_8),
+    ::testing::Values(make_tuple(&vpx_highbd_fdct32x32_1_neon, 32, VPX_BITS_12),
+                      make_tuple(&vpx_highbd_fdct32x32_1_neon, 32, VPX_BITS_10),
+                      make_tuple(&vpx_highbd_fdct32x32_1_neon, 32, VPX_BITS_8),
+                      make_tuple(&vpx_highbd_fdct16x16_1_neon, 16, VPX_BITS_12),
+                      make_tuple(&vpx_highbd_fdct16x16_1_neon, 16, VPX_BITS_10),
+                      make_tuple(&vpx_highbd_fdct16x16_1_neon, 16, VPX_BITS_8),
                       make_tuple(&vpx_fdct8x8_1_neon, 8, VPX_BITS_12),
                       make_tuple(&vpx_fdct8x8_1_neon, 8, VPX_BITS_10),
                       make_tuple(&vpx_fdct8x8_1_neon, 8, VPX_BITS_8),
+                      make_tuple(&vpx_fdct4x4_1_neon, 4, VPX_BITS_12),
+                      make_tuple(&vpx_fdct4x4_1_neon, 4, VPX_BITS_10),
                       make_tuple(&vpx_fdct4x4_1_neon, 4, VPX_BITS_8)));
 #else
 INSTANTIATE_TEST_SUITE_P(
index 0a1cdca..718dba0 100644 (file)
@@ -101,3 +101,68 @@ void vpx_fdct32x32_1_neon(const int16_t *input, tran_low_t *output,
   output[0] = (tran_low_t)(sum >> 3);
   output[1] = 0;
 }
+
+#if CONFIG_VP9_HIGHBITDEPTH
+
+void vpx_highbd_fdct16x16_1_neon(const int16_t *input, tran_low_t *output,
+                                 int stride) {
+  int32x4_t partial_sum[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
+                               vdupq_n_s32(0) };
+  int32_t sum;
+
+  int r = 0;
+  do {
+    const int16x8_t a = vld1q_s16(input);
+    const int16x8_t b = vld1q_s16(input + 8);
+    input += stride;
+    partial_sum[0] = vaddw_s16(partial_sum[0], vget_low_s16(a));
+    partial_sum[1] = vaddw_s16(partial_sum[1], vget_high_s16(a));
+    partial_sum[2] = vaddw_s16(partial_sum[2], vget_low_s16(b));
+    partial_sum[3] = vaddw_s16(partial_sum[3], vget_high_s16(b));
+    r++;
+  } while (r < 16);
+
+  partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[1]);
+  partial_sum[2] = vaddq_s32(partial_sum[2], partial_sum[3]);
+  partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[2]);
+  sum = horizontal_add_int32x4(partial_sum[0]);
+
+  output[0] = (tran_low_t)(sum >> 1);
+  output[1] = 0;
+}
+
+void vpx_highbd_fdct32x32_1_neon(const int16_t *input, tran_low_t *output,
+                                 int stride) {
+  int32x4_t partial_sum[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
+                               vdupq_n_s32(0) };
+
+  int32_t sum;
+
+  int r = 0;
+  do {
+    const int16x8_t a0 = vld1q_s16(input);
+    const int16x8_t a1 = vld1q_s16(input + 8);
+    const int16x8_t a2 = vld1q_s16(input + 16);
+    const int16x8_t a3 = vld1q_s16(input + 24);
+    input += stride;
+    partial_sum[0] = vaddw_s16(partial_sum[0], vget_low_s16(a0));
+    partial_sum[0] = vaddw_s16(partial_sum[0], vget_high_s16(a0));
+    partial_sum[1] = vaddw_s16(partial_sum[1], vget_low_s16(a1));
+    partial_sum[1] = vaddw_s16(partial_sum[1], vget_high_s16(a1));
+    partial_sum[2] = vaddw_s16(partial_sum[2], vget_low_s16(a2));
+    partial_sum[2] = vaddw_s16(partial_sum[2], vget_high_s16(a2));
+    partial_sum[3] = vaddw_s16(partial_sum[3], vget_low_s16(a3));
+    partial_sum[3] = vaddw_s16(partial_sum[3], vget_high_s16(a3));
+    r++;
+  } while (r < 32);
+
+  partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[1]);
+  partial_sum[2] = vaddq_s32(partial_sum[2], partial_sum[3]);
+  partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[2]);
+  sum = horizontal_add_int32x4(partial_sum[0]);
+
+  output[0] = (tran_low_t)(sum >> 3);
+  output[1] = 0;
+}
+
+#endif  // CONFIG_VP9_HIGHBITDEPTH
index 004afb3..5dad78c 100644 (file)
@@ -527,6 +527,8 @@ if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
 
   add_proto qw/void vpx_fdct4x4_1/, "const int16_t *input, tran_low_t *output, int stride";
   specialize qw/vpx_fdct4x4_1 sse2 neon/;
+  specialize qw/vpx_highbd_fdct4x4_1 neon/;
+  $vpx_highbd_fdct4x4_1_neon=vpx_fdct4x4_1_neon;
 
   add_proto qw/void vpx_fdct8x8/, "const int16_t *input, tran_low_t *output, int stride";
   specialize qw/vpx_fdct8x8 neon sse2/;
@@ -563,6 +565,7 @@ if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
   specialize qw/vpx_highbd_fdct16x16 sse2/;
 
   add_proto qw/void vpx_highbd_fdct16x16_1/, "const int16_t *input, tran_low_t *output, int stride";
+  specialize qw/vpx_highbd_fdct16x16_1 neon/;
 
   add_proto qw/void vpx_highbd_fdct32x32/, "const int16_t *input, tran_low_t *output, int stride";
   specialize qw/vpx_highbd_fdct32x32 sse2/;
@@ -571,6 +574,7 @@ if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
   specialize qw/vpx_highbd_fdct32x32_rd sse2/;
 
   add_proto qw/void vpx_highbd_fdct32x32_1/, "const int16_t *input, tran_low_t *output, int stride";
+  specialize qw/vpx_highbd_fdct32x32_1 neon/;
 } else {
   add_proto qw/void vpx_fdct4x4/, "const int16_t *input, tran_low_t *output, int stride";
   specialize qw/vpx_fdct4x4 neon sse2 msa lsx/;