Vector.Sum(Vector<T>) API implementation for horizontal add. (#53527)
authorIvan Zlatanov <ivan@zlatanov.net>
Fri, 11 Jun 2021 16:50:09 +0000 (19:50 +0300)
committerGitHub <noreply@github.com>
Fri, 11 Jun 2021 16:50:09 +0000 (09:50 -0700)
* Vector.Sum(Vector<T>) API implementation for horizontal add.

* Fixed inccorrect referece to Arm64 AddAccross intrinsic function.

* Added implementation for hardware accelerated Vector<T>.Sum for long, ulong, float, double on ARM64.

* Fixed formatting issue.

* Correctness.

* Fixed compiler error for ARM64.

* Formatting issue.

* More explicit switch statement. Fixed wrong simd size for NI_Vector64_ToScalar.

* Fixed auto formatting issue.

* Use AddPairwiseScalar for double, long and ulong on ARM64 for VectorT128_Sum.

* Forgot ToScalar call after AddPairwiseScalar.

* Fixed wrong return type.

src/coreclr/jit/simdashwintrinsic.cpp
src/coreclr/jit/simdashwintrinsiclistarm64.h
src/coreclr/jit/simdashwintrinsiclistxarch.h
src/libraries/System.Numerics.Vectors/ref/System.Numerics.Vectors.cs
src/libraries/System.Numerics.Vectors/tests/GenericVectorTests.cs
src/libraries/System.Private.CoreLib/src/System/Numerics/Vector.cs
src/libraries/System.Private.CoreLib/src/System/Numerics/Vector_1.cs

index e8f742c..8e8b760 100644 (file)
@@ -719,12 +719,118 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic       intrinsic,
                     }
                     break;
                 }
+                case NI_VectorT128_Sum:
+                {
+                    if (compOpportunisticallyDependsOn(InstructionSet_SSSE3))
+                    {
+                        GenTree* tmp;
+                        unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
+                        int      haddCount    = genLog2(vectorLength);
+
+                        for (int i = 0; i < haddCount; i++)
+                        {
+                            op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
+                                               nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
+                            op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, NI_SSSE3_HorizontalAdd,
+                                                             simdBaseJitType, simdSize);
+                        }
+
+                        return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector128_ToScalar, simdBaseJitType,
+                                                          simdSize);
+                    }
+
+                    return nullptr;
+                }
+                case NI_VectorT256_Sum:
+                {
+                    // HorizontalAdd combines pairs so we need log2(vectorLength) passes to sum all elements together.
+                    unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
+                    int haddCount = genLog2(vectorLength) - 1; // Minus 1 because for the last pass we split the vector
+                                                               // to low / high and add them together.
+                    GenTree*       tmp;
+                    NamedIntrinsic horizontalAdd = NI_AVX2_HorizontalAdd;
+                    NamedIntrinsic add           = NI_SSE2_Add;
+
+                    if (simdBaseType == TYP_DOUBLE)
+                    {
+                        horizontalAdd = NI_AVX_HorizontalAdd;
+                    }
+                    else if (simdBaseType == TYP_FLOAT)
+                    {
+                        horizontalAdd = NI_AVX_HorizontalAdd;
+                        add           = NI_SSE_Add;
+                    }
+
+                    for (int i = 0; i < haddCount; i++)
+                    {
+                        op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
+                                           nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
+                        op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, horizontalAdd, simdBaseJitType, simdSize);
+                    }
+
+                    op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
+                                       nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
+                    op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD16, op1, gtNewIconNode(0x01, TYP_INT),
+                                                     NI_AVX_ExtractVector128, simdBaseJitType, simdSize);
+
+                    tmp = gtNewSimdAsHWIntrinsicNode(simdType, tmp, NI_Vector256_GetLower, simdBaseJitType, simdSize);
+                    op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD16, op1, tmp, add, simdBaseJitType, 16);
+
+                    return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector128_ToScalar, simdBaseJitType, 16);
+                }
 #elif defined(TARGET_ARM64)
                 case NI_VectorT128_Abs:
                 {
                     assert(varTypeIsUnsigned(simdBaseType));
                     return op1;
                 }
+                case NI_VectorT128_Sum:
+                {
+                    GenTree* tmp;
+
+                    switch (simdBaseType)
+                    {
+                        case TYP_BYTE:
+                        case TYP_UBYTE:
+                        case TYP_SHORT:
+                        case TYP_USHORT:
+                        case TYP_INT:
+                        case TYP_UINT:
+                        {
+                            tmp = gtNewSimdAsHWIntrinsicNode(simdType, op1, NI_AdvSimd_Arm64_AddAcross, simdBaseJitType,
+                                                             simdSize);
+                            return gtNewSimdAsHWIntrinsicNode(retType, tmp, NI_Vector64_ToScalar, simdBaseJitType, 8);
+                        }
+                        case TYP_FLOAT:
+                        {
+                            unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
+                            int      haddCount    = genLog2(vectorLength);
+
+                            for (int i = 0; i < haddCount; i++)
+                            {
+                                op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
+                                                   nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
+                                op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, NI_AdvSimd_Arm64_AddPairwise,
+                                                                 simdBaseJitType, simdSize);
+                            }
+
+                            return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector128_ToScalar, simdBaseJitType,
+                                                              simdSize);
+                        }
+                        case TYP_DOUBLE:
+                        case TYP_LONG:
+                        case TYP_ULONG:
+                        {
+                            op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD8, op1, NI_AdvSimd_Arm64_AddPairwiseScalar,
+                                                             simdBaseJitType, simdSize);
+                            return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector64_ToScalar, simdBaseJitType, 8);
+                        }
+                        default:
+                        {
+                            unreached();
+                        }
+                    }
+                }
 #else
 #error Unsupported platform
 #endif // !TARGET_XARCH && !TARGET_ARM64
index 4eba541..2292228 100644 (file)
@@ -132,6 +132,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT128,  op_Inequality,
 SIMD_AS_HWINTRINSIC_ID(VectorT128,  op_Multiply,                                            2,         {NI_VectorT128_op_Multiply,                     NI_VectorT128_op_Multiply,                      NI_VectorT128_op_Multiply,                      NI_VectorT128_op_Multiply,                      NI_VectorT128_op_Multiply,                      NI_VectorT128_op_Multiply,                      NI_Illegal,                                     NI_Illegal,                                     NI_VectorT128_op_Multiply,                      NI_VectorT128_op_Multiply},                     SimdAsHWIntrinsicFlag::None)
 SIMD_AS_HWINTRINSIC_ID(VectorT128,  op_Subtraction,                                         2,         {NI_AdvSimd_Subtract,                           NI_AdvSimd_Subtract,                            NI_AdvSimd_Subtract,                            NI_AdvSimd_Subtract,                            NI_AdvSimd_Subtract,                            NI_AdvSimd_Subtract,                            NI_AdvSimd_Subtract,                            NI_AdvSimd_Subtract,                            NI_AdvSimd_Subtract,                            NI_AdvSimd_Arm64_Subtract},                     SimdAsHWIntrinsicFlag::None)
 SIMD_AS_HWINTRINSIC_ID(VectorT128,  SquareRoot,                                             1,         {NI_Illegal,                                    NI_Illegal,                                     NI_Illegal,                                     NI_Illegal,                                     NI_Illegal,                                     NI_Illegal,                                     NI_Illegal,                                     NI_Illegal,                                     NI_AdvSimd_Arm64_Sqrt,                          NI_AdvSimd_Arm64_Sqrt},                         SimdAsHWIntrinsicFlag::None)
+SIMD_AS_HWINTRINSIC_ID(VectorT128,  Sum,                                                    1,         {NI_VectorT128_Sum,                             NI_VectorT128_Sum,                              NI_VectorT128_Sum,                              NI_VectorT128_Sum,                              NI_VectorT128_Sum,                              NI_VectorT128_Sum,                              NI_VectorT128_Sum,                              NI_VectorT128_Sum,                              NI_VectorT128_Sum,                              NI_VectorT128_Sum},                             SimdAsHWIntrinsicFlag::None)
 
 #undef SIMD_AS_HWINTRINSIC_NM
 #undef SIMD_AS_HWINTRINSIC_ID
index af75fb7..92d665c 100644 (file)
@@ -132,6 +132,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT128,  op_Inequality,
 SIMD_AS_HWINTRINSIC_ID(VectorT128,  op_Multiply,                                            2,         {NI_Illegal,                                NI_Illegal,                                 NI_VectorT128_op_Multiply,                  NI_VectorT128_op_Multiply,                  NI_VectorT128_op_Multiply,                  NI_VectorT128_op_Multiply,                  NI_Illegal,                                 NI_Illegal,                                 NI_VectorT128_op_Multiply,                  NI_VectorT128_op_Multiply},                 SimdAsHWIntrinsicFlag::None)
 SIMD_AS_HWINTRINSIC_ID(VectorT128,  op_Subtraction,                                         2,         {NI_SSE2_Subtract,                          NI_SSE2_Subtract,                           NI_SSE2_Subtract,                           NI_SSE2_Subtract,                           NI_SSE2_Subtract,                           NI_SSE2_Subtract,                           NI_SSE2_Subtract,                           NI_SSE2_Subtract,                           NI_SSE_Subtract,                            NI_SSE2_Subtract},                          SimdAsHWIntrinsicFlag::None)
 SIMD_AS_HWINTRINSIC_ID(VectorT128,  SquareRoot,                                             1,         {NI_Illegal,                                NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_SSE_Sqrt,                                NI_SSE2_Sqrt},                              SimdAsHWIntrinsicFlag::None)
+SIMD_AS_HWINTRINSIC_ID(VectorT128,  Sum,                                                    1,         {NI_Illegal,                                NI_Illegal,                                 NI_VectorT128_Sum,                          NI_VectorT128_Sum,                          NI_VectorT128_Sum,                          NI_VectorT128_Sum,                          NI_Illegal,                                 NI_Illegal,                                 NI_VectorT128_Sum,                          NI_VectorT128_Sum},                         SimdAsHWIntrinsicFlag::None)
 
 // *************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************
 //                     ISA          ID                          Name                        NumArg                                                                                                                                                                                                      Instructions                                                                                                                                                                                                                                           Flags
@@ -170,6 +171,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT256,  op_Inequality,
 SIMD_AS_HWINTRINSIC_ID(VectorT256,  op_Multiply,                                            2,         {NI_Illegal,                                NI_Illegal,                                 NI_VectorT256_op_Multiply,                  NI_VectorT256_op_Multiply,                  NI_VectorT256_op_Multiply,                  NI_VectorT256_op_Multiply,                  NI_Illegal,                                 NI_Illegal,                                 NI_VectorT256_op_Multiply,                  NI_VectorT256_op_Multiply},                 SimdAsHWIntrinsicFlag::None)
 SIMD_AS_HWINTRINSIC_ID(VectorT256,  op_Subtraction,                                         2,         {NI_AVX2_Subtract,                          NI_AVX2_Subtract,                           NI_AVX2_Subtract,                           NI_AVX2_Subtract,                           NI_AVX2_Subtract,                           NI_AVX2_Subtract,                           NI_AVX2_Subtract,                           NI_AVX2_Subtract,                           NI_AVX_Subtract,                            NI_AVX_Subtract},                           SimdAsHWIntrinsicFlag::None)
 SIMD_AS_HWINTRINSIC_ID(VectorT256,  SquareRoot,                                             1,         {NI_Illegal,                                NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_Illegal,                                 NI_AVX_Sqrt,                                NI_AVX_Sqrt},                               SimdAsHWIntrinsicFlag::None)
+SIMD_AS_HWINTRINSIC_ID(VectorT256,  Sum,                                                    1,         {NI_Illegal,                                NI_Illegal,                                 NI_VectorT256_Sum,                          NI_VectorT256_Sum,                          NI_VectorT256_Sum,                          NI_VectorT256_Sum,                          NI_Illegal,                                 NI_Illegal,                                 NI_VectorT256_Sum,                          NI_VectorT256_Sum},                         SimdAsHWIntrinsicFlag::None)
 
 #undef SIMD_AS_HWINTRINSIC_NM
 #undef SIMD_AS_HWINTRINSIC_ID
index 4e9121f..1824bd3 100644 (file)
@@ -300,6 +300,7 @@ namespace System.Numerics
         [System.CLSCompliantAttribute(false)]
         public static void Widen(System.Numerics.Vector<System.UInt32> source, out System.Numerics.Vector<System.UInt64> low, out System.Numerics.Vector<System.UInt64> high) { throw null; }
         public static System.Numerics.Vector<T> Xor<T>(System.Numerics.Vector<T> left, System.Numerics.Vector<T> right) where T : struct { throw null; }
+        public static T Sum<T>(System.Numerics.Vector<T> value) where T : struct { throw null; }
     }
     public partial struct Vector2 : System.IEquatable<System.Numerics.Vector2>, System.IFormattable
     {
index 6124cd6..aa8113c 100644 (file)
@@ -3137,6 +3137,49 @@ namespace System.Numerics.Tests
         }
         #endregion
 
+        #region Sum
+
+        [Fact]
+        public void SumInt32() => TestSum<int>(x => x.Aggregate((a, b) => a + b));
+
+        [Fact]
+        public void SumInt64() => TestSum<long>(x => x.Aggregate((a, b) => a + b));
+
+        [Fact]
+        public void SumSingle() => TestSum<float>(x => x.Aggregate((a, b) => a + b));
+
+        [Fact]
+        public void SumDouble() => TestSum<double>(x => x.Aggregate((a, b) => a + b));
+
+        [Fact]
+        public void SumUInt32() => TestSum<uint>(x => x.Aggregate((a, b) => a + b));
+
+        [Fact]
+        public void SumUInt64() => TestSum<ulong>(x => x.Aggregate((a, b) => a + b));
+
+        [Fact]
+        public void SumByte() => TestSum<byte>(x => x.Aggregate((a, b) => (byte)(a + b)));
+
+        [Fact]
+        public void SumSByte() => TestSum<sbyte>(x => x.Aggregate((a, b) => (sbyte)(a + b)));
+
+        [Fact]
+        public void SumInt16() => TestSum<short>(x => x.Aggregate((a, b) => (short)(a + b)));
+
+        [Fact]
+        public void SumUInt16() => TestSum<ushort>(x => x.Aggregate((a, b) => (ushort)(a + b)));
+
+        private static void TestSum<T>(Func<T[], T> expected) where T : struct, IEquatable<T>
+        {
+            T[] values = GenerateRandomValuesForVector<T>();
+            Vector<T> vector = new(values);
+            T sum = Vector.Sum(vector);
+
+            AssertEqual(expected(values), sum, "Sum");
+        }
+
+        #endregion
+
         #region Helper Methods
         private static void AssertEqual<T>(T expected, T actual, string operation, int precision = -1) where T : IEquatable<T>
         {
index c563a5d..c0464fe 100644 (file)
@@ -1292,5 +1292,14 @@ namespace System.Numerics
 
             return Unsafe.As<Vector<TFrom>, Vector<TTo>>(ref vector);
         }
+
+        /// <summary>
+        /// Returns the sum of all elements inside the vector.
+        /// </summary>
+        [Intrinsic]
+        public static T Sum<T>(Vector<T> value) where T : struct
+        {
+            return Vector<T>.Sum(value);
+        }
     }
 }
index 3bd503a..054cf5e 100644 (file)
@@ -823,6 +823,19 @@ namespace System.Numerics
         }
 
         [Intrinsic]
+        internal static T Sum(Vector<T> value)
+        {
+            T sum = default;
+
+            for (nint index = 0; index < Count; index++)
+            {
+                sum = ScalarAdd(sum, value.GetElement(index));
+            }
+
+            return sum;
+        }
+
+        [Intrinsic]
         internal static unsafe Vector<T> SquareRoot(Vector<T> value)
         {
             Vector<T> result = default;