Vectorize Enumerable.Range initialization, take 2 (#87992)
authorneon-sunset <neon-sunset@outlook.com>
Fri, 21 Jul 2023 15:10:47 +0000 (18:10 +0300)
committerGitHub <noreply@github.com>
Fri, 21 Jul 2023 15:10:47 +0000 (11:10 -0400)
* Vectorize Enumerable.Range initialization

* Address PR feedback

---------

Co-authored-by: Stephen Toub <stoub@microsoft.com>
src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs
src/libraries/System.Linq/tests/RangeTests.cs

index e15a7e2..c058142 100644 (file)
@@ -2,6 +2,9 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Collections.Generic;
+using System.Numerics;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
 
 namespace System.Linq
 {
@@ -16,15 +19,17 @@ namespace System.Linq
 
             public int[] ToArray()
             {
-                int[] array = new int[_end - _start];
-                Fill(array, _start);
+                int start = _start;
+                int[] array = new int[_end - start];
+                Fill(array, start);
                 return array;
             }
 
             public List<int> ToList()
             {
-                List<int> list = new List<int>(_end - _start);
-                Fill(SetCountAndGetSpan(list, _end - _start), _start);
+                (int start, int end) = (_start, _end);
+                List<int> list = new List<int>(end - start);
+                Fill(SetCountAndGetSpan(list, end - start), start);
                 return list;
             }
 
@@ -33,9 +38,33 @@ namespace System.Linq
 
             private static void Fill(Span<int> destination, int value)
             {
-                for (int i = 0; i < destination.Length; i++, value++)
+                ref int pos = ref MemoryMarshal.GetReference(destination);
+                ref int end = ref Unsafe.Add(ref pos, destination.Length);
+
+                if (Vector.IsHardwareAccelerated &&
+                    Vector<int>.Count <= 8 &&
+                    destination.Length >= Vector<int>.Count)
+                {
+                    Vector<int> init = new Vector<int>((ReadOnlySpan<int>)new int[] { 0, 1, 2, 3, 4, 5, 6, 7 });
+                    Vector<int> current = new Vector<int>(value) + init;
+                    Vector<int> increment = new Vector<int>(Vector<int>.Count);
+
+                    ref int oneVectorFromEnd = ref Unsafe.Subtract(ref end, Vector<int>.Count);
+                    do
+                    {
+                        current.StoreUnsafe(ref pos);
+                        current += increment;
+                        pos = ref Unsafe.Add(ref pos, Vector<int>.Count);
+                    }
+                    while (!Unsafe.IsAddressGreaterThan(ref pos, ref oneVectorFromEnd));
+
+                    value = current[0];
+                }
+
+                while (Unsafe.IsAddressLessThan(ref pos, ref end))
                 {
-                    destination[i] = value;
+                    pos = value++;
+                    pos = ref Unsafe.Add(ref pos, 1);
                 }
             }
 
index ab6e1b9..8421a66 100644 (file)
@@ -1,11 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System;
 using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
 using Xunit;
 
 namespace System.Linq.Tests
@@ -26,11 +22,20 @@ namespace System.Linq.Tests
             Assert.Equal(100, expected);
         }
 
-        [Fact]
-        public void Range_ToArray_ProduceCorrectResult()
+        public static IEnumerable<object[]> Range_ToArray_ProduceCorrectResult_MemberData()
+        {
+            for (int i = 0; i < 64; i++)
+            {
+                yield return new object[] { i };
+            }
+        }
+
+        [Theory]
+        [MemberData(nameof(Range_ToArray_ProduceCorrectResult_MemberData))]
+        public void Range_ToArray_ProduceCorrectResult(int length)
         {
-            var array = Enumerable.Range(1, 100).ToArray();
-            Assert.Equal(100, array.Length);
+            var array = Enumerable.Range(1, length).ToArray();
+            Assert.Equal(length, array.Length);
             for (var i = 0; i < array.Length; i++)
                 Assert.Equal(i + 1, array[i]);
         }