Implement AVX Insert/Extract helper-intrinsics in managed code
authorFei Peng <fei.peng@intel.com>
Wed, 14 Mar 2018 00:02:14 +0000 (17:02 -0700)
committerTanner Gooding <tagoo@outlook.com>
Fri, 16 Mar 2018 13:50:46 +0000 (06:50 -0700)
src/mscorlib/src/System/Runtime/Intrinsics/X86/Avx.cs

index 7fa6926..23f4127 100644 (file)
@@ -4,6 +4,7 @@
 
 using System;
 using System.Runtime.Intrinsics;
+using System.Runtime.CompilerServices;
 
 namespace System.Runtime.Intrinsics.X86
 {
@@ -238,42 +239,145 @@ namespace System.Runtime.Intrinsics.X86
         /// __int8 _mm256_extract_epi8 (__m256i a, const int index)
         ///   HELPER
         /// </summary>
-        public static sbyte Extract(Vector256<sbyte> value, byte index) => Extract(value, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static sbyte Extract(Vector256<sbyte> value, byte index)
+        {
+            index &= 0x1F; // the instructions only need the lowest 5 bits.
+            if (index > 15)
+            {
+                return Sse41.Extract(ExtractVector128(value, 1), (byte)(index - 16));
+            }
+            else
+            {
+                return Sse41.Extract(GetLowerHalf(value), index);
+            }
+        }
+
         /// <summary>
         /// __int8 _mm256_extract_epi8 (__m256i a, const int index)
         ///   HELPER
         /// </summary>
-        public static byte Extract(Vector256<byte> value, byte index) => Extract(value, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static byte Extract(Vector256<byte> value, byte index)
+        {
+            index &= 0x1F; // the instructions only need the lowest 5 bits.
+            if (index > 15)
+            {
+                return Sse41.Extract(ExtractVector128(value, 1), (byte)(index - 16));
+            }
+            else
+            {
+                return Sse41.Extract(GetLowerHalf(value), index);
+            }
+        }
+
         /// <summary>
         /// __int16 _mm256_extract_epi16 (__m256i a, const int index)
         ///   HELPER
         /// </summary>
-        public static short Extract(Vector256<short> value, byte index) => Extract(value, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static short Extract(Vector256<short> value, byte index)
+        {
+            index &= 0xF; // the instructions only need the lowest 4 bits.
+            if (index > 7)
+            {
+                return Sse2.Extract(ExtractVector128(value, 1), (byte)(index - 8));
+            }
+            else
+            {
+                return Sse2.Extract(GetLowerHalf(value), index);
+            }
+        }
+
         /// <summary>
         /// __int16 _mm256_extract_epi16 (__m256i a, const int index)
         ///   HELPER
         /// </summary>
-        public static ushort Extract(Vector256<ushort> value, byte index) => Extract(value, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static ushort Extract(Vector256<ushort> value, byte index)
+        {
+            index &= 0xF; // the instructions only need the lowest 4 bits.
+            if (index > 7)
+            {
+                return Sse2.Extract(ExtractVector128(value, 1), (byte)(index - 8));
+            }
+            else
+            {
+                return Sse2.Extract(GetLowerHalf(value), index);
+            }
+        }
+
         /// <summary>
         /// __int32 _mm256_extract_epi32 (__m256i a, const int index)
         ///   HELPER
         /// </summary>
-        public static int Extract(Vector256<int> value, byte index) => Extract(value, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static int Extract(Vector256<int> value, byte index)
+        {
+            index &= 0x7; // the instructions only need the lowest 3 bits.
+            if (index > 3)
+            {
+                return Sse41.Extract(ExtractVector128(value, 1), (byte)(index - 4));
+            }
+            else
+            {
+                return Sse41.Extract(GetLowerHalf(value), index);
+            }
+        }
+
         /// <summary>
         /// __int32 _mm256_extract_epi32 (__m256i a, const int index)
         ///   HELPER
         /// </summary>
-        public static uint Extract(Vector256<uint> value, byte index) => Extract(value, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static uint Extract(Vector256<uint> value, byte index)
+        {
+            index &= 0x7; // the instructions only need the lowest 3 bits.
+            if (index > 3)
+            {
+                return Sse41.Extract(ExtractVector128(value, 1), (byte)(index - 4));
+            }
+            else
+            {
+                return Sse41.Extract(GetLowerHalf(value), index);
+            }
+        }
+
         /// <summary>
         /// __int64 _mm256_extract_epi64 (__m256i a, const int index)
         ///   HELPER
         /// </summary>
-        public static long Extract(Vector256<long> value, byte index) => Extract(value, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static long Extract(Vector256<long> value, byte index)
+        {
+            index &= 0x3; // the instructions only need the lowest 2 bits.
+            if (index > 1)
+            {
+                return Sse41.Extract(ExtractVector128(value, 1), (byte)(index - 2));
+            }
+            else
+            {
+                return Sse41.Extract(GetLowerHalf(value), index);
+            }
+        }
+
         /// <summary>
         /// __int64 _mm256_extract_epi64 (__m256i a, const int index)
         ///   HELPER
         /// </summary>
-        public static ulong Extract(Vector256<ulong> value, byte index) => Extract(value, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static ulong Extract(Vector256<ulong> value, byte index)
+        {
+            index &= 0x3; // the instructions only need the lowest 2 bits.
+            if (index > 1)
+            {
+                return Sse41.Extract(ExtractVector128(value, 1), (byte)(index - 2));
+            }
+            else
+            {
+                return Sse41.Extract(GetLowerHalf(value), index);
+            }
+        }
 
         /// <summary>
         /// __m128 _mm256_extractf128_ps (__m256 a, const int imm8)
@@ -402,42 +506,169 @@ namespace System.Runtime.Intrinsics.X86
         /// __m256i _mm256_insert_epi8 (__m256i a, __int8 i, const int index)
         ///   HELPER
         /// </summary>
-        public static Vector256<sbyte> Insert(Vector256<sbyte> value, sbyte data, byte index) => Insert(value, data, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static Vector256<sbyte> Insert(Vector256<sbyte> value, sbyte data, byte index)
+        {
+            index &= 0x1F; // the instructions only need the lowest 5 bits.
+            if (index > 15)
+            {
+                Vector128<sbyte> half = ExtractVector128(value, 1);
+                half = Sse41.Insert(half, data, (byte)(index - 16));
+                return InsertVector128(value, half, 1);
+            }
+            else
+            {
+                Vector128<sbyte> half = Sse41.Insert(GetLowerHalf(value), data, index);
+                return StaticCast<float, sbyte>(Blend(StaticCast<sbyte, float>(value), StaticCast<sbyte, float>(ExtendToVector256(half)), 15));
+            }
+        }
+
         /// <summary>
         /// __m256i _mm256_insert_epi8 (__m256i a, __int8 i, const int index)
         ///   HELPER
         /// </summary>
-        public static Vector256<byte> Insert(Vector256<byte> value, byte data, byte index) => Insert(value, data, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static Vector256<byte> Insert(Vector256<byte> value, byte data, byte index)
+        {
+            index &= 0x1F; // the instructions only need the lowest 5 bits.
+            if (index > 15)
+            {
+                Vector128<byte> half = ExtractVector128(value, 1);
+                half = Sse41.Insert(half, data, (byte)(index - 16));
+                return InsertVector128(value, half, 1);
+            }
+            else
+            {
+                Vector128<byte> half = Sse41.Insert(GetLowerHalf(value), data, index);
+                return StaticCast<float, byte>(Blend(StaticCast<byte, float>(value), StaticCast<byte, float>(ExtendToVector256(half)), 15));
+            }
+        }
+        
         /// <summary>
         /// __m256i _mm256_insert_epi16 (__m256i a, __int16 i, const int index)
         ///   HELPER
         /// </summary>
-        public static Vector256<short> Insert(Vector256<short> value, short data, byte index) => Insert(value, data, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static Vector256<short> Insert(Vector256<short> value, short data, byte index)
+        {
+            index &= 0xF; // the instructions only need the lowest 4 bits.
+            if (index > 7)
+            {
+                Vector128<short> half = ExtractVector128(value, 1);
+                half = Sse2.Insert(half, data, (byte)(index - 8));
+                return InsertVector128(value, half, 1);
+            }
+            else
+            {
+                Vector128<short> half = Sse2.Insert(GetLowerHalf(value), data, index);
+                return StaticCast<float, short>(Blend(StaticCast<short, float>(value), StaticCast<short, float>(ExtendToVector256(half)), 15));
+            }
+        }
+
         /// <summary>
         /// __m256i _mm256_insert_epi16 (__m256i a, __int16 i, const int index)
         ///   HELPER
         /// </summary>
-        public static Vector256<ushort> Insert(Vector256<ushort> value, ushort data, byte index) => Insert(value, data, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static Vector256<ushort> Insert(Vector256<ushort> value, ushort data, byte index)
+        {
+            index &= 0xF; // the instructions only need the lowest 4 bits.
+            if (index > 7)
+            {
+                Vector128<ushort> half = ExtractVector128(value, 1);
+                half = Sse2.Insert(half, data, (byte)(index - 8));
+                return InsertVector128(value, half, 1);
+            }
+            else
+            {
+                Vector128<ushort> half = Sse2.Insert(GetLowerHalf(value), data, index);
+                return StaticCast<float, ushort>(Blend(StaticCast<ushort, float>(value), StaticCast<ushort, float>(ExtendToVector256(half)), 15));
+            }
+        }
+
         /// <summary>
         /// __m256i _mm256_insert_epi32 (__m256i a, __int32 i, const int index)
         ///   HELPER
         /// </summary>
-        public static Vector256<int> Insert(Vector256<int> value, int data, byte index) => Insert(value, data, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static Vector256<int> Insert(Vector256<int> value, int data, byte index)
+        {
+            index &= 0x7; // the instructions only need the lowest 3 bits.
+            if (index > 3)
+            {
+                Vector128<int> half = ExtractVector128(value, 1);
+                half = Sse41.Insert(half, data, (byte)(index - 4));
+                return InsertVector128(value, half, 1);
+            }
+            else
+            {
+                Vector128<int> half = Sse41.Insert(GetLowerHalf(value), data, index);
+                return StaticCast<float, int>(Blend(StaticCast<int, float>(value), StaticCast<int, float>(ExtendToVector256(half)), 15));
+            }
+        }
+        
         /// <summary>
         /// __m256i _mm256_insert_epi32 (__m256i a, __int32 i, const int index)
         ///   HELPER
         /// </summary>
-        public static Vector256<uint> Insert(Vector256<uint> value, uint data, byte index) => Insert(value, data, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static Vector256<uint> Insert(Vector256<uint> value, uint data, byte index)
+        {
+            index &= 0x7; // the instructions only need the lowest 3 bits.
+            if (index > 3)
+            {
+                Vector128<uint> half = ExtractVector128(value, 1);
+                half = Sse41.Insert(half, data, (byte)(index - 4));
+                return InsertVector128(value, half, 1);
+            }
+            else
+            {
+                Vector128<uint> half = Sse41.Insert(GetLowerHalf(value), data, index);
+                return StaticCast<float, uint>(Blend(StaticCast<uint, float>(value), StaticCast<uint, float>(ExtendToVector256(half)), 15));
+            }
+        }
+
         /// <summary>
         /// __m256i _mm256_insert_epi64 (__m256i a, __int64 i, const int index)
         ///   HELPER
         /// </summary>
-        public static Vector256<long> Insert(Vector256<long> value, long data, byte index) => Insert(value, data, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static Vector256<long> Insert(Vector256<long> value, long data, byte index)
+        {
+            index &= 0x3; // the instructions only need the lowest 2 bits.
+            if (index > 1)
+            {
+                Vector128<long> half = ExtractVector128(value, 1);
+                half = Sse41.Insert(half, data, (byte)(index - 2));
+                return InsertVector128(value, half, 1);
+            }
+            else
+            {
+                Vector128<long> half = Sse41.Insert(GetLowerHalf(value), data, index);
+                return StaticCast<float, long>(Blend(StaticCast<long, float>(value), StaticCast<long, float>(ExtendToVector256(half)), 15));
+            }
+        }
+
         /// <summary>
         /// __m256i _mm256_insert_epi64 (__m256i a, __int64 i, const int index)
         ///   HELPER
         /// </summary>
-        public static Vector256<ulong> Insert(Vector256<ulong> value, ulong data, byte index) => Insert(value, data, index);
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static Vector256<ulong> Insert(Vector256<ulong> value, ulong data, byte index)
+        {
+            index &= 0x3; // the instructions only need the lowest 2 bits.
+            if (index > 1)
+            {
+                Vector128<ulong> half = ExtractVector128(value, 1);
+                half = Sse41.Insert(half, data, (byte)(index - 2));
+                return InsertVector128(value, half, 1);
+            }
+            else
+            {
+                Vector128<ulong> half = Sse41.Insert(GetLowerHalf(value), data, index);
+                return StaticCast<float, ulong>(Blend(StaticCast<ulong, float>(value), StaticCast<ulong, float>(ExtendToVector256(half)), 15));
+            }
+        }
 
         /// <summary>
         /// __m256 _mm256_insertf128_ps (__m256 a, __m128 b, int imm8)