Allow non-uniform above array in d63 predictor Neon impl
authorGeorge Steed <george.steed@arm.com>
Fri, 17 Mar 2023 19:55:17 +0000 (19:55 +0000)
committerGeorge Steed <george.steed@arm.com>
Tue, 28 Mar 2023 13:27:22 +0000 (13:27 +0000)
The existing standard bitdepth implementation doesn't appear to manifest
as a failure in any of the predictor or MD5 tests, but it does rely on
the predictor tests filling the second `bs` elements of the `above`
input array with copies of `above[bs - 1]` in order to match the C
implementation.

This patch adjusts the Neon implementation to correctly match the C
implementation in the case where the elements of the `above` array all
differ.

The geomean of performance for the predictor is approximately a 2%
slowdown compared to the previous vectorized implementation. This is
still considerably faster than the unspecialized naive C implementation.

Bug: webm:1797
Change-Id: I8fb00a154288d54b24a72a7ff63c816bdcf3aca3

vpx_dsp/arm/intrapred_neon.c

index 7c225f6..3d117fa 100644 (file)
@@ -499,12 +499,16 @@ void vpx_d63_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride,
 
   vst1_u8(dst + 0 * stride, d0);
   vst1_u8(dst + 1 * stride, d1);
-  vst1_u8(dst + 2 * stride, vext_u8(d0, a7, 1));
-  vst1_u8(dst + 3 * stride, vext_u8(d1, a7, 1));
-  vst1_u8(dst + 4 * stride, vext_u8(d0, a7, 2));
-  vst1_u8(dst + 5 * stride, vext_u8(d1, a7, 2));
-  vst1_u8(dst + 6 * stride, vext_u8(d0, a7, 3));
-  vst1_u8(dst + 7 * stride, vext_u8(d1, a7, 3));
+
+  d0 = vext_u8(d0, d0, 7);
+  d1 = vext_u8(d1, d1, 7);
+
+  vst1_u8(dst + 2 * stride, vext_u8(d0, a7, 2));
+  vst1_u8(dst + 3 * stride, vext_u8(d1, a7, 2));
+  vst1_u8(dst + 4 * stride, vext_u8(d0, a7, 3));
+  vst1_u8(dst + 5 * stride, vext_u8(d1, a7, 3));
+  vst1_u8(dst + 6 * stride, vext_u8(d0, a7, 4));
+  vst1_u8(dst + 7 * stride, vext_u8(d1, a7, 4));
 }
 
 void vpx_d63_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
@@ -522,20 +526,24 @@ void vpx_d63_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride,
 
   vst1q_u8(dst + 0 * stride, d0);
   vst1q_u8(dst + 1 * stride, d1);
-  vst1q_u8(dst + 2 * stride, vextq_u8(d0, a15, 1));
-  vst1q_u8(dst + 3 * stride, vextq_u8(d1, a15, 1));
-  vst1q_u8(dst + 4 * stride, vextq_u8(d0, a15, 2));
-  vst1q_u8(dst + 5 * stride, vextq_u8(d1, a15, 2));
-  vst1q_u8(dst + 6 * stride, vextq_u8(d0, a15, 3));
-  vst1q_u8(dst + 7 * stride, vextq_u8(d1, a15, 3));
-  vst1q_u8(dst + 8 * stride, vextq_u8(d0, a15, 4));
-  vst1q_u8(dst + 9 * stride, vextq_u8(d1, a15, 4));
-  vst1q_u8(dst + 10 * stride, vextq_u8(d0, a15, 5));
-  vst1q_u8(dst + 11 * stride, vextq_u8(d1, a15, 5));
-  vst1q_u8(dst + 12 * stride, vextq_u8(d0, a15, 6));
-  vst1q_u8(dst + 13 * stride, vextq_u8(d1, a15, 6));
-  vst1q_u8(dst + 14 * stride, vextq_u8(d0, a15, 7));
-  vst1q_u8(dst + 15 * stride, vextq_u8(d1, a15, 7));
+
+  d0 = vextq_u8(d0, d0, 15);
+  d1 = vextq_u8(d1, d1, 15);
+
+  vst1q_u8(dst + 2 * stride, vextq_u8(d0, a15, 2));
+  vst1q_u8(dst + 3 * stride, vextq_u8(d1, a15, 2));
+  vst1q_u8(dst + 4 * stride, vextq_u8(d0, a15, 3));
+  vst1q_u8(dst + 5 * stride, vextq_u8(d1, a15, 3));
+  vst1q_u8(dst + 6 * stride, vextq_u8(d0, a15, 4));
+  vst1q_u8(dst + 7 * stride, vextq_u8(d1, a15, 4));
+  vst1q_u8(dst + 8 * stride, vextq_u8(d0, a15, 5));
+  vst1q_u8(dst + 9 * stride, vextq_u8(d1, a15, 5));
+  vst1q_u8(dst + 10 * stride, vextq_u8(d0, a15, 6));
+  vst1q_u8(dst + 11 * stride, vextq_u8(d1, a15, 6));
+  vst1q_u8(dst + 12 * stride, vextq_u8(d0, a15, 7));
+  vst1q_u8(dst + 13 * stride, vextq_u8(d1, a15, 7));
+  vst1q_u8(dst + 14 * stride, vextq_u8(d0, a15, 8));
+  vst1q_u8(dst + 15 * stride, vextq_u8(d1, a15, 8));
 }
 
 void vpx_d63_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
@@ -560,66 +568,72 @@ void vpx_d63_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride,
   vst1q_u8(dst + 0 * stride + 16, d0_hi);
   vst1q_u8(dst + 1 * stride + 0, d1_lo);
   vst1q_u8(dst + 1 * stride + 16, d1_hi);
-  vst1q_u8(dst + 2 * stride + 0, vextq_u8(d0_lo, d0_hi, 1));
-  vst1q_u8(dst + 2 * stride + 16, vextq_u8(d0_hi, a31, 1));
-  vst1q_u8(dst + 3 * stride + 0, vextq_u8(d1_lo, d1_hi, 1));
-  vst1q_u8(dst + 3 * stride + 16, vextq_u8(d1_hi, a31, 1));
-  vst1q_u8(dst + 4 * stride + 0, vextq_u8(d0_lo, d0_hi, 2));
-  vst1q_u8(dst + 4 * stride + 16, vextq_u8(d0_hi, a31, 2));
-  vst1q_u8(dst + 5 * stride + 0, vextq_u8(d1_lo, d1_hi, 2));
-  vst1q_u8(dst + 5 * stride + 16, vextq_u8(d1_hi, a31, 2));
-  vst1q_u8(dst + 6 * stride + 0, vextq_u8(d0_lo, d0_hi, 3));
-  vst1q_u8(dst + 6 * stride + 16, vextq_u8(d0_hi, a31, 3));
-  vst1q_u8(dst + 7 * stride + 0, vextq_u8(d1_lo, d1_hi, 3));
-  vst1q_u8(dst + 7 * stride + 16, vextq_u8(d1_hi, a31, 3));
-  vst1q_u8(dst + 8 * stride + 0, vextq_u8(d0_lo, d0_hi, 4));
-  vst1q_u8(dst + 8 * stride + 16, vextq_u8(d0_hi, a31, 4));
-  vst1q_u8(dst + 9 * stride + 0, vextq_u8(d1_lo, d1_hi, 4));
-  vst1q_u8(dst + 9 * stride + 16, vextq_u8(d1_hi, a31, 4));
-  vst1q_u8(dst + 10 * stride + 0, vextq_u8(d0_lo, d0_hi, 5));
-  vst1q_u8(dst + 10 * stride + 16, vextq_u8(d0_hi, a31, 5));
-  vst1q_u8(dst + 11 * stride + 0, vextq_u8(d1_lo, d1_hi, 5));
-  vst1q_u8(dst + 11 * stride + 16, vextq_u8(d1_hi, a31, 5));
-  vst1q_u8(dst + 12 * stride + 0, vextq_u8(d0_lo, d0_hi, 6));
-  vst1q_u8(dst + 12 * stride + 16, vextq_u8(d0_hi, a31, 6));
-  vst1q_u8(dst + 13 * stride + 0, vextq_u8(d1_lo, d1_hi, 6));
-  vst1q_u8(dst + 13 * stride + 16, vextq_u8(d1_hi, a31, 6));
-  vst1q_u8(dst + 14 * stride + 0, vextq_u8(d0_lo, d0_hi, 7));
-  vst1q_u8(dst + 14 * stride + 16, vextq_u8(d0_hi, a31, 7));
-  vst1q_u8(dst + 15 * stride + 0, vextq_u8(d1_lo, d1_hi, 7));
-  vst1q_u8(dst + 15 * stride + 16, vextq_u8(d1_hi, a31, 7));
-  vst1q_u8(dst + 16 * stride + 0, vextq_u8(d0_lo, d0_hi, 8));
-  vst1q_u8(dst + 16 * stride + 16, vextq_u8(d0_hi, a31, 8));
-  vst1q_u8(dst + 17 * stride + 0, vextq_u8(d1_lo, d1_hi, 8));
-  vst1q_u8(dst + 17 * stride + 16, vextq_u8(d1_hi, a31, 8));
-  vst1q_u8(dst + 18 * stride + 0, vextq_u8(d0_lo, d0_hi, 9));
-  vst1q_u8(dst + 18 * stride + 16, vextq_u8(d0_hi, a31, 9));
-  vst1q_u8(dst + 19 * stride + 0, vextq_u8(d1_lo, d1_hi, 9));
-  vst1q_u8(dst + 19 * stride + 16, vextq_u8(d1_hi, a31, 9));
-  vst1q_u8(dst + 20 * stride + 0, vextq_u8(d0_lo, d0_hi, 10));
-  vst1q_u8(dst + 20 * stride + 16, vextq_u8(d0_hi, a31, 10));
-  vst1q_u8(dst + 21 * stride + 0, vextq_u8(d1_lo, d1_hi, 10));
-  vst1q_u8(dst + 21 * stride + 16, vextq_u8(d1_hi, a31, 10));
-  vst1q_u8(dst + 22 * stride + 0, vextq_u8(d0_lo, d0_hi, 11));
-  vst1q_u8(dst + 22 * stride + 16, vextq_u8(d0_hi, a31, 11));
-  vst1q_u8(dst + 23 * stride + 0, vextq_u8(d1_lo, d1_hi, 11));
-  vst1q_u8(dst + 23 * stride + 16, vextq_u8(d1_hi, a31, 11));
-  vst1q_u8(dst + 24 * stride + 0, vextq_u8(d0_lo, d0_hi, 12));
-  vst1q_u8(dst + 24 * stride + 16, vextq_u8(d0_hi, a31, 12));
-  vst1q_u8(dst + 25 * stride + 0, vextq_u8(d1_lo, d1_hi, 12));
-  vst1q_u8(dst + 25 * stride + 16, vextq_u8(d1_hi, a31, 12));
-  vst1q_u8(dst + 26 * stride + 0, vextq_u8(d0_lo, d0_hi, 13));
-  vst1q_u8(dst + 26 * stride + 16, vextq_u8(d0_hi, a31, 13));
-  vst1q_u8(dst + 27 * stride + 0, vextq_u8(d1_lo, d1_hi, 13));
-  vst1q_u8(dst + 27 * stride + 16, vextq_u8(d1_hi, a31, 13));
-  vst1q_u8(dst + 28 * stride + 0, vextq_u8(d0_lo, d0_hi, 14));
-  vst1q_u8(dst + 28 * stride + 16, vextq_u8(d0_hi, a31, 14));
-  vst1q_u8(dst + 29 * stride + 0, vextq_u8(d1_lo, d1_hi, 14));
-  vst1q_u8(dst + 29 * stride + 16, vextq_u8(d1_hi, a31, 14));
-  vst1q_u8(dst + 30 * stride + 0, vextq_u8(d0_lo, d0_hi, 15));
-  vst1q_u8(dst + 30 * stride + 16, vextq_u8(d0_hi, a31, 15));
-  vst1q_u8(dst + 31 * stride + 0, vextq_u8(d1_lo, d1_hi, 15));
-  vst1q_u8(dst + 31 * stride + 16, vextq_u8(d1_hi, a31, 15));
+
+  d0_hi = vextq_u8(d0_lo, d0_hi, 15);
+  d0_lo = vextq_u8(d0_lo, d0_lo, 15);
+  d1_hi = vextq_u8(d1_lo, d1_hi, 15);
+  d1_lo = vextq_u8(d1_lo, d1_lo, 15);
+
+  vst1q_u8(dst + 2 * stride + 0, vextq_u8(d0_lo, d0_hi, 2));
+  vst1q_u8(dst + 2 * stride + 16, vextq_u8(d0_hi, a31, 2));
+  vst1q_u8(dst + 3 * stride + 0, vextq_u8(d1_lo, d1_hi, 2));
+  vst1q_u8(dst + 3 * stride + 16, vextq_u8(d1_hi, a31, 2));
+  vst1q_u8(dst + 4 * stride + 0, vextq_u8(d0_lo, d0_hi, 3));
+  vst1q_u8(dst + 4 * stride + 16, vextq_u8(d0_hi, a31, 3));
+  vst1q_u8(dst + 5 * stride + 0, vextq_u8(d1_lo, d1_hi, 3));
+  vst1q_u8(dst + 5 * stride + 16, vextq_u8(d1_hi, a31, 3));
+  vst1q_u8(dst + 6 * stride + 0, vextq_u8(d0_lo, d0_hi, 4));
+  vst1q_u8(dst + 6 * stride + 16, vextq_u8(d0_hi, a31, 4));
+  vst1q_u8(dst + 7 * stride + 0, vextq_u8(d1_lo, d1_hi, 4));
+  vst1q_u8(dst + 7 * stride + 16, vextq_u8(d1_hi, a31, 4));
+  vst1q_u8(dst + 8 * stride + 0, vextq_u8(d0_lo, d0_hi, 5));
+  vst1q_u8(dst + 8 * stride + 16, vextq_u8(d0_hi, a31, 5));
+  vst1q_u8(dst + 9 * stride + 0, vextq_u8(d1_lo, d1_hi, 5));
+  vst1q_u8(dst + 9 * stride + 16, vextq_u8(d1_hi, a31, 5));
+  vst1q_u8(dst + 10 * stride + 0, vextq_u8(d0_lo, d0_hi, 6));
+  vst1q_u8(dst + 10 * stride + 16, vextq_u8(d0_hi, a31, 6));
+  vst1q_u8(dst + 11 * stride + 0, vextq_u8(d1_lo, d1_hi, 6));
+  vst1q_u8(dst + 11 * stride + 16, vextq_u8(d1_hi, a31, 6));
+  vst1q_u8(dst + 12 * stride + 0, vextq_u8(d0_lo, d0_hi, 7));
+  vst1q_u8(dst + 12 * stride + 16, vextq_u8(d0_hi, a31, 7));
+  vst1q_u8(dst + 13 * stride + 0, vextq_u8(d1_lo, d1_hi, 7));
+  vst1q_u8(dst + 13 * stride + 16, vextq_u8(d1_hi, a31, 7));
+  vst1q_u8(dst + 14 * stride + 0, vextq_u8(d0_lo, d0_hi, 8));
+  vst1q_u8(dst + 14 * stride + 16, vextq_u8(d0_hi, a31, 8));
+  vst1q_u8(dst + 15 * stride + 0, vextq_u8(d1_lo, d1_hi, 8));
+  vst1q_u8(dst + 15 * stride + 16, vextq_u8(d1_hi, a31, 8));
+  vst1q_u8(dst + 16 * stride + 0, vextq_u8(d0_lo, d0_hi, 9));
+  vst1q_u8(dst + 16 * stride + 16, vextq_u8(d0_hi, a31, 9));
+  vst1q_u8(dst + 17 * stride + 0, vextq_u8(d1_lo, d1_hi, 9));
+  vst1q_u8(dst + 17 * stride + 16, vextq_u8(d1_hi, a31, 9));
+  vst1q_u8(dst + 18 * stride + 0, vextq_u8(d0_lo, d0_hi, 10));
+  vst1q_u8(dst + 18 * stride + 16, vextq_u8(d0_hi, a31, 10));
+  vst1q_u8(dst + 19 * stride + 0, vextq_u8(d1_lo, d1_hi, 10));
+  vst1q_u8(dst + 19 * stride + 16, vextq_u8(d1_hi, a31, 10));
+  vst1q_u8(dst + 20 * stride + 0, vextq_u8(d0_lo, d0_hi, 11));
+  vst1q_u8(dst + 20 * stride + 16, vextq_u8(d0_hi, a31, 11));
+  vst1q_u8(dst + 21 * stride + 0, vextq_u8(d1_lo, d1_hi, 11));
+  vst1q_u8(dst + 21 * stride + 16, vextq_u8(d1_hi, a31, 11));
+  vst1q_u8(dst + 22 * stride + 0, vextq_u8(d0_lo, d0_hi, 12));
+  vst1q_u8(dst + 22 * stride + 16, vextq_u8(d0_hi, a31, 12));
+  vst1q_u8(dst + 23 * stride + 0, vextq_u8(d1_lo, d1_hi, 12));
+  vst1q_u8(dst + 23 * stride + 16, vextq_u8(d1_hi, a31, 12));
+  vst1q_u8(dst + 24 * stride + 0, vextq_u8(d0_lo, d0_hi, 13));
+  vst1q_u8(dst + 24 * stride + 16, vextq_u8(d0_hi, a31, 13));
+  vst1q_u8(dst + 25 * stride + 0, vextq_u8(d1_lo, d1_hi, 13));
+  vst1q_u8(dst + 25 * stride + 16, vextq_u8(d1_hi, a31, 13));
+  vst1q_u8(dst + 26 * stride + 0, vextq_u8(d0_lo, d0_hi, 14));
+  vst1q_u8(dst + 26 * stride + 16, vextq_u8(d0_hi, a31, 14));
+  vst1q_u8(dst + 27 * stride + 0, vextq_u8(d1_lo, d1_hi, 14));
+  vst1q_u8(dst + 27 * stride + 16, vextq_u8(d1_hi, a31, 14));
+  vst1q_u8(dst + 28 * stride + 0, vextq_u8(d0_lo, d0_hi, 15));
+  vst1q_u8(dst + 28 * stride + 16, vextq_u8(d0_hi, a31, 15));
+  vst1q_u8(dst + 29 * stride + 0, vextq_u8(d1_lo, d1_hi, 15));
+  vst1q_u8(dst + 29 * stride + 16, vextq_u8(d1_hi, a31, 15));
+  vst1q_u8(dst + 30 * stride + 0, d0_hi);
+  vst1q_u8(dst + 30 * stride + 16, a31);
+  vst1q_u8(dst + 31 * stride + 0, d1_hi);
+  vst1q_u8(dst + 31 * stride + 16, a31);
 }
 
 // -----------------------------------------------------------------------------