Improve Dictionary TryGetValue size/perfomance (dotnet/coreclr#27195)
authorBen Adams <thundercat@illyriad.co.uk>
Wed, 16 Oct 2019 15:14:17 +0000 (16:14 +0100)
committerJan Kotas <jkotas@microsoft.com>
Wed, 16 Oct 2019 15:14:17 +0000 (08:14 -0700)
* Dictionary avoid second bounds check in Get methods

* Add NullRef methods to Unsafe

Commit migrated from https://github.com/dotnet/coreclr/commit/921b39ccb5d075cb61ee9b167d0d17fc4acfda59

src/coreclr/src/tools/crossgen2/Common/TypeSystem/IL/Stubs/UnsafeIntrinsics.cs
src/coreclr/src/vm/jitinterface.cpp
src/coreclr/src/vm/mscorlib.h
src/libraries/System.Private.CoreLib/src/Internal/Runtime/CompilerServices/Unsafe.cs
src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs
src/libraries/System.Private.CoreLib/src/System/Collections/HashHelpers.cs
src/libraries/System.Private.CoreLib/src/System/Memory.cs
src/libraries/System.Private.CoreLib/src/System/ReadOnlyMemory.cs
src/libraries/System.Private.CoreLib/src/System/ReadOnlySpan.Fast.cs
src/libraries/System.Private.CoreLib/src/System/Span.Fast.cs

index 7bc575a..d7c3ab5 100644 (file)
@@ -70,6 +70,18 @@ namespace Internal.IL.Stubs
                         (byte)ILOpcode.ldarg_1, (byte)ILOpcode.ldarg_0,
                         (byte)ILOpcode.sub,
                         (byte)ILOpcode.ret }, Array.Empty<LocalVariableDefinition>(), null);
+                case "NullRef":
+                    return new ILStubMethodIL(method, new byte[]
+                    {
+                        (byte)ILOpcode.ldc_i4_0, (byte)ILOpcode.conv_u,
+                        (byte)ILOpcode.ret }, Array.Empty<LocalVariableDefinition>(), null);
+                case "IsNullRef":
+                    return new ILStubMethodIL(method, new byte[]
+                    {
+                        (byte)ILOpcode.ldarg_0, 
+                        (byte)ILOpcode.ldc_i4_0, (byte)ILOpcode.conv_u,
+                        (byte)ILOpcode.prefix1, unchecked((byte)ILOpcode.ceq),
+                        (byte)ILOpcode.ret }, Array.Empty<LocalVariableDefinition>(), null);
             }
 
             return null;
index 12f621f..39959a8 100644 (file)
@@ -7254,6 +7254,29 @@ bool getILIntrinsicImplementationForUnsafe(MethodDesc * ftn,
         methInfo->options = (CorInfoOptions)0;
         return true;
     }
+    else if (tk == MscorlibBinder::GetMethod(METHOD__UNSAFE__BYREF_NULLREF)->GetMemberDef())
+    {
+        static const BYTE ilcode[] = { CEE_LDC_I4_0, CEE_CONV_U, CEE_RET };
+        methInfo->ILCode = const_cast<BYTE*>(ilcode);
+        methInfo->ILCodeSize = sizeof(ilcode);
+        methInfo->maxStack = 1;
+        methInfo->EHcount = 0;
+        methInfo->options = (CorInfoOptions)0;
+        return true;
+    }
+    else if (tk == MscorlibBinder::GetMethod(METHOD__UNSAFE__BYREF_IS_NULL)->GetMemberDef())
+    {
+        // 'ldnull' opcode would produce type o, and we can't compare & against o (ECMA-335, Table III.4).
+        // However, we can compare & against native int, so we'll use that instead.
+
+        static const BYTE ilcode[] = { CEE_LDARG_0, CEE_LDC_I4_0, CEE_CONV_U, CEE_PREFIX1, (CEE_CEQ & 0xFF), CEE_RET };
+        methInfo->ILCode = const_cast<BYTE*>(ilcode);
+        methInfo->ILCodeSize = sizeof(ilcode);
+        methInfo->maxStack = 2;
+        methInfo->EHcount = 0;
+        methInfo->options = (CorInfoOptions)0;
+        return true;
+    }
     else if (tk == MscorlibBinder::GetMethod(METHOD__UNSAFE__BYREF_INIT_BLOCK_UNALIGNED)->GetMemberDef())
     {
         static const BYTE ilcode[] = { CEE_LDARG_0, CEE_LDARG_1, CEE_LDARG_2, CEE_PREFIX1, (CEE_UNALIGNED & 0xFF), 0x01, CEE_PREFIX1, (CEE_INITBLK & 0xFF), CEE_RET };
index cfb62ea..63ccef8 100644 (file)
@@ -698,6 +698,8 @@ DEFINE_METHOD(JIT_HELPERS,          ENUM_COMPARE_TO,        EnumCompareTo, NoSig
 
 DEFINE_CLASS(UNSAFE,                InternalCompilerServices,       Unsafe)
 DEFINE_METHOD(UNSAFE,               AS_POINTER,             AsPointer, NoSig)
+DEFINE_METHOD(UNSAFE,               BYREF_IS_NULL,          IsNullRef, NoSig)
+DEFINE_METHOD(UNSAFE,               BYREF_NULLREF,          NullRef, NoSig)
 DEFINE_METHOD(UNSAFE,               AS_REF_IN,              AsRef, GM_RefT_RetRefT)
 DEFINE_METHOD(UNSAFE,               AS_REF_POINTER,         AsRef, GM_VoidPtr_RetRefT)
 DEFINE_METHOD(UNSAFE,               SIZEOF,                 SizeOf, NoSig)
index 4a93d98..485c157 100644 (file)
@@ -387,5 +387,40 @@ namespace Internal.Runtime.CompilerServices
         {
             throw new PlatformNotSupportedException();
         }
+
+        /// <summary>
+        /// Returns a by-ref to type <typeparamref name="T"/> that is a null reference.
+        /// </summary>
+        [Intrinsic]
+        [NonVersionable]
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static ref T NullRef<T>()
+        {
+            return ref Unsafe.AsRef<T>(null);
+
+            // ldc.i4.0
+            // conv.u
+            // ret
+        }
+
+        /// <summary>
+        /// Returns if a given by-ref to type <typeparamref name="T"/> is a null reference.
+        /// </summary>
+        /// <remarks>
+        /// This check is conceptually similar to "(void*)(&amp;source) == nullptr".
+        /// </remarks>
+        [Intrinsic]
+        [NonVersionable]
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static bool IsNullRef<T>(ref T source)
+        {
+            return Unsafe.AsPointer(ref source) == null;
+
+            // ldarg.0
+            // ldc.i4.0
+            // conv.u
+            // ceq
+            // ret
+        }
     }
 }
index 9303f5b..5eca70b 100644 (file)
@@ -7,6 +7,8 @@ using System.Diagnostics.CodeAnalysis;
 using System.Runtime.CompilerServices;
 using System.Runtime.Serialization;
 
+using Internal.Runtime.CompilerServices;
+
 namespace System.Collections.Generic
 {
     /// <summary>
@@ -38,11 +40,11 @@ namespace System.Collections.Generic
     {
         private struct Entry
         {
+            public uint hashCode;
             // 0-based index of next entry in chain: -1 means end of chain
             // also encodes whether this entry _itself_ is part of the free list by changing sign and subtracting 3,
             // so -2 means end of free list, -3 means index 0 but on free list, -4 means index 1 but on free list, etc.
             public int next;
-            public uint hashCode;
             public TKey key;           // Key of entry
             public TValue value;         // Value of entry
         }
@@ -168,8 +170,11 @@ namespace System.Collections.Generic
         {
             get
             {
-                int i = FindEntry(key);
-                if (i >= 0) return _entries![i].value;
+                ref TValue value = ref FindValue(key);
+                if (!Unsafe.IsNullRef(ref value))
+                {
+                    return value;
+                }
                 ThrowHelper.ThrowKeyNotFoundException(key);
                 return default;
             }
@@ -191,8 +196,8 @@ namespace System.Collections.Generic
 
         bool ICollection<KeyValuePair<TKey, TValue>>.Contains(KeyValuePair<TKey, TValue> keyValuePair)
         {
-            int i = FindEntry(keyValuePair.Key);
-            if (i >= 0 && EqualityComparer<TValue>.Default.Equals(_entries![i].value, keyValuePair.Value))
+            ref TValue value = ref FindValue(keyValuePair.Key);
+            if (!Unsafe.IsNullRef(ref value) && EqualityComparer<TValue>.Default.Equals(value, keyValuePair.Value))
             {
                 return true;
             }
@@ -201,8 +206,8 @@ namespace System.Collections.Generic
 
         bool ICollection<KeyValuePair<TKey, TValue>>.Remove(KeyValuePair<TKey, TValue> keyValuePair)
         {
-            int i = FindEntry(keyValuePair.Key);
-            if (i >= 0 && EqualityComparer<TValue>.Default.Equals(_entries![i].value, keyValuePair.Value))
+            ref TValue value = ref FindValue(keyValuePair.Key);
+            if (!Unsafe.IsNullRef(ref value) && EqualityComparer<TValue>.Default.Equals(value, keyValuePair.Value))
             {
                 Remove(keyValuePair.Key);
                 return true;
@@ -228,7 +233,7 @@ namespace System.Collections.Generic
         }
 
         public bool ContainsKey(TKey key)
-            => FindEntry(key) >= 0;
+            => !Unsafe.IsNullRef(ref FindValue(key));
 
         public bool ContainsValue(TValue value)
         {
@@ -318,47 +323,53 @@ namespace System.Collections.Generic
             }
         }
 
-        private int FindEntry(TKey key)
+        private ref TValue FindValue(TKey key)
         {
             if (key == null)
             {
                 ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
             }
 
-            int i = -1;
             int[]? buckets = _buckets;
-            Entry[]? entries = _entries;
-            int collisionCount = 0;
+            ref Entry entry = ref Unsafe.NullRef<Entry>();
             if (buckets != null)
             {
-                Debug.Assert(entries != null, "expected entries to be != null");
+                Debug.Assert(_entries != null, "expected entries to be != null");
                 IEqualityComparer<TKey>? comparer = _comparer;
                 if (comparer == null)
                 {
                     uint hashCode = (uint)key.GetHashCode();
-                    // Value in _buckets is 1-based
-                    i = buckets[hashCode % (uint)buckets.Length] - 1;
+                    int i = buckets[hashCode % (uint)buckets.Length];
+                    Entry[]? entries = _entries;
+                    uint collisionCount = 0;
                     if (default(TKey)! != null) // TODO-NULLABLE: default(T) == null warning (https://github.com/dotnet/roslyn/issues/34757)
                     {
                         // ValueType: Devirtualize with EqualityComparer<TValue>.Default intrinsic
-                        while (true)
+
+                        // Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional.
+                        i--;
+                        do
                         {
                             // Should be a while loop https://github.com/dotnet/coreclr/issues/15476
                             // Test in if to drop range check for following array access
-                            if ((uint)i >= (uint)entries.Length || (entries[i].hashCode == hashCode && EqualityComparer<TKey>.Default.Equals(entries[i].key, key)))
+                            if ((uint)i >= (uint)entries.Length)
                             {
-                                break;
+                                goto ReturnNotFound;
                             }
 
-                            i = entries[i].next;
-                            if (collisionCount >= entries.Length)
+                            entry = ref entries[i];
+                            if (entry.hashCode == hashCode && EqualityComparer<TKey>.Default.Equals(entry.key, key))
                             {
-                                // The chain of entries forms a loop; which means a concurrent update has happened.
-                                // Break out of the loop and throw, rather than looping forever.
-                                ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
+                                goto ReturnFound;
                             }
+
+                            i = entry.next;
+
                             collisionCount++;
-                        }
+                        } while (collisionCount <= (uint)entries.Length);
+                        // The chain of entries forms a loop; which means a concurrent update has happened.
+                        // Break out of the loop and throw, rather than looping forever.
+                        goto ConcurrentOperation;
                     }
                     else
                     {
@@ -366,54 +377,77 @@ namespace System.Collections.Generic
                         // https://github.com/dotnet/coreclr/issues/17273
                         // So cache in a local rather than get EqualityComparer per loop iteration
                         EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
-                        while (true)
+
+                        // Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional.
+                        i--;
+                        do
                         {
                             // Should be a while loop https://github.com/dotnet/coreclr/issues/15476
                             // Test in if to drop range check for following array access
-                            if ((uint)i >= (uint)entries.Length || (entries[i].hashCode == hashCode && defaultComparer.Equals(entries[i].key, key)))
+                            if ((uint)i >= (uint)entries.Length)
                             {
-                                break;
+                                goto ReturnNotFound;
                             }
 
-                            i = entries[i].next;
-                            if (collisionCount >= entries.Length)
+                            entry = ref entries[i];
+                            if (entry.hashCode == hashCode && defaultComparer.Equals(entry.key, key))
                             {
-                                // The chain of entries forms a loop; which means a concurrent update has happened.
-                                // Break out of the loop and throw, rather than looping forever.
-                                ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
+                                goto ReturnFound;
                             }
+
+                            i = entry.next;
+
                             collisionCount++;
-                        }
+                        } while (collisionCount <= (uint)entries.Length);
+                        // The chain of entries forms a loop; which means a concurrent update has happened.
+                        // Break out of the loop and throw, rather than looping forever.
+                        goto ConcurrentOperation;
                     }
                 }
                 else
                 {
                     uint hashCode = (uint)comparer.GetHashCode(key);
-                    // Value in _buckets is 1-based
-                    i = buckets[hashCode % (uint)buckets.Length] - 1;
-                    while (true)
+                    int i = buckets[hashCode % (uint)buckets.Length];
+                    Entry[]? entries = _entries;
+                    uint collisionCount = 0;
+                    // Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional.
+                    i--;
+                    do
                     {
                         // Should be a while loop https://github.com/dotnet/coreclr/issues/15476
                         // Test in if to drop range check for following array access
-                        if ((uint)i >= (uint)entries.Length ||
-                            (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)))
+                        if ((uint)i >= (uint)entries.Length)
                         {
-                            break;
+                            goto ReturnNotFound;
                         }
 
-                        i = entries[i].next;
-                        if (collisionCount >= entries.Length)
+                        entry = ref entries[i];
+                        if (entry.hashCode == hashCode && comparer.Equals(entry.key, key))
                         {
-                            // The chain of entries forms a loop; which means a concurrent update has happened.
-                            // Break out of the loop and throw, rather than looping forever.
-                            ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
+                            goto ReturnFound;
                         }
+
+                        i = entry.next;
+
                         collisionCount++;
-                    }
+                    } while (collisionCount <= (uint)entries.Length);
+                    // The chain of entries forms a loop; which means a concurrent update has happened.
+                    // Break out of the loop and throw, rather than looping forever.
+                    goto ConcurrentOperation;
                 }
             }
 
-            return i;
+            goto ReturnNotFound;
+
+        ConcurrentOperation:
+            ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
+        ReturnFound:
+            ref TValue value = ref entry.value;
+        Return:
+            return ref value;
+        ReturnNotFound:
+            value = ref Unsafe.NullRef<TValue>();
+            goto Return;
         }
 
         private int Initialize(int capacity)
@@ -446,7 +480,7 @@ namespace System.Collections.Generic
             IEqualityComparer<TKey>? comparer = _comparer;
             uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key));
 
-            int collisionCount = 0;
+            uint collisionCount = 0;
             ref int bucket = ref _buckets[hashCode % (uint)_buckets.Length];
             // Value in _buckets is 1-based
             int i = bucket - 1;
@@ -483,13 +517,14 @@ namespace System.Collections.Generic
                         }
 
                         i = entries[i].next;
-                        if (collisionCount >= entries.Length)
+
+                        collisionCount++;
+                        if (collisionCount > (uint)entries.Length)
                         {
                             // The chain of entries forms a loop; which means a concurrent update has happened.
                             // Break out of the loop and throw, rather than looping forever.
                             ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
                         }
-                        collisionCount++;
                     }
                 }
                 else
@@ -525,13 +560,14 @@ namespace System.Collections.Generic
                         }
 
                         i = entries[i].next;
-                        if (collisionCount >= entries.Length)
+
+                        collisionCount++;
+                        if (collisionCount > (uint)entries.Length)
                         {
                             // The chain of entries forms a loop; which means a concurrent update has happened.
                             // Break out of the loop and throw, rather than looping forever.
                             ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
                         }
-                        collisionCount++;
                     }
                 }
             }
@@ -564,13 +600,14 @@ namespace System.Collections.Generic
                     }
 
                     i = entries[i].next;
-                    if (collisionCount >= entries.Length)
+
+                    collisionCount++;
+                    if (collisionCount > (uint)entries.Length)
                     {
                         // The chain of entries forms a loop; which means a concurrent update has happened.
                         // Break out of the loop and throw, rather than looping forever.
                         ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
                     }
-                    collisionCount++;
                 }
             }
 
@@ -725,10 +762,10 @@ namespace System.Collections.Generic
 
             int[]? buckets = _buckets;
             Entry[]? entries = _entries;
-            int collisionCount = 0;
             if (buckets != null)
             {
                 Debug.Assert(entries != null, "entries should be non-null");
+                uint collisionCount = 0;
                 uint hashCode = (uint)(_comparer?.GetHashCode(key) ?? key.GetHashCode());
                 uint bucket = hashCode % (uint)buckets.Length;
                 int last = -1;
@@ -769,13 +806,14 @@ namespace System.Collections.Generic
 
                     last = i;
                     i = entry.next;
-                    if (collisionCount >= entries.Length)
+
+                    collisionCount++;
+                    if (collisionCount > (uint)entries.Length)
                     {
                         // The chain of entries forms a loop; which means a concurrent update has happened.
                         // Break out of the loop and throw, rather than looping forever.
                         ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
                     }
-                    collisionCount++;
                 }
             }
             return false;
@@ -793,10 +831,10 @@ namespace System.Collections.Generic
 
             int[]? buckets = _buckets;
             Entry[]? entries = _entries;
-            int collisionCount = 0;
             if (buckets != null)
             {
                 Debug.Assert(entries != null, "entries should be non-null");
+                uint collisionCount = 0;
                 uint hashCode = (uint)(_comparer?.GetHashCode(key) ?? key.GetHashCode());
                 uint bucket = hashCode % (uint)buckets.Length;
                 int last = -1;
@@ -839,13 +877,14 @@ namespace System.Collections.Generic
 
                     last = i;
                     i = entry.next;
-                    if (collisionCount >= entries.Length)
+
+                    collisionCount++;
+                    if (collisionCount > (uint)entries.Length)
                     {
                         // The chain of entries forms a loop; which means a concurrent update has happened.
                         // Break out of the loop and throw, rather than looping forever.
                         ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
                     }
-                    collisionCount++;
                 }
             }
             value = default!;
@@ -854,10 +893,10 @@ namespace System.Collections.Generic
 
         public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value)
         {
-            int i = FindEntry(key);
-            if (i >= 0)
+            ref TValue valRef = ref FindValue(key);
+            if (!Unsafe.IsNullRef(ref valRef))
             {
-                value = _entries![i].value;
+                value = valRef;
                 return true;
             }
             value = default!;
@@ -1022,10 +1061,10 @@ namespace System.Collections.Generic
             {
                 if (IsCompatibleKey(key))
                 {
-                    int i = FindEntry((TKey)key);
-                    if (i >= 0)
+                    ref TValue value = ref FindValue((TKey)key);
+                    if (!Unsafe.IsNullRef(ref value))
                     {
-                        return _entries![i].value;
+                        return value;
                     }
                 }
                 return null;
index 92714a2..3b53609 100644 (file)
@@ -8,7 +8,7 @@ namespace System.Collections
 {
     internal static partial class HashHelpers
     {
-        public const int HashCollisionThreshold = 100;
+        public const uint HashCollisionThreshold = 100;
 
         // This is the maximum prime smaller than Array.MaxArrayLength
         public const int MaxPrimeArrayLength = 0x7FEFFFFD;
index 25f5d8f..bc43764 100644 (file)
@@ -295,7 +295,7 @@ namespace System
                 // in which case that's the dangerous operation performed by the dev, and we're just following
                 // suit here to make it work as best as possible.
 
-                ref T refToReturn = ref Unsafe.AsRef<T>(null);
+                ref T refToReturn = ref Unsafe.NullRef<T>();
                 int lengthOfUnderlyingSpan = 0;
 
                 // Copy this field into a local so that it can't change out from under us mid-operation.
index 18843ad..2510e0f 100644 (file)
@@ -217,7 +217,7 @@ namespace System
             [MethodImpl(MethodImplOptions.AggressiveInlining)]
             get
             {
-                ref T refToReturn = ref Unsafe.AsRef<T>(null);
+                ref T refToReturn = ref Unsafe.NullRef<T>();
                 int lengthOfUnderlyingSpan = 0;
 
                 // Copy this field into a local so that it can't change out from under us mid-operation.
index e4514ef..238b90f 100644 (file)
@@ -148,10 +148,10 @@ namespace System
         /// It can be used for pinning and is required to support the use of span within a fixed statement.
         /// </summary>
         [EditorBrowsable(EditorBrowsableState.Never)]
-        public unsafe ref readonly T GetPinnableReference()
+        public ref readonly T GetPinnableReference()
         {
             // Ensure that the native code has just one forward branch that is predicted-not-taken.
-            ref T ret = ref Unsafe.AsRef<T>(null);
+            ref T ret = ref Unsafe.NullRef<T>();
             if (_length != 0) ret = ref _pointer.Value;
             return ref ret;
         }
index 2ecc2a0..2ecd1cc 100644 (file)
@@ -154,10 +154,10 @@ namespace System
         /// It can be used for pinning and is required to support the use of span within a fixed statement.
         /// </summary>
         [EditorBrowsable(EditorBrowsableState.Never)]
-        public unsafe ref T GetPinnableReference()
+        public ref T GetPinnableReference()
         {
             // Ensure that the native code has just one forward branch that is predicted-not-taken.
-            ref T ret = ref Unsafe.AsRef<T>(null);
+            ref T ret = ref Unsafe.NullRef<T>();
             if (_length != 0) ret = ref _pointer.Value;
             return ref ret;
         }