Avoid mod operator when fast alternative available (dotnet/coreclr#27299)
authorBen Adams <thundercat@illyriad.co.uk>
Sat, 26 Oct 2019 13:57:12 +0000 (14:57 +0100)
committerJan Kotas <jkotas@microsoft.com>
Sat, 26 Oct 2019 13:57:12 +0000 (06:57 -0700)
Commit migrated from https://github.com/dotnet/coreclr/commit/e532bf642a3a381d53ff52c234f29deb7d11e7a0

src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs
src/libraries/System.Private.CoreLib/src/System/Collections/HashHelpers.cs

index 5eca70b..e7726f9 100644 (file)
@@ -51,6 +51,9 @@ namespace System.Collections.Generic
 
         private int[]? _buckets;
         private Entry[]? _entries;
+#if BIT64
+        private ulong _fastModMultiplier;
+#endif
         private int _count;
         private int _freeList;
         private int _freeCount;
@@ -330,16 +333,15 @@ namespace System.Collections.Generic
                 ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
             }
 
-            int[]? buckets = _buckets;
             ref Entry entry = ref Unsafe.NullRef<Entry>();
-            if (buckets != null)
+            if (_buckets != null)
             {
                 Debug.Assert(_entries != null, "expected entries to be != null");
                 IEqualityComparer<TKey>? comparer = _comparer;
                 if (comparer == null)
                 {
                     uint hashCode = (uint)key.GetHashCode();
-                    int i = buckets[hashCode % (uint)buckets.Length];
+                    int i = GetBucket(hashCode);
                     Entry[]? entries = _entries;
                     uint collisionCount = 0;
                     if (default(TKey)! != null) // TODO-NULLABLE: default(T) == null warning (https://github.com/dotnet/roslyn/issues/34757)
@@ -407,7 +409,7 @@ namespace System.Collections.Generic
                 else
                 {
                     uint hashCode = (uint)comparer.GetHashCode(key);
-                    int i = buckets[hashCode % (uint)buckets.Length];
+                    int i = GetBucket(hashCode);
                     Entry[]? entries = _entries;
                     uint collisionCount = 0;
                     // Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional.
@@ -453,10 +455,16 @@ namespace System.Collections.Generic
         private int Initialize(int capacity)
         {
             int size = HashHelpers.GetPrime(capacity);
+            int[] buckets = new int[size];
+            Entry[] entries = new Entry[size];
 
+            // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
             _freeList = -1;
-            _buckets = new int[size];
-            _entries = new Entry[size];
+#if BIT64
+            _fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)size);
+#endif
+            _buckets = buckets;
+            _entries = entries;
 
             return size;
         }
@@ -481,7 +489,7 @@ namespace System.Collections.Generic
             uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key));
 
             uint collisionCount = 0;
-            ref int bucket = ref _buckets[hashCode % (uint)_buckets.Length];
+            ref int bucket = ref GetBucket(hashCode);
             // Value in _buckets is 1-based
             int i = bucket - 1;
 
@@ -625,7 +633,7 @@ namespace System.Collections.Generic
                 if (count == entries.Length)
                 {
                     Resize();
-                    bucket = ref _buckets[hashCode % (uint)_buckets.Length];
+                    bucket = ref GetBucket(hashCode);
                 }
                 index = count;
                 _count = count + 1;
@@ -716,7 +724,6 @@ namespace System.Collections.Generic
             Debug.Assert(_entries != null, "_entries should be non-null");
             Debug.Assert(newSize >= _entries.Length);
 
-            int[] buckets = new int[newSize];
             Entry[] entries = new Entry[newSize];
 
             int count = _count;
@@ -734,19 +741,23 @@ namespace System.Collections.Generic
                 }
             }
 
+            // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
+            _buckets = new int[newSize];
+#if BIT64
+            _fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)newSize);
+#endif
             for (int i = 0; i < count; i++)
             {
                 if (entries[i].next >= -1)
                 {
-                    uint bucket = entries[i].hashCode % (uint)newSize;
+                    ref int bucket = ref GetBucket(entries[i].hashCode);
                     // Value in _buckets is 1-based
-                    entries[i].next = buckets[bucket] - 1;
+                    entries[i].next = bucket - 1;
                     // Value in _buckets is 1-based
-                    buckets[bucket] = i + 1;
+                    bucket = i + 1;
                 }
             }
 
-            _buckets = buckets;
             _entries = entries;
         }
 
@@ -760,17 +771,16 @@ namespace System.Collections.Generic
                 ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
             }
 
-            int[]? buckets = _buckets;
-            Entry[]? entries = _entries;
-            if (buckets != null)
+            if (_buckets != null)
             {
-                Debug.Assert(entries != null, "entries should be non-null");
+                Debug.Assert(_entries != null, "entries should be non-null");
                 uint collisionCount = 0;
                 uint hashCode = (uint)(_comparer?.GetHashCode(key) ?? key.GetHashCode());
-                uint bucket = hashCode % (uint)buckets.Length;
+                ref int bucket = ref GetBucket(hashCode);
+                Entry[]? entries = _entries;
                 int last = -1;
                 // Value in buckets is 1-based
-                int i = buckets[bucket] - 1;
+                int i = bucket - 1;
                 while (i >= 0)
                 {
                     ref Entry entry = ref entries[i];
@@ -780,7 +790,7 @@ namespace System.Collections.Generic
                         if (last < 0)
                         {
                             // Value in buckets is 1-based
-                            buckets[bucket] = entry.next + 1;
+                            bucket = entry.next + 1;
                         }
                         else
                         {
@@ -829,17 +839,16 @@ namespace System.Collections.Generic
                 ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
             }
 
-            int[]? buckets = _buckets;
-            Entry[]? entries = _entries;
-            if (buckets != null)
+            if (_buckets != null)
             {
-                Debug.Assert(entries != null, "entries should be non-null");
+                Debug.Assert(_entries != null, "entries should be non-null");
                 uint collisionCount = 0;
                 uint hashCode = (uint)(_comparer?.GetHashCode(key) ?? key.GetHashCode());
-                uint bucket = hashCode % (uint)buckets.Length;
+                ref int bucket = ref GetBucket(hashCode);
+                Entry[]? entries = _entries;
                 int last = -1;
                 // Value in buckets is 1-based
-                int i = buckets[bucket] - 1;
+                int i = bucket - 1;
                 while (i >= 0)
                 {
                     ref Entry entry = ref entries[i];
@@ -849,7 +858,7 @@ namespace System.Collections.Generic
                         if (last < 0)
                         {
                             // Value in buckets is 1-based
-                            buckets[bucket] = entry.next + 1;
+                            bucket = entry.next + 1;
                         }
                         else
                         {
@@ -982,6 +991,7 @@ namespace System.Collections.Generic
             _version++;
             if (_buckets == null)
                 return Initialize(capacity);
+
             int newSize = HashHelpers.GetPrime(capacity);
             Resize(newSize, forceNewHashCodes: false);
             return newSize;
@@ -1011,8 +1021,8 @@ namespace System.Collections.Generic
         {
             if (capacity < Count)
                 ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.capacity);
-            int newSize = HashHelpers.GetPrime(capacity);
 
+            int newSize = HashHelpers.GetPrime(capacity);
             Entry[]? oldEntries = _entries;
             int currentCapacity = oldEntries == null ? 0 : oldEntries.Length;
             if (newSize >= currentCapacity)
@@ -1022,7 +1032,6 @@ namespace System.Collections.Generic
             _version++;
             Initialize(newSize);
             Entry[]? entries = _entries;
-            int[]? buckets = _buckets;
             int count = 0;
             for (int i = 0; i < oldCount; i++)
             {
@@ -1031,11 +1040,11 @@ namespace System.Collections.Generic
                 {
                     ref Entry entry = ref entries![count];
                     entry = oldEntries[i];
-                    uint bucket = hashCode % (uint)newSize;
+                    ref int bucket = ref GetBucket(hashCode);
                     // Value in _buckets is 1-based
-                    entry.next = buckets![bucket] - 1; // If we get here, we have entries, therefore buckets is not null.
+                    entry.next = bucket - 1;
                     // Value in _buckets is 1-based
-                    buckets[bucket] = count + 1;
+                    bucket = count + 1;
                     count++;
                 }
             }
@@ -1153,6 +1162,17 @@ namespace System.Collections.Generic
             }
         }
 
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private ref int GetBucket(uint hashCode)
+        {
+            int[] buckets = _buckets!;
+#if BIT64
+            return ref buckets[HashHelpers.FastMod(hashCode, (uint)buckets.Length, _fastModMultiplier)];
+#else
+            return ref buckets[hashCode % (uint)buckets.Length];
+#endif
+        }
+
         public struct Enumerator : IEnumerator<KeyValuePair<TKey, TValue>>,
             IDictionaryEnumerator
         {
index 3b53609..b01b01d 100644 (file)
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Diagnostics;
+using System.Runtime.CompilerServices;
 
 namespace System.Collections
 {
@@ -28,12 +29,14 @@ namespace System.Collections
         // h1(key) + i*h2(key), 0 <= i < size.  h2 and the size must be relatively prime.
         // We prefer the low computation costs of higher prime numbers over the increased
         // memory allocation of a fixed prime number i.e. when right sizing a HashSet.
-        public static readonly int[] primes = {
+        private static readonly int[] s_primes =
+        {
             3, 7, 11, 17, 23, 29, 37, 47, 59, 71, 89, 107, 131, 163, 197, 239, 293, 353, 431, 521, 631, 761, 919,
             1103, 1327, 1597, 1931, 2333, 2801, 3371, 4049, 4861, 5839, 7013, 8419, 10103, 12143, 14591,
             17519, 21023, 25229, 30293, 36353, 43627, 52361, 62851, 75431, 90523, 108631, 130363, 156437,
             187751, 225307, 270371, 324449, 389357, 467237, 560689, 672827, 807403, 968897, 1162687, 1395263,
-            1674319, 2009191, 2411033, 2893249, 3471899, 4166287, 4999559, 5999471, 7199369 };
+            1674319, 2009191, 2411033, 2893249, 3471899, 4166287, 4999559, 5999471, 7199369
+        };
 
         public static bool IsPrime(int candidate)
         {
@@ -55,9 +58,8 @@ namespace System.Collections
             if (min < 0)
                 throw new ArgumentException(SR.Arg_HTCapacityOverflow);
 
-            for (int i = 0; i < primes.Length; i++)
+            foreach (int prime in s_primes)
             {
-                int prime = primes[i];
                 if (prime >= min)
                     return prime;
             }
@@ -86,5 +88,24 @@ namespace System.Collections
 
             return GetPrime(newSize);
         }
+
+#if BIT64
+        public static ulong GetFastModMultiplier(uint divisor)
+            => ulong.MaxValue / divisor + 1;
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static uint FastMod(uint value, uint divisor, ulong multiplier)
+        {
+            // Using fastmod from Daniel Lemire https://lemire.me/blog/2019/02/08/faster-remainders-when-the-divisor-is-a-constant-beating-compilers-and-libdivide/
+
+            ulong lowbits = multiplier * value;
+            // 64bit * 64bit => 128bit isn't currently supported by Math https://github.com/dotnet/corefx/issues/41822
+            // otherwise we'd want this to be (uint)Math.MultiplyHigh(lowbits, divisor)
+            uint high = (uint)((((ulong)(uint)lowbits * divisor >> 32) + (lowbits >> 32) * divisor) >> 32);
+
+            Debug.Assert(high == value % divisor);
+            return high;
+        }
+#endif
     }
 }