Improve dictionary & hashset lookup perf for OrdinalIgnoreCase (#36252)
authorLevi Broderick <GrabYourPitchforks@users.noreply.github.com>
Thu, 6 Aug 2020 04:16:21 +0000 (21:16 -0700)
committerGitHub <noreply@github.com>
Thu, 6 Aug 2020 04:16:21 +0000 (21:16 -0700)
src/libraries/System.Collections/tests/Generic/Dictionary/Dictionary.Generic.Tests.cs
src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs
src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs
src/libraries/System.Private.CoreLib/src/System/Collections/Generic/HashSet.cs
src/libraries/System.Private.CoreLib/src/System/Collections/Generic/IInternalStringEqualityComparer.cs [new file with mode: 0644]
src/libraries/System.Private.CoreLib/src/System/Collections/Generic/NonRandomizedStringEqualityComparer.cs
src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs [new file with mode: 0644]
src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs
src/libraries/System.Private.CoreLib/src/System/StringComparer.cs
src/libraries/System.Resources.Extensions/tests/TestData.resources

index 3de7252..32f277e 100644 (file)
@@ -3,7 +3,9 @@
 
 using Common.System;
 using System.Collections.Generic;
+using System.Globalization;
 using System.Linq;
+using System.Runtime.Serialization;
 using Xunit;
 
 namespace System.Collections.Tests
@@ -610,5 +612,37 @@ namespace System.Collections.Tests
         }
 
         #endregion
+
+        #region Non-randomized comparers
+        [Fact]
+        public void Dictionary_Comparer_NonRandomizedStringComparers()
+        {
+            RunTest(null);
+            RunTest(EqualityComparer<string>.Default);
+            RunTest(StringComparer.Ordinal);
+            RunTest(StringComparer.OrdinalIgnoreCase);
+            RunTest(StringComparer.InvariantCulture);
+            RunTest(StringComparer.InvariantCultureIgnoreCase);
+            RunTest(StringComparer.Create(CultureInfo.InvariantCulture, ignoreCase: false));
+            RunTest(StringComparer.Create(CultureInfo.InvariantCulture, ignoreCase: true));
+
+            void RunTest(IEqualityComparer<string> comparer)
+            {
+                // First, instantiate the dictionary and check its Comparer property
+
+                Dictionary<string, object> dict = new Dictionary<string, object>(comparer);
+                object expected = comparer ?? EqualityComparer<string>.Default;
+
+                Assert.Same(expected, dict.Comparer);
+
+                // Then pretend to serialize the dictionary and check the stored Comparer instance
+
+                SerializationInfo si = new SerializationInfo(typeof(Dictionary<string, object>), new FormatterConverter());
+                dict.GetObjectData(si, new StreamingContext(StreamingContextStates.All));
+
+                Assert.Same(expected, si.GetValue("Comparer", typeof(IEqualityComparer<string>)));
+            }
+        }
+        #endregion
     }
 }
index bc83088..934e57a 100644 (file)
@@ -2,6 +2,10 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Collections.Generic;
+using System.Numerics;
+using System.Reflection;
+using System.Runtime.InteropServices;
+using System.Runtime.Serialization;
 using Xunit;
 
 namespace System.Collections.Tests
@@ -37,5 +41,238 @@ namespace System.Collections.Tests
                     dictionary.Remove(key);
             }
         }
+
+        [Fact]
+        public static void ComparerImplementations_Dictionary_WithWellKnownStringComparers()
+        {
+            Type nonRandomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
+            Type nonRandomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
+            Type randomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
+            Type randomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
+
+            // null comparer
+
+            RunDictionaryTest(
+                equalityComparer: null,
+                expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+                expectedPublicComparerBeforeCollisionThreshold: EqualityComparer<string>.Default.GetType(),
+                expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+            // EqualityComparer<string>.Default comparer
+
+            RunDictionaryTest(
+                equalityComparer: EqualityComparer<string>.Default,
+                expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+                expectedPublicComparerBeforeCollisionThreshold: EqualityComparer<string>.Default.GetType(),
+                expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+            // Ordinal comparer
+
+            RunDictionaryTest(
+                equalityComparer: StringComparer.Ordinal,
+                expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+                expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal.GetType(),
+                expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+            // OrdinalIgnoreCase comparer
+
+            RunDictionaryTest(
+                equalityComparer: StringComparer.OrdinalIgnoreCase,
+                expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType,
+                expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase.GetType(),
+                expectedComparerAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType);
+
+            // linguistic comparer (not optimized)
+
+            RunDictionaryTest(
+                equalityComparer: StringComparer.InvariantCulture,
+                expectedInternalComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
+                expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
+                expectedComparerAfterCollisionThreshold: StringComparer.InvariantCulture.GetType());
+
+            static void RunDictionaryTest(
+                IEqualityComparer<string> equalityComparer,
+                Type expectedInternalComparerBeforeCollisionThreshold,
+                Type expectedPublicComparerBeforeCollisionThreshold,
+                Type expectedComparerAfterCollisionThreshold)
+            {
+                RunCollectionTestCommon(
+                    () => new Dictionary<string, object>(equalityComparer),
+                    (dictionary, key) => dictionary.Add(key, null),
+                    (dictionary, key) => dictionary.ContainsKey(key),
+                    dictionary => dictionary.Comparer,
+                    expectedInternalComparerBeforeCollisionThreshold,
+                    expectedPublicComparerBeforeCollisionThreshold,
+                    expectedComparerAfterCollisionThreshold);
+            }
+        }
+
+        [Fact]
+        public static void ComparerImplementations_HashSet_WithWellKnownStringComparers()
+        {
+            Type nonRandomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
+            Type nonRandomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
+            Type randomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
+            Type randomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
+
+            // null comparer
+
+            RunHashSetTest(
+                equalityComparer: null,
+                expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+                expectedPublicComparerBeforeCollisionThreshold: EqualityComparer<string>.Default.GetType(),
+                expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+            // EqualityComparer<string>.Default comparer
+
+            RunHashSetTest(
+                equalityComparer: EqualityComparer<string>.Default,
+                expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+                expectedPublicComparerBeforeCollisionThreshold: EqualityComparer<string>.Default.GetType(),
+                expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+            // Ordinal comparer
+
+            RunHashSetTest(
+                equalityComparer: StringComparer.Ordinal,
+                expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+                expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal.GetType(),
+                expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+            // OrdinalIgnoreCase comparer
+
+            RunHashSetTest(
+                equalityComparer: StringComparer.OrdinalIgnoreCase,
+                expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType,
+                expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase.GetType(),
+                expectedComparerAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType);
+
+            // linguistic comparer (not optimized)
+
+            RunHashSetTest(
+                equalityComparer: StringComparer.InvariantCulture,
+                expectedInternalComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
+                expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
+                expectedComparerAfterCollisionThreshold: StringComparer.InvariantCulture.GetType());
+
+            static void RunHashSetTest(
+                IEqualityComparer<string> equalityComparer,
+                Type expectedInternalComparerBeforeCollisionThreshold,
+                Type expectedPublicComparerBeforeCollisionThreshold,
+                Type expectedComparerAfterCollisionThreshold)
+            {
+                RunCollectionTestCommon(
+                    () => new HashSet<string>(equalityComparer),
+                    (set, key) => Assert.True(set.Add(key)),
+                    (set, key) => set.Contains(key),
+                    set => set.Comparer,
+                    expectedInternalComparerBeforeCollisionThreshold,
+                    expectedPublicComparerBeforeCollisionThreshold,
+                    expectedComparerAfterCollisionThreshold);
+            }
+        }
+
+        private static void RunCollectionTestCommon<TCollection>(
+            Func<TCollection> collectionFactory,
+            Action<TCollection, string> addKeyCallback,
+            Func<TCollection, string, bool> containsKeyCallback,
+            Func<TCollection, IEqualityComparer<string>> getComparerCallback,
+            Type expectedInternalComparerBeforeCollisionThreshold,
+            Type expectedPublicComparerBeforeCollisionThreshold,
+            Type expectedComparerAfterCollisionThreshold)
+        {
+            TCollection collection = collectionFactory();
+            List<string> allKeys = new List<string>();
+
+            const int StartOfRange = 0xE020; // use the Unicode Private Use range to avoid accidentally creating strings that really do compare as equal OrdinalIgnoreCase
+            const int Stride = 0x40; // to ensure we don't accidentally reset the 0x20 bit of the seed, which is used to negate OrdinalIgnoreCase effects
+
+            // First, go right up to the collision threshold, but don't exceed it.
+
+            for (int i = 0; i < 100; i++)
+            {
+                string newKey = GenerateCollidingString(i * Stride + StartOfRange);
+                Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal
+                Assert.Equal(0x24716ca0, _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(newKey)); // ensure has a zero hash code OrdinalIgnoreCase
+
+                addKeyCallback(collection, newKey);
+                allKeys.Add(newKey);
+            }
+
+            FieldInfo internalComparerField = collection.GetType().GetField("_comparer", BindingFlags.NonPublic | BindingFlags.Instance);
+            Assert.NotNull(internalComparerField);
+
+            Assert.Equal(expectedInternalComparerBeforeCollisionThreshold, internalComparerField.GetValue(collection)?.GetType());
+            Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection).GetType());
+
+            // Now exceed the collision threshold, which should rebucket entries.
+            // Continue adding a few more entries to ensure we didn't corrupt internal state.
+
+            for (int i = 100; i < 110; i++)
+            {
+                string newKey = GenerateCollidingString(i * Stride + StartOfRange);
+                Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal
+                Assert.Equal(0x24716ca0, _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(newKey)); // ensure has a zero hash code OrdinalIgnoreCase
+
+                addKeyCallback(collection, newKey);
+                allKeys.Add(newKey);
+            }
+
+            Assert.Equal(expectedComparerAfterCollisionThreshold, internalComparerField.GetValue(collection)?.GetType());
+            Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection).GetType()); // shouldn't change this return value after collision threshold met
+
+            // And validate that all strings are present in the dictionary.
+
+            foreach (string key in allKeys)
+            {
+                Assert.True(containsKeyCallback(collection, key));
+            }
+
+            // Also make sure we didn't accidentally put the internal comparer in the serialized object data.
+
+            collection = collectionFactory();
+            SerializationInfo si = new SerializationInfo(collection.GetType(), new FormatterConverter());
+            ((ISerializable)collection).GetObjectData(si, new StreamingContext());
+
+            object serializedComparer = si.GetValue("Comparer", typeof(IEqualityComparer<string>));
+            Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, serializedComparer.GetType());
+        }
+
+        private static Lazy<Func<string, int>> _lazyGetNonRandomizedHashCodeDel = new Lazy<Func<string, int>>(
+            () => GetStringHashCodeOpenDelegate("GetNonRandomizedHashCode"));
+
+        private static Lazy<Func<string, int>> _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel = new Lazy<Func<string, int>>(
+            () => GetStringHashCodeOpenDelegate("GetNonRandomizedHashCodeOrdinalIgnoreCase"));
+
+        // Generates a string with a well-known non-randomized hash code:
+        // - string.GetNonRandomizedHashCode returns 0.
+        // - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0x24716ca0.
+        // Provide a different seed to produce a different string.
+        private static string GenerateCollidingString(int seed)
+        {
+            return string.Create(8, seed, (span, seed) =>
+            {
+                Span<byte> asBytes = MemoryMarshal.AsBytes(span);
+
+                uint hash1 = (5381 << 16) + 5381;
+                uint hash2 = BitOperations.RotateLeft(hash1, 5) + hash1;
+
+                MemoryMarshal.Write(asBytes, ref seed);
+                MemoryMarshal.Write(asBytes.Slice(4), ref hash2); // set hash2 := 0 (for Ordinal)
+
+                hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (uint)seed;
+                hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1);
+
+                MemoryMarshal.Write(asBytes.Slice(8), ref hash1); // set hash1 := 0 (for Ordinal)
+            });
+        }
+
+        private static Func<string, int> GetStringHashCodeOpenDelegate(string methodName)
+        {
+            MethodInfo method = typeof(string).GetMethod(methodName, BindingFlags.Instance | BindingFlags.NonPublic);
+            Assert.NotNull(method);
+
+            return method.CreateDelegate<Func<string, int>>(target: null); // create open delegate unbound to 'this'
+        }
     }
 }
index 1715d43..4278e25 100644 (file)
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IEnumerable.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IEnumerator.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IEqualityComparer.cs" />
+    <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IInternalStringEqualityComparer.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IList.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\InsertionBehavior.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IReadOnlyCollection.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\KeyNotFoundException.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\KeyValuePair.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\List.cs" />
+    <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\RandomizedStringEqualityComparer.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\NonRandomizedStringEqualityComparer.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\ValueListBuilder.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Collections\HashHelpers.cs" />
index ac4fd58..6c5682e 100644 (file)
@@ -59,10 +59,24 @@ namespace System.Collections.Generic
                 _comparer = comparer;
             }
 
-            if (typeof(TKey) == typeof(string) && _comparer == null)
+            // Special-case EqualityComparer<string>.Default, StringComparer.Ordinal, and StringComparer.OrdinalIgnoreCase.
+            // We use a non-randomized comparer for improved perf, falling back to a randomized comparer if the
+            // hash buckets become unbalanced.
+
+            if (typeof(TKey) == typeof(string))
             {
-                // To start, move off default comparer for string which is randomised
-                _comparer = (IEqualityComparer<TKey>)NonRandomizedStringEqualityComparer.Default;
+                if (_comparer is null)
+                {
+                    _comparer = (IEqualityComparer<TKey>)NonRandomizedStringEqualityComparer.WrappedAroundDefaultComparer;
+                }
+                else if (ReferenceEquals(_comparer, StringComparer.Ordinal))
+                {
+                    _comparer = (IEqualityComparer<TKey>)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinal;
+                }
+                else if (ReferenceEquals(_comparer, StringComparer.OrdinalIgnoreCase))
+                {
+                    _comparer = (IEqualityComparer<TKey>)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinalIgnoreCase;
+                }
             }
         }
 
@@ -125,10 +139,20 @@ namespace System.Collections.Generic
             HashHelpers.SerializationInfoTable.Add(this, info);
         }
 
-        public IEqualityComparer<TKey> Comparer =>
-            (_comparer == null || _comparer is NonRandomizedStringEqualityComparer) ?
-                EqualityComparer<TKey>.Default :
-                _comparer;
+        public IEqualityComparer<TKey> Comparer
+        {
+            get
+            {
+                if (typeof(TKey) == typeof(string))
+                {
+                    return (IEqualityComparer<TKey>)IInternalStringEqualityComparer.GetUnderlyingEqualityComparer((IEqualityComparer<string?>?)_comparer);
+                }
+                else
+                {
+                    return _comparer ?? EqualityComparer<TKey>.Default;
+                }
+            }
+        }
 
         public int Count => _count - _freeCount;
 
@@ -299,7 +323,7 @@ namespace System.Collections.Generic
             }
 
             info.AddValue(VersionName, _version);
-            info.AddValue(ComparerName, _comparer ?? EqualityComparer<TKey>.Default, typeof(IEqualityComparer<TKey>));
+            info.AddValue(ComparerName, Comparer, typeof(IEqualityComparer<TKey>));
             info.AddValue(HashSizeName, _buckets == null ? 0 : _buckets.Length); // This is the length of the bucket array
 
             if (_buckets != null)
@@ -633,7 +657,6 @@ namespace System.Collections.Generic
             {
                 // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing
                 // i.e. EqualityComparer<string>.Default.
-                _comparer = null;
                 Resize(entries.Length, true);
             }
 
@@ -702,14 +725,21 @@ namespace System.Collections.Generic
 
             if (!typeof(TKey).IsValueType && forceNewHashCodes)
             {
+                Debug.Assert(_comparer is NonRandomizedStringEqualityComparer);
+                _comparer = (IEqualityComparer<TKey>)((NonRandomizedStringEqualityComparer)_comparer).GetRandomizedEqualityComparer();
+
                 for (int i = 0; i < count; i++)
                 {
                     if (entries[i].next >= -1)
                     {
-                        Debug.Assert(_comparer == null);
-                        entries[i].hashCode = (uint)entries[i].key.GetHashCode();
+                        entries[i].hashCode = (uint)_comparer.GetHashCode(entries[i].key);
                     }
                 }
+
+                if (ReferenceEquals(_comparer, EqualityComparer<TKey>.Default))
+                {
+                    _comparer = null;
+                }
             }
 
             // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
index f86c16f..394b0ff 100644 (file)
@@ -60,10 +60,24 @@ namespace System.Collections.Generic
                 _comparer = comparer;
             }
 
-            if (typeof(T) == typeof(string) && _comparer == null)
+            // Special-case EqualityComparer<string>.Default, StringComparer.Ordinal, and StringComparer.OrdinalIgnoreCase.
+            // We use a non-randomized comparer for improved perf, falling back to a randomized comparer if the
+            // hash buckets become unbalanced.
+
+            if (typeof(T) == typeof(string))
             {
-                // To start, move off default comparer for string which is randomized.
-                _comparer = (IEqualityComparer<T>)NonRandomizedStringEqualityComparer.Default;
+                if (_comparer is null)
+                {
+                    _comparer = (IEqualityComparer<T>)NonRandomizedStringEqualityComparer.WrappedAroundDefaultComparer;
+                }
+                else if (ReferenceEquals(_comparer, StringComparer.Ordinal))
+                {
+                    _comparer = (IEqualityComparer<T>)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinal;
+                }
+                else if (ReferenceEquals(_comparer, StringComparer.OrdinalIgnoreCase))
+                {
+                    _comparer = (IEqualityComparer<T>)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinalIgnoreCase;
+                }
             }
         }
 
@@ -380,7 +394,7 @@ namespace System.Collections.Generic
             }
 
             info.AddValue(VersionName, _version); // need to serialize version to avoid problems with serializing while enumerating
-            info.AddValue(ComparerName, _comparer ?? EqualityComparer<T>.Default, typeof(IEqualityComparer<T>));
+            info.AddValue(ComparerName, Comparer, typeof(IEqualityComparer<T>));
             info.AddValue(CapacityName, _buckets == null ? 0 : _buckets.Length);
 
             if (_buckets != null)
@@ -912,10 +926,20 @@ namespace System.Collections.Generic
         }
 
         /// <summary>Gets the <see cref="IEqualityComparer"/> object that is used to determine equality for the values in the set.</summary>
-        public IEqualityComparer<T> Comparer =>
-            (_comparer == null || _comparer is NonRandomizedStringEqualityComparer) ?
-                EqualityComparer<T>.Default :
-                _comparer;
+        public IEqualityComparer<T> Comparer
+        {
+            get
+            {
+                if (typeof(T) == typeof(string))
+                {
+                    return (IEqualityComparer<T>)IInternalStringEqualityComparer.GetUnderlyingEqualityComparer((IEqualityComparer<string?>?)_comparer);
+                }
+                else
+                {
+                    return _comparer ?? EqualityComparer<T>.Default;
+                }
+            }
+        }
 
         /// <summary>Ensures that this hash set can hold the specified number of elements without growing.</summary>
         public int EnsureCapacity(int capacity)
@@ -957,15 +981,22 @@ namespace System.Collections.Generic
 
             if (!typeof(T).IsValueType && forceNewHashCodes)
             {
+                Debug.Assert(_comparer is NonRandomizedStringEqualityComparer);
+                _comparer = (IEqualityComparer<T>)((NonRandomizedStringEqualityComparer)_comparer).GetRandomizedEqualityComparer();
+
                 for (int i = 0; i < count; i++)
                 {
                     ref Entry entry = ref entries[i];
                     if (entry.Next >= -1)
                     {
-                        Debug.Assert(_comparer == null);
-                        entry.HashCode = entry.Value != null ? entry.Value!.GetHashCode() : 0;
+                        entry.HashCode = entry.Value != null ? _comparer!.GetHashCode(entry.Value) : 0;
                     }
                 }
+
+                if (ReferenceEquals(_comparer, EqualityComparer<T>.Default))
+                {
+                    _comparer = null;
+                }
             }
 
             // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
@@ -1185,7 +1216,6 @@ namespace System.Collections.Generic
             {
                 // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing
                 // i.e. EqualityComparer<string>.Default.
-                _comparer = null;
                 Resize(entries.Length, forceNewHashCodes: true);
                 location = FindItemIndex(value);
                 Debug.Assert(location >= 0);
diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/IInternalStringEqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/IInternalStringEqualityComparer.cs
new file mode 100644 (file)
index 0000000..28c2ab9
--- /dev/null
@@ -0,0 +1,36 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Collections.Generic
+{
+    /// <summary>
+    /// Represents an <see cref="IEqualityComparer{String}"/> that's meant for internal
+    /// use only and isn't intended to be serialized or returned back to the user.
+    /// Use the <see cref="GetUnderlyingEqualityComparer"/> method to get the object
+    /// that should actually be returned to the caller.
+    /// </summary>
+    internal interface IInternalStringEqualityComparer : IEqualityComparer<string?>
+    {
+        IEqualityComparer<string?> GetUnderlyingEqualityComparer();
+
+        /// <summary>
+        /// Unwraps the internal equality comparer, if proxied.
+        /// Otherwise returns the equality comparer itself or its default equivalent.
+        /// </summary>
+        internal static IEqualityComparer<string?> GetUnderlyingEqualityComparer(IEqualityComparer<string?>? outerComparer)
+        {
+            if (outerComparer is null)
+            {
+                return EqualityComparer<string?>.Default;
+            }
+            else if (outerComparer is IInternalStringEqualityComparer internalComparer)
+            {
+                return internalComparer.GetUnderlyingEqualityComparer();
+            }
+            else
+            {
+                return outerComparer;
+            }
+        }
+    }
+}
index b75ebd5..0ec1a50 100644 (file)
@@ -1,7 +1,10 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Diagnostics;
+using System.Globalization;
 using System.Runtime.Serialization;
+using Internal.Runtime.CompilerServices;
 
 namespace System.Collections.Generic
 {
@@ -11,23 +14,100 @@ namespace System.Collections.Generic
     // randomized string hashing.
     [Serializable] // Required for compatibility with .NET Core 2.0 as we exposed the NonRandomizedStringEqualityComparer inside the serialization blob
     // Needs to be public to support binary serialization compatibility
-    public sealed class NonRandomizedStringEqualityComparer : EqualityComparer<string?>, ISerializable
+    public class NonRandomizedStringEqualityComparer : IEqualityComparer<string?>, IInternalStringEqualityComparer, ISerializable
     {
-        internal static new IEqualityComparer<string?> Default { get; } = new NonRandomizedStringEqualityComparer();
+        // Dictionary<...>.Comparer and similar methods need to return the original IEqualityComparer
+        // that was passed in to the ctor. The caller chooses one of these singletons so that the
+        // GetUnderlyingEqualityComparer method can return the correct value.
 
-        private NonRandomizedStringEqualityComparer() { }
+        internal static readonly NonRandomizedStringEqualityComparer WrappedAroundDefaultComparer = new OrdinalComparer(EqualityComparer<string?>.Default);
+        internal static readonly NonRandomizedStringEqualityComparer WrappedAroundStringComparerOrdinal = new OrdinalComparer(StringComparer.Ordinal);
+        internal static readonly NonRandomizedStringEqualityComparer WrappedAroundStringComparerOrdinalIgnoreCase = new OrdinalIgnoreCaseComparer(StringComparer.OrdinalIgnoreCase);
+
+        private readonly IEqualityComparer<string?> _underlyingComparer;
+
+        private NonRandomizedStringEqualityComparer(IEqualityComparer<string?> underlyingComparer)
+        {
+            Debug.Assert(underlyingComparer != null);
+            _underlyingComparer = underlyingComparer;
+        }
 
         // This is used by the serialization engine.
-        private NonRandomizedStringEqualityComparer(SerializationInfo information, StreamingContext context) { }
+        protected NonRandomizedStringEqualityComparer(SerializationInfo information, StreamingContext context)
+            : this(EqualityComparer<string?>.Default)
+        {
+        }
+
+        public virtual bool Equals(string? x, string? y)
+        {
+            // This instance may have been deserialized into a class that doesn't guarantee
+            // these parameters are non-null. Can't short-circuit the null checks.
+
+            return string.Equals(x, y);
+        }
+
+        public virtual int GetHashCode(string? obj)
+        {
+            // This instance may have been deserialized into a class that doesn't guarantee
+            // these parameters are non-null. Can't short-circuit the null checks.
+
+            return obj?.GetNonRandomizedHashCode() ?? 0;
+        }
 
-        public sealed override bool Equals(string? x, string? y) => string.Equals(x, y);
+        internal virtual RandomizedStringEqualityComparer GetRandomizedEqualityComparer()
+        {
+            return RandomizedStringEqualityComparer.Create(_underlyingComparer, ignoreCase: false);
+        }
 
-        public sealed override int GetHashCode(string? obj) => obj?.GetNonRandomizedHashCode() ?? 0;
+        // Gets the comparer that should be returned back to the caller when querying the
+        // ICollection.Comparer property. Also used for serialization purposes.
+        public virtual IEqualityComparer<string?> GetUnderlyingEqualityComparer() => _underlyingComparer;
 
-        public void GetObjectData(SerializationInfo info, StreamingContext context)
+        void ISerializable.GetObjectData(SerializationInfo info, StreamingContext context)
         {
             // We are doing this to stay compatible with .NET Framework.
+            // Our own collection types will never call this (since this type is a wrapper),
+            // but perhaps third-party collection types could try serializing an instance
+            // of this.
             info.SetType(typeof(GenericEqualityComparer<string>));
         }
+
+        private sealed class OrdinalComparer : NonRandomizedStringEqualityComparer
+        {
+            internal OrdinalComparer(IEqualityComparer<string?> wrappedComparer)
+                : base(wrappedComparer)
+            {
+            }
+
+            public override bool Equals(string? x, string? y) => string.Equals(x, y);
+
+            public override int GetHashCode(string? obj)
+            {
+                Debug.Assert(obj != null, "This implementation is only called from first-party collection types that guarantee non-null parameters.");
+                return obj.GetNonRandomizedHashCode();
+            }
+
+        }
+
+        private sealed class OrdinalIgnoreCaseComparer : NonRandomizedStringEqualityComparer
+        {
+            internal OrdinalIgnoreCaseComparer(IEqualityComparer<string?> wrappedComparer)
+                : base(wrappedComparer)
+            {
+            }
+
+            public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y);
+
+            public override int GetHashCode(string? obj)
+            {
+                Debug.Assert(obj != null, "This implementation is only called from first-party collection types that guarantee non-null parameters.");
+                return obj.GetNonRandomizedHashCodeOrdinalIgnoreCase();
+            }
+
+            internal override RandomizedStringEqualityComparer GetRandomizedEqualityComparer()
+            {
+                return RandomizedStringEqualityComparer.Create(_underlyingComparer, ignoreCase: true);
+            }
+        }
     }
 }
diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs
new file mode 100644 (file)
index 0000000..168959d
--- /dev/null
@@ -0,0 +1,124 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Internal.Runtime.CompilerServices;
+
+namespace System.Collections.Generic
+{
+    /// <summary>
+    /// A randomized <see cref="EqualityComparer{String}"/> which uses a different seed on each
+    /// construction as a general good hygiene + defense-in-depth mechanism. This implementation
+    /// *does not* need to stay in sync with <see cref="string.GetHashCode"/>, which for stability
+    /// is required to use an app-global seed.
+    /// </summary>
+    internal abstract class RandomizedStringEqualityComparer : EqualityComparer<string?>, IInternalStringEqualityComparer
+    {
+        private readonly MarvinSeed _seed;
+        private readonly IEqualityComparer<string?> _underlyingComparer;
+
+        private unsafe RandomizedStringEqualityComparer(IEqualityComparer<string?> underlyingComparer)
+        {
+            _underlyingComparer = underlyingComparer;
+
+            fixed (MarvinSeed* seed = &_seed)
+            {
+                Interop.GetRandomBytes((byte*)seed, sizeof(MarvinSeed));
+            }
+        }
+
+        internal static RandomizedStringEqualityComparer Create(IEqualityComparer<string?> underlyingComparer, bool ignoreCase)
+        {
+            if (!ignoreCase)
+            {
+                return new OrdinalComparer(underlyingComparer);
+            }
+            else
+            {
+                return new OrdinalIgnoreCaseComparer(underlyingComparer);
+            }
+        }
+
+        public IEqualityComparer<string?> GetUnderlyingEqualityComparer() => _underlyingComparer;
+
+        private struct MarvinSeed
+        {
+            internal uint p0;
+            internal uint p1;
+        }
+
+        private sealed class OrdinalComparer : RandomizedStringEqualityComparer
+        {
+            internal OrdinalComparer(IEqualityComparer<string?> wrappedComparer)
+                : base(wrappedComparer)
+            {
+            }
+
+            public override bool Equals(string? x, string? y) => string.Equals(x, y);
+
+            public override int GetHashCode(string? obj)
+            {
+                if (obj is null)
+                {
+                    return 0;
+                }
+
+                // The Ordinal version of Marvin32 operates over bytes.
+                // The multiplication from # chars -> # bytes will never integer overflow.
+                return Marvin.ComputeHash32(
+                    ref Unsafe.As<char, byte>(ref obj.GetRawStringData()),
+                    (uint)obj.Length * 2,
+                    _seed.p0, _seed.p1);
+            }
+        }
+
+        private sealed class OrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer
+        {
+            internal OrdinalIgnoreCaseComparer(IEqualityComparer<string?> wrappedComparer)
+                : base(wrappedComparer)
+            {
+            }
+
+            public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y);
+
+            public override int GetHashCode(string? obj)
+            {
+                if (obj is null)
+                {
+                    return 0;
+                }
+
+                // The Ordinal version of Marvin32 operates over bytes, so convert
+                // char count -> byte count. Guaranteed not to integer overflow.
+                return Marvin.ComputeHash32(
+                    ref Unsafe.As<char, byte>(ref obj.GetRawStringData()),
+                    (uint)obj.Length * sizeof(char),
+                    _seed.p0, _seed.p1);
+            }
+        }
+
+        private sealed class RandomizedOrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer
+        {
+            internal RandomizedOrdinalIgnoreCaseComparer(IEqualityComparer<string?> underlyingComparer)
+                : base(underlyingComparer)
+            {
+            }
+
+            public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y);
+
+            public override int GetHashCode(string? obj)
+            {
+                if (obj is null)
+                {
+                    return 0;
+                }
+
+                // The OrdinalIgnoreCase version of Marvin32 operates over chars,
+                // so pass in the char count directly.
+                return Marvin.ComputeHash32OrdinalIgnoreCase(
+                    ref obj.GetRawStringData(),
+                    obj.Length,
+                    _seed.p0, _seed.p1);
+            }
+        }
+    }
+}
index d937011..40f6317 100644 (file)
@@ -42,7 +42,27 @@ namespace System
             return SpanHelpers.SequenceCompareTo(ref Unsafe.Add(ref strA.GetRawStringData(), indexA), countA, ref Unsafe.Add(ref strB.GetRawStringData(), indexB), countB);
         }
 
-        private static bool EqualsOrdinalIgnoreCase(string strA, string strB)
+        internal static bool EqualsOrdinalIgnoreCase(string? strA, string? strB)
+        {
+            if (ReferenceEquals(strA, strB))
+            {
+                return true;
+            }
+
+            if (strA is null || strB is null)
+            {
+                return false;
+            }
+
+            if (strA.Length != strB.Length)
+            {
+                return false;
+            }
+
+            return EqualsOrdinalIgnoreCaseNoLengthCheck(strA, strB);
+        }
+
+        private static bool EqualsOrdinalIgnoreCaseNoLengthCheck(string strA, string strB)
         {
             Debug.Assert(strA.Length == strB.Length);
 
@@ -645,7 +665,7 @@ namespace System
                     if (this.Length != value.Length)
                         return false;
 
-                    return EqualsOrdinalIgnoreCase(this, value);
+                    return EqualsOrdinalIgnoreCaseNoLengthCheck(this, value);
 
                 default:
                     throw new ArgumentException(SR.NotSupported_StringComparison, nameof(comparisonType));
@@ -701,7 +721,7 @@ namespace System
                     if (a.Length != b.Length)
                         return false;
 
-                    return EqualsOrdinalIgnoreCase(a, b);
+                    return EqualsOrdinalIgnoreCaseNoLengthCheck(a, b);
 
                 default:
                     throw new ArgumentException(SR.NotSupported_StringComparison, nameof(comparisonType));
@@ -811,6 +831,46 @@ namespace System
             }
         }
 
+        // Use this if and only if 'Denial of Service' attacks are not a concern (i.e. never used for free-form user input),
+        // or are otherwise mitigated
+        internal unsafe int GetNonRandomizedHashCodeOrdinalIgnoreCase()
+        {
+            fixed (char* src = &_firstChar)
+            {
+                Debug.Assert(src[this.Length] == '\0', "src[this.Length] == '\\0'");
+                Debug.Assert(((int)src) % 4 == 0, "Managed string should start at 4 bytes boundary");
+
+                uint hash1 = (5381 << 16) + 5381;
+                uint hash2 = hash1;
+
+                uint* ptr = (uint*)src;
+                int length = this.Length;
+
+                // We "normalize to lowercase" every char by ORing with 0x0020. This casts
+                // a very wide net because it will change, e.g., '^' to '~'. But that should
+                // be ok because we expect this to be very rare in practice.
+
+                const uint NormalizeToLowercase = 0x0020_0020u; // valid both for big-endian and for little-endian
+
+                while (length > 2)
+                {
+                    length -= 4;
+                    // Where length is 4n-1 (e.g. 3,7,11,15,19) this additionally consumes the null terminator
+                    hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (ptr[0] | NormalizeToLowercase);
+                    hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (ptr[1] | NormalizeToLowercase);
+                    ptr += 2;
+                }
+
+                if (length > 0)
+                {
+                    // Where length is 4n-3 (e.g. 1,5,9,13,17) this additionally consumes the null terminator
+                    hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (ptr[0] | NormalizeToLowercase);
+                }
+
+                return (int)(hash1 + (hash2 * 1566083941));
+            }
+        }
+
         // Determines whether a specified string is a prefix of the current instance
         //
         public bool StartsWith(string value)
index e935c3d..39791cf 100644 (file)
@@ -12,14 +12,9 @@ namespace System
     [System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
     public abstract class StringComparer : IComparer, IEqualityComparer, IComparer<string?>, IEqualityComparer<string?>
     {
-        private static readonly CultureAwareComparer s_invariantCulture = new CultureAwareComparer(CultureInfo.InvariantCulture, CompareOptions.None);
-        private static readonly CultureAwareComparer s_invariantCultureIgnoreCase = new CultureAwareComparer(CultureInfo.InvariantCulture, CompareOptions.IgnoreCase);
-        private static readonly OrdinalCaseSensitiveComparer s_ordinal = new OrdinalCaseSensitiveComparer();
-        private static readonly OrdinalIgnoreCaseComparer s_ordinalIgnoreCase = new OrdinalIgnoreCaseComparer();
+        public static StringComparer InvariantCulture => CultureAwareComparer.InvariantCaseSensitiveInstance;
 
-        public static StringComparer InvariantCulture => s_invariantCulture;
-
-        public static StringComparer InvariantCultureIgnoreCase => s_invariantCultureIgnoreCase;
+        public static StringComparer InvariantCultureIgnoreCase => CultureAwareComparer.InvariantIgnoreCaseInstance;
 
         public static StringComparer CurrentCulture =>
             new CultureAwareComparer(CultureInfo.CurrentCulture, CompareOptions.None);
@@ -27,9 +22,9 @@ namespace System
         public static StringComparer CurrentCultureIgnoreCase =>
             new CultureAwareComparer(CultureInfo.CurrentCulture, CompareOptions.IgnoreCase);
 
-        public static StringComparer Ordinal => s_ordinal;
+        public static StringComparer Ordinal => OrdinalCaseSensitiveComparer.Instance;
 
-        public static StringComparer OrdinalIgnoreCase => s_ordinalIgnoreCase;
+        public static StringComparer OrdinalIgnoreCase => OrdinalIgnoreCaseComparer.Instance;
 
         // Convert a StringComparison to a StringComparer
         public static StringComparer FromComparison(StringComparison comparisonType)
@@ -128,6 +123,9 @@ namespace System
     [System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
     public sealed class CultureAwareComparer : StringComparer, ISerializable
     {
+        internal static readonly CultureAwareComparer InvariantCaseSensitiveInstance = new CultureAwareComparer(CompareInfo.Invariant, CompareOptions.None);
+        internal static readonly CultureAwareComparer InvariantIgnoreCaseInstance = new CultureAwareComparer(CompareInfo.Invariant, CompareOptions.IgnoreCase);
+
         private const CompareOptions ValidCompareMaskOffFlags = ~(CompareOptions.IgnoreCase | CompareOptions.IgnoreSymbols | CompareOptions.IgnoreNonSpace | CompareOptions.IgnoreWidth | CompareOptions.IgnoreKanaType | CompareOptions.StringSort);
 
         private readonly CompareInfo _compareInfo; // Do not rename (binary serialization)
@@ -286,7 +284,9 @@ namespace System
     [Serializable]
     internal sealed class OrdinalCaseSensitiveComparer : OrdinalComparer, ISerializable
     {
-        public OrdinalCaseSensitiveComparer() : base(false)
+        internal static readonly OrdinalCaseSensitiveComparer Instance = new OrdinalCaseSensitiveComparer();
+
+        private OrdinalCaseSensitiveComparer() : base(false)
         {
         }
 
@@ -313,7 +313,9 @@ namespace System
     [Serializable]
     internal sealed class OrdinalIgnoreCaseComparer : OrdinalComparer, ISerializable
     {
-        public OrdinalIgnoreCaseComparer() : base(true)
+        internal static readonly OrdinalIgnoreCaseComparer Instance = new OrdinalIgnoreCaseComparer();
+
+        private OrdinalIgnoreCaseComparer() : base(true)
         {
         }
 
index b2f4b4c..b600677 100644 (file)
Binary files a/src/libraries/System.Resources.Extensions/tests/TestData.resources and b/src/libraries/System.Resources.Extensions/tests/TestData.resources differ