Optimize transpose_neon.h helper functions
authorJonathan Wright <jonathan.wright@arm.com>
Sat, 25 Feb 2023 00:43:46 +0000 (00:43 +0000)
committerJonathan Wright <jonathan.wright@arm.com>
Mon, 27 Feb 2023 09:49:02 +0000 (09:49 +0000)
1) Use vtrn[12]q_[su]64 in vpx_vtrnq_[su]64* helpers on AArch64
   targets. This produces half as many TRN1/2 instructions compared to
   the number of MOVs that result from vcombine.

2) Use vpx_vtrnq_[su]64* helpers wherever applicable.

3) Refactor transpose_4x8_s16 to operate on 128-bit vectors.

Change-Id: I9a8b1c1fe2a98a429e0c5f39def5eb2f65759127

vpx_dsp/arm/transpose_neon.h

index 48292c6..518278f 100644 (file)
@@ -39,26 +39,45 @@ static INLINE int16x8x2_t vpx_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) {
 
 static INLINE int32x4x2_t vpx_vtrnq_s64_to_s32(int32x4_t a0, int32x4_t a1) {
   int32x4x2_t b0;
+#if defined(__aarch64__)
+  b0.val[0] = vreinterpretq_s32_s64(
+      vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+  b0.val[1] = vreinterpretq_s32_s64(
+      vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1)));
+#else
   b0.val[0] = vcombine_s32(vget_low_s32(a0), vget_low_s32(a1));
   b0.val[1] = vcombine_s32(vget_high_s32(a0), vget_high_s32(a1));
+#endif
   return b0;
 }
 
 static INLINE int64x2x2_t vpx_vtrnq_s64(int32x4_t a0, int32x4_t a1) {
   int64x2x2_t b0;
+#if defined(__aarch64__)
+  b0.val[0] = vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1));
+  b0.val[1] = vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1));
+#else
   b0.val[0] = vcombine_s64(vreinterpret_s64_s32(vget_low_s32(a0)),
                            vreinterpret_s64_s32(vget_low_s32(a1)));
   b0.val[1] = vcombine_s64(vreinterpret_s64_s32(vget_high_s32(a0)),
                            vreinterpret_s64_s32(vget_high_s32(a1)));
+#endif
   return b0;
 }
 
 static INLINE uint8x16x2_t vpx_vtrnq_u64_to_u8(uint32x4_t a0, uint32x4_t a1) {
   uint8x16x2_t b0;
+#if defined(__aarch64__)
+  b0.val[0] = vreinterpretq_u8_u64(
+      vtrn1q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1)));
+  b0.val[1] = vreinterpretq_u8_u64(
+      vtrn2q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1)));
+#else
   b0.val[0] = vcombine_u8(vreinterpret_u8_u32(vget_low_u32(a0)),
                           vreinterpret_u8_u32(vget_low_u32(a1)));
   b0.val[1] = vcombine_u8(vreinterpret_u8_u32(vget_high_u32(a0)),
                           vreinterpret_u8_u32(vget_high_u32(a1)));
+#endif
   return b0;
 }
 
@@ -155,17 +174,13 @@ static INLINE void transpose_s16_4x4q(int16x8_t *a0, int16x8_t *a1) {
   // c0: 00 01 20 21  02 03 22 23
   // c1: 10 11 30 31  12 13 32 33
 
-  const int32x4_t c0 =
-      vcombine_s32(vget_low_s32(b0.val[0]), vget_low_s32(b0.val[1]));
-  const int32x4_t c1 =
-      vcombine_s32(vget_high_s32(b0.val[0]), vget_high_s32(b0.val[1]));
+  const int16x8x2_t c0 = vpx_vtrnq_s64_to_s16(b0.val[0], b0.val[1]);
 
   // Swap 16 bit elements resulting in:
   // d0.val[0]: 00 10 20 30  02 12 22 32
   // d0.val[1]: 01 11 21 31  03 13 23 33
 
-  const int16x8x2_t d0 =
-      vtrnq_s16(vreinterpretq_s16_s32(c0), vreinterpretq_s16_s32(c1));
+  const int16x8x2_t d0 = vtrnq_s16(c0.val[0], c0.val[1]);
 
   *a0 = d0.val[0];
   *a1 = d0.val[1];
@@ -186,17 +201,13 @@ static INLINE void transpose_u16_4x4q(uint16x8_t *a0, uint16x8_t *a1) {
   // c0: 00 01 20 21  02 03 22 23
   // c1: 10 11 30 31  12 13 32 33
 
-  const uint32x4_t c0 =
-      vcombine_u32(vget_low_u32(b0.val[0]), vget_low_u32(b0.val[1]));
-  const uint32x4_t c1 =
-      vcombine_u32(vget_high_u32(b0.val[0]), vget_high_u32(b0.val[1]));
+  const uint16x8x2_t c0 = vpx_vtrnq_u64_to_u16(b0.val[0], b0.val[1]);
 
   // Swap 16 bit elements resulting in:
   // d0.val[0]: 00 10 20 30  02 12 22 32
   // d0.val[1]: 01 11 21 31  03 13 23 33
 
-  const uint16x8x2_t d0 =
-      vtrnq_u16(vreinterpretq_u16_u32(c0), vreinterpretq_u16_u32(c1));
+  const uint16x8x2_t d0 = vtrnq_u16(c0.val[0], c0.val[1]);
 
   *a0 = d0.val[0];
   *a1 = d0.val[1];
@@ -295,7 +306,7 @@ static INLINE void transpose_s16_4x8(const int16x4_t a0, const int16x4_t a1,
                                      const int16x4_t a6, const int16x4_t a7,
                                      int16x8_t *const o0, int16x8_t *const o1,
                                      int16x8_t *const o2, int16x8_t *const o3) {
-  // Swap 16 bit elements. Goes from:
+  // Combine rows. Goes from:
   // a0: 00 01 02 03
   // a1: 10 11 12 13
   // a2: 20 21 22 23
@@ -305,53 +316,40 @@ static INLINE void transpose_s16_4x8(const int16x4_t a0, const int16x4_t a1,
   // a6: 60 61 62 63
   // a7: 70 71 72 73
   // to:
-  // b0.val[0]: 00 10 02 12
-  // b0.val[1]: 01 11 03 13
-  // b1.val[0]: 20 30 22 32
-  // b1.val[1]: 21 31 23 33
-  // b2.val[0]: 40 50 42 52
-  // b2.val[1]: 41 51 43 53
-  // b3.val[0]: 60 70 62 72
-  // b3.val[1]: 61 71 63 73
+  // b0: 00 01 02 03 40 41 42 43
+  // b1: 10 11 12 13 50 51 52 53
+  // b2: 20 21 22 23 60 61 62 63
+  // b3: 30 31 32 33 70 71 72 73
+
+  const int16x8_t b0 = vcombine_s16(a0, a4);
+  const int16x8_t b1 = vcombine_s16(a1, a5);
+  const int16x8_t b2 = vcombine_s16(a2, a6);
+  const int16x8_t b3 = vcombine_s16(a3, a7);
 
-  const int16x4x2_t b0 = vtrn_s16(a0, a1);
-  const int16x4x2_t b1 = vtrn_s16(a2, a3);
-  const int16x4x2_t b2 = vtrn_s16(a4, a5);
-  const int16x4x2_t b3 = vtrn_s16(a6, a7);
+  // Swap 16 bit elements resulting in:
+  // c0.val[0]: 00 10 02 12 40 50 42 52
+  // c0.val[1]: 01 11 03 13 41 51 43 53
+  // c1.val[0]: 20 30 22 32 60 70 62 72
+  // c1.val[1]: 21 31 23 33 61 71 63 73
+
+  const int16x8x2_t c0 = vtrnq_s16(b0, b1);
+  const int16x8x2_t c1 = vtrnq_s16(b2, b3);
 
   // Swap 32 bit elements resulting in:
-  // c0.val[0]: 00 10 20 30
-  // c0.val[1]: 02 12 22 32
-  // c1.val[0]: 01 11 21 31
-  // c1.val[1]: 03 13 23 33
-  // c2.val[0]: 40 50 60 70
-  // c2.val[1]: 42 52 62 72
-  // c3.val[0]: 41 51 61 71
-  // c3.val[1]: 43 53 63 73
+  // d0.val[0]: 00 10 20 30 40 50 60 70
+  // d0.val[1]: 02 12 22 32 42 52 62 72
+  // d1.val[0]: 01 11 21 31 41 51 61 71
+  // d1.val[1]: 03 13 23 33 43 53 63 73
 
-  const int32x2x2_t c0 = vtrn_s32(vreinterpret_s32_s16(b0.val[0]),
-                                  vreinterpret_s32_s16(b1.val[0]));
-  const int32x2x2_t c1 = vtrn_s32(vreinterpret_s32_s16(b0.val[1]),
-                                  vreinterpret_s32_s16(b1.val[1]));
-  const int32x2x2_t c2 = vtrn_s32(vreinterpret_s32_s16(b2.val[0]),
-                                  vreinterpret_s32_s16(b3.val[0]));
-  const int32x2x2_t c3 = vtrn_s32(vreinterpret_s32_s16(b2.val[1]),
-                                  vreinterpret_s32_s16(b3.val[1]));
+  const int32x4x2_t d0 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[0]),
+                                   vreinterpretq_s32_s16(c1.val[0]));
+  const int32x4x2_t d1 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[1]),
+                                   vreinterpretq_s32_s16(c1.val[1]));
 
-  // Swap 64 bit elements resulting in:
-  // o0: 00 10 20 30 40 50 60 70
-  // o1: 01 11 21 31 41 51 61 71
-  // o2: 02 12 22 32 42 52 62 72
-  // o3: 03 13 23 33 43 53 63 73
-
-  *o0 = vcombine_s16(vreinterpret_s16_s32(c0.val[0]),
-                     vreinterpret_s16_s32(c2.val[0]));
-  *o1 = vcombine_s16(vreinterpret_s16_s32(c1.val[0]),
-                     vreinterpret_s16_s32(c3.val[0]));
-  *o2 = vcombine_s16(vreinterpret_s16_s32(c0.val[1]),
-                     vreinterpret_s16_s32(c2.val[1]));
-  *o3 = vcombine_s16(vreinterpret_s16_s32(c1.val[1]),
-                     vreinterpret_s16_s32(c3.val[1]));
+  *o0 = vreinterpretq_s16_s32(d0.val[0]);
+  *o1 = vreinterpretq_s16_s32(d1.val[0]);
+  *o2 = vreinterpretq_s16_s32(d0.val[1]);
+  *o3 = vreinterpretq_s16_s32(d1.val[1]);
 }
 
 static INLINE void transpose_s32_4x8(int32x4_t *const a0, int32x4_t *const a1,