From: Levi Broderick Date: Thu, 6 Aug 2020 04:16:21 +0000 (-0700) Subject: Improve dictionary & hashset lookup perf for OrdinalIgnoreCase (#36252) X-Git-Tag: submit/tizen/20210909.063632~6189 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=189e1aa8f91632c196fe0e6cb1410a85bb7d2283;p=platform%2Fupstream%2Fdotnet%2Fruntime.git Improve dictionary & hashset lookup perf for OrdinalIgnoreCase (#36252) --- diff --git a/src/libraries/System.Collections/tests/Generic/Dictionary/Dictionary.Generic.Tests.cs b/src/libraries/System.Collections/tests/Generic/Dictionary/Dictionary.Generic.Tests.cs index 3de7252..32f277e 100644 --- a/src/libraries/System.Collections/tests/Generic/Dictionary/Dictionary.Generic.Tests.cs +++ b/src/libraries/System.Collections/tests/Generic/Dictionary/Dictionary.Generic.Tests.cs @@ -3,7 +3,9 @@ using Common.System; using System.Collections.Generic; +using System.Globalization; using System.Linq; +using System.Runtime.Serialization; using Xunit; namespace System.Collections.Tests @@ -610,5 +612,37 @@ namespace System.Collections.Tests } #endregion + + #region Non-randomized comparers + [Fact] + public void Dictionary_Comparer_NonRandomizedStringComparers() + { + RunTest(null); + RunTest(EqualityComparer.Default); + RunTest(StringComparer.Ordinal); + RunTest(StringComparer.OrdinalIgnoreCase); + RunTest(StringComparer.InvariantCulture); + RunTest(StringComparer.InvariantCultureIgnoreCase); + RunTest(StringComparer.Create(CultureInfo.InvariantCulture, ignoreCase: false)); + RunTest(StringComparer.Create(CultureInfo.InvariantCulture, ignoreCase: true)); + + void RunTest(IEqualityComparer comparer) + { + // First, instantiate the dictionary and check its Comparer property + + Dictionary dict = new Dictionary(comparer); + object expected = comparer ?? EqualityComparer.Default; + + Assert.Same(expected, dict.Comparer); + + // Then pretend to serialize the dictionary and check the stored Comparer instance + + SerializationInfo si = new SerializationInfo(typeof(Dictionary), new FormatterConverter()); + dict.GetObjectData(si, new StreamingContext(StreamingContextStates.All)); + + Assert.Same(expected, si.GetValue("Comparer", typeof(IEqualityComparer))); + } + } + #endregion } } diff --git a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs index bc83088..934e57a 100644 --- a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs +++ b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs @@ -2,6 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Numerics; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Runtime.Serialization; using Xunit; namespace System.Collections.Tests @@ -37,5 +41,238 @@ namespace System.Collections.Tests dictionary.Remove(key); } } + + [Fact] + public static void ComparerImplementations_Dictionary_WithWellKnownStringComparers() + { + Type nonRandomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true); + Type nonRandomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true); + Type randomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true); + Type randomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true); + + // null comparer + + RunDictionaryTest( + equalityComparer: null, + expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default.GetType(), + expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + + // EqualityComparer.Default comparer + + RunDictionaryTest( + equalityComparer: EqualityComparer.Default, + expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default.GetType(), + expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + + // Ordinal comparer + + RunDictionaryTest( + equalityComparer: StringComparer.Ordinal, + expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal.GetType(), + expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + + // OrdinalIgnoreCase comparer + + RunDictionaryTest( + equalityComparer: StringComparer.OrdinalIgnoreCase, + expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase.GetType(), + expectedComparerAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType); + + // linguistic comparer (not optimized) + + RunDictionaryTest( + equalityComparer: StringComparer.InvariantCulture, + expectedInternalComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), + expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), + expectedComparerAfterCollisionThreshold: StringComparer.InvariantCulture.GetType()); + + static void RunDictionaryTest( + IEqualityComparer equalityComparer, + Type expectedInternalComparerBeforeCollisionThreshold, + Type expectedPublicComparerBeforeCollisionThreshold, + Type expectedComparerAfterCollisionThreshold) + { + RunCollectionTestCommon( + () => new Dictionary(equalityComparer), + (dictionary, key) => dictionary.Add(key, null), + (dictionary, key) => dictionary.ContainsKey(key), + dictionary => dictionary.Comparer, + expectedInternalComparerBeforeCollisionThreshold, + expectedPublicComparerBeforeCollisionThreshold, + expectedComparerAfterCollisionThreshold); + } + } + + [Fact] + public static void ComparerImplementations_HashSet_WithWellKnownStringComparers() + { + Type nonRandomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true); + Type nonRandomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true); + Type randomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true); + Type randomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true); + + // null comparer + + RunHashSetTest( + equalityComparer: null, + expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default.GetType(), + expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + + // EqualityComparer.Default comparer + + RunHashSetTest( + equalityComparer: EqualityComparer.Default, + expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default.GetType(), + expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + + // Ordinal comparer + + RunHashSetTest( + equalityComparer: StringComparer.Ordinal, + expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal.GetType(), + expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + + // OrdinalIgnoreCase comparer + + RunHashSetTest( + equalityComparer: StringComparer.OrdinalIgnoreCase, + expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase.GetType(), + expectedComparerAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType); + + // linguistic comparer (not optimized) + + RunHashSetTest( + equalityComparer: StringComparer.InvariantCulture, + expectedInternalComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), + expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), + expectedComparerAfterCollisionThreshold: StringComparer.InvariantCulture.GetType()); + + static void RunHashSetTest( + IEqualityComparer equalityComparer, + Type expectedInternalComparerBeforeCollisionThreshold, + Type expectedPublicComparerBeforeCollisionThreshold, + Type expectedComparerAfterCollisionThreshold) + { + RunCollectionTestCommon( + () => new HashSet(equalityComparer), + (set, key) => Assert.True(set.Add(key)), + (set, key) => set.Contains(key), + set => set.Comparer, + expectedInternalComparerBeforeCollisionThreshold, + expectedPublicComparerBeforeCollisionThreshold, + expectedComparerAfterCollisionThreshold); + } + } + + private static void RunCollectionTestCommon( + Func collectionFactory, + Action addKeyCallback, + Func containsKeyCallback, + Func> getComparerCallback, + Type expectedInternalComparerBeforeCollisionThreshold, + Type expectedPublicComparerBeforeCollisionThreshold, + Type expectedComparerAfterCollisionThreshold) + { + TCollection collection = collectionFactory(); + List allKeys = new List(); + + const int StartOfRange = 0xE020; // use the Unicode Private Use range to avoid accidentally creating strings that really do compare as equal OrdinalIgnoreCase + const int Stride = 0x40; // to ensure we don't accidentally reset the 0x20 bit of the seed, which is used to negate OrdinalIgnoreCase effects + + // First, go right up to the collision threshold, but don't exceed it. + + for (int i = 0; i < 100; i++) + { + string newKey = GenerateCollidingString(i * Stride + StartOfRange); + Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal + Assert.Equal(0x24716ca0, _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(newKey)); // ensure has a zero hash code OrdinalIgnoreCase + + addKeyCallback(collection, newKey); + allKeys.Add(newKey); + } + + FieldInfo internalComparerField = collection.GetType().GetField("_comparer", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(internalComparerField); + + Assert.Equal(expectedInternalComparerBeforeCollisionThreshold, internalComparerField.GetValue(collection)?.GetType()); + Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection).GetType()); + + // Now exceed the collision threshold, which should rebucket entries. + // Continue adding a few more entries to ensure we didn't corrupt internal state. + + for (int i = 100; i < 110; i++) + { + string newKey = GenerateCollidingString(i * Stride + StartOfRange); + Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal + Assert.Equal(0x24716ca0, _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(newKey)); // ensure has a zero hash code OrdinalIgnoreCase + + addKeyCallback(collection, newKey); + allKeys.Add(newKey); + } + + Assert.Equal(expectedComparerAfterCollisionThreshold, internalComparerField.GetValue(collection)?.GetType()); + Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection).GetType()); // shouldn't change this return value after collision threshold met + + // And validate that all strings are present in the dictionary. + + foreach (string key in allKeys) + { + Assert.True(containsKeyCallback(collection, key)); + } + + // Also make sure we didn't accidentally put the internal comparer in the serialized object data. + + collection = collectionFactory(); + SerializationInfo si = new SerializationInfo(collection.GetType(), new FormatterConverter()); + ((ISerializable)collection).GetObjectData(si, new StreamingContext()); + + object serializedComparer = si.GetValue("Comparer", typeof(IEqualityComparer)); + Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, serializedComparer.GetType()); + } + + private static Lazy> _lazyGetNonRandomizedHashCodeDel = new Lazy>( + () => GetStringHashCodeOpenDelegate("GetNonRandomizedHashCode")); + + private static Lazy> _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel = new Lazy>( + () => GetStringHashCodeOpenDelegate("GetNonRandomizedHashCodeOrdinalIgnoreCase")); + + // Generates a string with a well-known non-randomized hash code: + // - string.GetNonRandomizedHashCode returns 0. + // - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0x24716ca0. + // Provide a different seed to produce a different string. + private static string GenerateCollidingString(int seed) + { + return string.Create(8, seed, (span, seed) => + { + Span asBytes = MemoryMarshal.AsBytes(span); + + uint hash1 = (5381 << 16) + 5381; + uint hash2 = BitOperations.RotateLeft(hash1, 5) + hash1; + + MemoryMarshal.Write(asBytes, ref seed); + MemoryMarshal.Write(asBytes.Slice(4), ref hash2); // set hash2 := 0 (for Ordinal) + + hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (uint)seed; + hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1); + + MemoryMarshal.Write(asBytes.Slice(8), ref hash1); // set hash1 := 0 (for Ordinal) + }); + } + + private static Func GetStringHashCodeOpenDelegate(string methodName) + { + MethodInfo method = typeof(string).GetMethod(methodName, BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + return method.CreateDelegate>(target: null); // create open delegate unbound to 'this' + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems index 1715d43..4278e25 100644 --- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems +++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems @@ -169,6 +169,7 @@ + @@ -179,6 +180,7 @@ + diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs index ac4fd58..6c5682e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs @@ -59,10 +59,24 @@ namespace System.Collections.Generic _comparer = comparer; } - if (typeof(TKey) == typeof(string) && _comparer == null) + // Special-case EqualityComparer.Default, StringComparer.Ordinal, and StringComparer.OrdinalIgnoreCase. + // We use a non-randomized comparer for improved perf, falling back to a randomized comparer if the + // hash buckets become unbalanced. + + if (typeof(TKey) == typeof(string)) { - // To start, move off default comparer for string which is randomised - _comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.Default; + if (_comparer is null) + { + _comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.WrappedAroundDefaultComparer; + } + else if (ReferenceEquals(_comparer, StringComparer.Ordinal)) + { + _comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinal; + } + else if (ReferenceEquals(_comparer, StringComparer.OrdinalIgnoreCase)) + { + _comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinalIgnoreCase; + } } } @@ -125,10 +139,20 @@ namespace System.Collections.Generic HashHelpers.SerializationInfoTable.Add(this, info); } - public IEqualityComparer Comparer => - (_comparer == null || _comparer is NonRandomizedStringEqualityComparer) ? - EqualityComparer.Default : - _comparer; + public IEqualityComparer Comparer + { + get + { + if (typeof(TKey) == typeof(string)) + { + return (IEqualityComparer)IInternalStringEqualityComparer.GetUnderlyingEqualityComparer((IEqualityComparer?)_comparer); + } + else + { + return _comparer ?? EqualityComparer.Default; + } + } + } public int Count => _count - _freeCount; @@ -299,7 +323,7 @@ namespace System.Collections.Generic } info.AddValue(VersionName, _version); - info.AddValue(ComparerName, _comparer ?? EqualityComparer.Default, typeof(IEqualityComparer)); + info.AddValue(ComparerName, Comparer, typeof(IEqualityComparer)); info.AddValue(HashSizeName, _buckets == null ? 0 : _buckets.Length); // This is the length of the bucket array if (_buckets != null) @@ -633,7 +657,6 @@ namespace System.Collections.Generic { // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing // i.e. EqualityComparer.Default. - _comparer = null; Resize(entries.Length, true); } @@ -702,14 +725,21 @@ namespace System.Collections.Generic if (!typeof(TKey).IsValueType && forceNewHashCodes) { + Debug.Assert(_comparer is NonRandomizedStringEqualityComparer); + _comparer = (IEqualityComparer)((NonRandomizedStringEqualityComparer)_comparer).GetRandomizedEqualityComparer(); + for (int i = 0; i < count; i++) { if (entries[i].next >= -1) { - Debug.Assert(_comparer == null); - entries[i].hashCode = (uint)entries[i].key.GetHashCode(); + entries[i].hashCode = (uint)_comparer.GetHashCode(entries[i].key); } } + + if (ReferenceEquals(_comparer, EqualityComparer.Default)) + { + _comparer = null; + } } // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/HashSet.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/HashSet.cs index f86c16f..394b0ff 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/HashSet.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/HashSet.cs @@ -60,10 +60,24 @@ namespace System.Collections.Generic _comparer = comparer; } - if (typeof(T) == typeof(string) && _comparer == null) + // Special-case EqualityComparer.Default, StringComparer.Ordinal, and StringComparer.OrdinalIgnoreCase. + // We use a non-randomized comparer for improved perf, falling back to a randomized comparer if the + // hash buckets become unbalanced. + + if (typeof(T) == typeof(string)) { - // To start, move off default comparer for string which is randomized. - _comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.Default; + if (_comparer is null) + { + _comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.WrappedAroundDefaultComparer; + } + else if (ReferenceEquals(_comparer, StringComparer.Ordinal)) + { + _comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinal; + } + else if (ReferenceEquals(_comparer, StringComparer.OrdinalIgnoreCase)) + { + _comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinalIgnoreCase; + } } } @@ -380,7 +394,7 @@ namespace System.Collections.Generic } info.AddValue(VersionName, _version); // need to serialize version to avoid problems with serializing while enumerating - info.AddValue(ComparerName, _comparer ?? EqualityComparer.Default, typeof(IEqualityComparer)); + info.AddValue(ComparerName, Comparer, typeof(IEqualityComparer)); info.AddValue(CapacityName, _buckets == null ? 0 : _buckets.Length); if (_buckets != null) @@ -912,10 +926,20 @@ namespace System.Collections.Generic } /// Gets the object that is used to determine equality for the values in the set. - public IEqualityComparer Comparer => - (_comparer == null || _comparer is NonRandomizedStringEqualityComparer) ? - EqualityComparer.Default : - _comparer; + public IEqualityComparer Comparer + { + get + { + if (typeof(T) == typeof(string)) + { + return (IEqualityComparer)IInternalStringEqualityComparer.GetUnderlyingEqualityComparer((IEqualityComparer?)_comparer); + } + else + { + return _comparer ?? EqualityComparer.Default; + } + } + } /// Ensures that this hash set can hold the specified number of elements without growing. public int EnsureCapacity(int capacity) @@ -957,15 +981,22 @@ namespace System.Collections.Generic if (!typeof(T).IsValueType && forceNewHashCodes) { + Debug.Assert(_comparer is NonRandomizedStringEqualityComparer); + _comparer = (IEqualityComparer)((NonRandomizedStringEqualityComparer)_comparer).GetRandomizedEqualityComparer(); + for (int i = 0; i < count; i++) { ref Entry entry = ref entries[i]; if (entry.Next >= -1) { - Debug.Assert(_comparer == null); - entry.HashCode = entry.Value != null ? entry.Value!.GetHashCode() : 0; + entry.HashCode = entry.Value != null ? _comparer!.GetHashCode(entry.Value) : 0; } } + + if (ReferenceEquals(_comparer, EqualityComparer.Default)) + { + _comparer = null; + } } // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails @@ -1185,7 +1216,6 @@ namespace System.Collections.Generic { // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing // i.e. EqualityComparer.Default. - _comparer = null; Resize(entries.Length, forceNewHashCodes: true); location = FindItemIndex(value); Debug.Assert(location >= 0); diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/IInternalStringEqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/IInternalStringEqualityComparer.cs new file mode 100644 index 0000000..28c2ab9 --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/IInternalStringEqualityComparer.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Collections.Generic +{ + /// + /// Represents an that's meant for internal + /// use only and isn't intended to be serialized or returned back to the user. + /// Use the method to get the object + /// that should actually be returned to the caller. + /// + internal interface IInternalStringEqualityComparer : IEqualityComparer + { + IEqualityComparer GetUnderlyingEqualityComparer(); + + /// + /// Unwraps the internal equality comparer, if proxied. + /// Otherwise returns the equality comparer itself or its default equivalent. + /// + internal static IEqualityComparer GetUnderlyingEqualityComparer(IEqualityComparer? outerComparer) + { + if (outerComparer is null) + { + return EqualityComparer.Default; + } + else if (outerComparer is IInternalStringEqualityComparer internalComparer) + { + return internalComparer.GetUnderlyingEqualityComparer(); + } + else + { + return outerComparer; + } + } + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/NonRandomizedStringEqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/NonRandomizedStringEqualityComparer.cs index b75ebd5..0ec1a50 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/NonRandomizedStringEqualityComparer.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/NonRandomizedStringEqualityComparer.cs @@ -1,7 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; +using System.Globalization; using System.Runtime.Serialization; +using Internal.Runtime.CompilerServices; namespace System.Collections.Generic { @@ -11,23 +14,100 @@ namespace System.Collections.Generic // randomized string hashing. [Serializable] // Required for compatibility with .NET Core 2.0 as we exposed the NonRandomizedStringEqualityComparer inside the serialization blob // Needs to be public to support binary serialization compatibility - public sealed class NonRandomizedStringEqualityComparer : EqualityComparer, ISerializable + public class NonRandomizedStringEqualityComparer : IEqualityComparer, IInternalStringEqualityComparer, ISerializable { - internal static new IEqualityComparer Default { get; } = new NonRandomizedStringEqualityComparer(); + // Dictionary<...>.Comparer and similar methods need to return the original IEqualityComparer + // that was passed in to the ctor. The caller chooses one of these singletons so that the + // GetUnderlyingEqualityComparer method can return the correct value. - private NonRandomizedStringEqualityComparer() { } + internal static readonly NonRandomizedStringEqualityComparer WrappedAroundDefaultComparer = new OrdinalComparer(EqualityComparer.Default); + internal static readonly NonRandomizedStringEqualityComparer WrappedAroundStringComparerOrdinal = new OrdinalComparer(StringComparer.Ordinal); + internal static readonly NonRandomizedStringEqualityComparer WrappedAroundStringComparerOrdinalIgnoreCase = new OrdinalIgnoreCaseComparer(StringComparer.OrdinalIgnoreCase); + + private readonly IEqualityComparer _underlyingComparer; + + private NonRandomizedStringEqualityComparer(IEqualityComparer underlyingComparer) + { + Debug.Assert(underlyingComparer != null); + _underlyingComparer = underlyingComparer; + } // This is used by the serialization engine. - private NonRandomizedStringEqualityComparer(SerializationInfo information, StreamingContext context) { } + protected NonRandomizedStringEqualityComparer(SerializationInfo information, StreamingContext context) + : this(EqualityComparer.Default) + { + } + + public virtual bool Equals(string? x, string? y) + { + // This instance may have been deserialized into a class that doesn't guarantee + // these parameters are non-null. Can't short-circuit the null checks. + + return string.Equals(x, y); + } + + public virtual int GetHashCode(string? obj) + { + // This instance may have been deserialized into a class that doesn't guarantee + // these parameters are non-null. Can't short-circuit the null checks. + + return obj?.GetNonRandomizedHashCode() ?? 0; + } - public sealed override bool Equals(string? x, string? y) => string.Equals(x, y); + internal virtual RandomizedStringEqualityComparer GetRandomizedEqualityComparer() + { + return RandomizedStringEqualityComparer.Create(_underlyingComparer, ignoreCase: false); + } - public sealed override int GetHashCode(string? obj) => obj?.GetNonRandomizedHashCode() ?? 0; + // Gets the comparer that should be returned back to the caller when querying the + // ICollection.Comparer property. Also used for serialization purposes. + public virtual IEqualityComparer GetUnderlyingEqualityComparer() => _underlyingComparer; - public void GetObjectData(SerializationInfo info, StreamingContext context) + void ISerializable.GetObjectData(SerializationInfo info, StreamingContext context) { // We are doing this to stay compatible with .NET Framework. + // Our own collection types will never call this (since this type is a wrapper), + // but perhaps third-party collection types could try serializing an instance + // of this. info.SetType(typeof(GenericEqualityComparer)); } + + private sealed class OrdinalComparer : NonRandomizedStringEqualityComparer + { + internal OrdinalComparer(IEqualityComparer wrappedComparer) + : base(wrappedComparer) + { + } + + public override bool Equals(string? x, string? y) => string.Equals(x, y); + + public override int GetHashCode(string? obj) + { + Debug.Assert(obj != null, "This implementation is only called from first-party collection types that guarantee non-null parameters."); + return obj.GetNonRandomizedHashCode(); + } + + } + + private sealed class OrdinalIgnoreCaseComparer : NonRandomizedStringEqualityComparer + { + internal OrdinalIgnoreCaseComparer(IEqualityComparer wrappedComparer) + : base(wrappedComparer) + { + } + + public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y); + + public override int GetHashCode(string? obj) + { + Debug.Assert(obj != null, "This implementation is only called from first-party collection types that guarantee non-null parameters."); + return obj.GetNonRandomizedHashCodeOrdinalIgnoreCase(); + } + + internal override RandomizedStringEqualityComparer GetRandomizedEqualityComparer() + { + return RandomizedStringEqualityComparer.Create(_underlyingComparer, ignoreCase: true); + } + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs new file mode 100644 index 0000000..168959d --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs @@ -0,0 +1,124 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Internal.Runtime.CompilerServices; + +namespace System.Collections.Generic +{ + /// + /// A randomized which uses a different seed on each + /// construction as a general good hygiene + defense-in-depth mechanism. This implementation + /// *does not* need to stay in sync with , which for stability + /// is required to use an app-global seed. + /// + internal abstract class RandomizedStringEqualityComparer : EqualityComparer, IInternalStringEqualityComparer + { + private readonly MarvinSeed _seed; + private readonly IEqualityComparer _underlyingComparer; + + private unsafe RandomizedStringEqualityComparer(IEqualityComparer underlyingComparer) + { + _underlyingComparer = underlyingComparer; + + fixed (MarvinSeed* seed = &_seed) + { + Interop.GetRandomBytes((byte*)seed, sizeof(MarvinSeed)); + } + } + + internal static RandomizedStringEqualityComparer Create(IEqualityComparer underlyingComparer, bool ignoreCase) + { + if (!ignoreCase) + { + return new OrdinalComparer(underlyingComparer); + } + else + { + return new OrdinalIgnoreCaseComparer(underlyingComparer); + } + } + + public IEqualityComparer GetUnderlyingEqualityComparer() => _underlyingComparer; + + private struct MarvinSeed + { + internal uint p0; + internal uint p1; + } + + private sealed class OrdinalComparer : RandomizedStringEqualityComparer + { + internal OrdinalComparer(IEqualityComparer wrappedComparer) + : base(wrappedComparer) + { + } + + public override bool Equals(string? x, string? y) => string.Equals(x, y); + + public override int GetHashCode(string? obj) + { + if (obj is null) + { + return 0; + } + + // The Ordinal version of Marvin32 operates over bytes. + // The multiplication from # chars -> # bytes will never integer overflow. + return Marvin.ComputeHash32( + ref Unsafe.As(ref obj.GetRawStringData()), + (uint)obj.Length * 2, + _seed.p0, _seed.p1); + } + } + + private sealed class OrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer + { + internal OrdinalIgnoreCaseComparer(IEqualityComparer wrappedComparer) + : base(wrappedComparer) + { + } + + public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y); + + public override int GetHashCode(string? obj) + { + if (obj is null) + { + return 0; + } + + // The Ordinal version of Marvin32 operates over bytes, so convert + // char count -> byte count. Guaranteed not to integer overflow. + return Marvin.ComputeHash32( + ref Unsafe.As(ref obj.GetRawStringData()), + (uint)obj.Length * sizeof(char), + _seed.p0, _seed.p1); + } + } + + private sealed class RandomizedOrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer + { + internal RandomizedOrdinalIgnoreCaseComparer(IEqualityComparer underlyingComparer) + : base(underlyingComparer) + { + } + + public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y); + + public override int GetHashCode(string? obj) + { + if (obj is null) + { + return 0; + } + + // The OrdinalIgnoreCase version of Marvin32 operates over chars, + // so pass in the char count directly. + return Marvin.ComputeHash32OrdinalIgnoreCase( + ref obj.GetRawStringData(), + obj.Length, + _seed.p0, _seed.p1); + } + } + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs b/src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs index d937011..40f6317 100644 --- a/src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs +++ b/src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs @@ -42,7 +42,27 @@ namespace System return SpanHelpers.SequenceCompareTo(ref Unsafe.Add(ref strA.GetRawStringData(), indexA), countA, ref Unsafe.Add(ref strB.GetRawStringData(), indexB), countB); } - private static bool EqualsOrdinalIgnoreCase(string strA, string strB) + internal static bool EqualsOrdinalIgnoreCase(string? strA, string? strB) + { + if (ReferenceEquals(strA, strB)) + { + return true; + } + + if (strA is null || strB is null) + { + return false; + } + + if (strA.Length != strB.Length) + { + return false; + } + + return EqualsOrdinalIgnoreCaseNoLengthCheck(strA, strB); + } + + private static bool EqualsOrdinalIgnoreCaseNoLengthCheck(string strA, string strB) { Debug.Assert(strA.Length == strB.Length); @@ -645,7 +665,7 @@ namespace System if (this.Length != value.Length) return false; - return EqualsOrdinalIgnoreCase(this, value); + return EqualsOrdinalIgnoreCaseNoLengthCheck(this, value); default: throw new ArgumentException(SR.NotSupported_StringComparison, nameof(comparisonType)); @@ -701,7 +721,7 @@ namespace System if (a.Length != b.Length) return false; - return EqualsOrdinalIgnoreCase(a, b); + return EqualsOrdinalIgnoreCaseNoLengthCheck(a, b); default: throw new ArgumentException(SR.NotSupported_StringComparison, nameof(comparisonType)); @@ -811,6 +831,46 @@ namespace System } } + // Use this if and only if 'Denial of Service' attacks are not a concern (i.e. never used for free-form user input), + // or are otherwise mitigated + internal unsafe int GetNonRandomizedHashCodeOrdinalIgnoreCase() + { + fixed (char* src = &_firstChar) + { + Debug.Assert(src[this.Length] == '\0', "src[this.Length] == '\\0'"); + Debug.Assert(((int)src) % 4 == 0, "Managed string should start at 4 bytes boundary"); + + uint hash1 = (5381 << 16) + 5381; + uint hash2 = hash1; + + uint* ptr = (uint*)src; + int length = this.Length; + + // We "normalize to lowercase" every char by ORing with 0x0020. This casts + // a very wide net because it will change, e.g., '^' to '~'. But that should + // be ok because we expect this to be very rare in practice. + + const uint NormalizeToLowercase = 0x0020_0020u; // valid both for big-endian and for little-endian + + while (length > 2) + { + length -= 4; + // Where length is 4n-1 (e.g. 3,7,11,15,19) this additionally consumes the null terminator + hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (ptr[0] | NormalizeToLowercase); + hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (ptr[1] | NormalizeToLowercase); + ptr += 2; + } + + if (length > 0) + { + // Where length is 4n-3 (e.g. 1,5,9,13,17) this additionally consumes the null terminator + hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (ptr[0] | NormalizeToLowercase); + } + + return (int)(hash1 + (hash2 * 1566083941)); + } + } + // Determines whether a specified string is a prefix of the current instance // public bool StartsWith(string value) diff --git a/src/libraries/System.Private.CoreLib/src/System/StringComparer.cs b/src/libraries/System.Private.CoreLib/src/System/StringComparer.cs index e935c3d..39791cf 100644 --- a/src/libraries/System.Private.CoreLib/src/System/StringComparer.cs +++ b/src/libraries/System.Private.CoreLib/src/System/StringComparer.cs @@ -12,14 +12,9 @@ namespace System [System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")] public abstract class StringComparer : IComparer, IEqualityComparer, IComparer, IEqualityComparer { - private static readonly CultureAwareComparer s_invariantCulture = new CultureAwareComparer(CultureInfo.InvariantCulture, CompareOptions.None); - private static readonly CultureAwareComparer s_invariantCultureIgnoreCase = new CultureAwareComparer(CultureInfo.InvariantCulture, CompareOptions.IgnoreCase); - private static readonly OrdinalCaseSensitiveComparer s_ordinal = new OrdinalCaseSensitiveComparer(); - private static readonly OrdinalIgnoreCaseComparer s_ordinalIgnoreCase = new OrdinalIgnoreCaseComparer(); + public static StringComparer InvariantCulture => CultureAwareComparer.InvariantCaseSensitiveInstance; - public static StringComparer InvariantCulture => s_invariantCulture; - - public static StringComparer InvariantCultureIgnoreCase => s_invariantCultureIgnoreCase; + public static StringComparer InvariantCultureIgnoreCase => CultureAwareComparer.InvariantIgnoreCaseInstance; public static StringComparer CurrentCulture => new CultureAwareComparer(CultureInfo.CurrentCulture, CompareOptions.None); @@ -27,9 +22,9 @@ namespace System public static StringComparer CurrentCultureIgnoreCase => new CultureAwareComparer(CultureInfo.CurrentCulture, CompareOptions.IgnoreCase); - public static StringComparer Ordinal => s_ordinal; + public static StringComparer Ordinal => OrdinalCaseSensitiveComparer.Instance; - public static StringComparer OrdinalIgnoreCase => s_ordinalIgnoreCase; + public static StringComparer OrdinalIgnoreCase => OrdinalIgnoreCaseComparer.Instance; // Convert a StringComparison to a StringComparer public static StringComparer FromComparison(StringComparison comparisonType) @@ -128,6 +123,9 @@ namespace System [System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")] public sealed class CultureAwareComparer : StringComparer, ISerializable { + internal static readonly CultureAwareComparer InvariantCaseSensitiveInstance = new CultureAwareComparer(CompareInfo.Invariant, CompareOptions.None); + internal static readonly CultureAwareComparer InvariantIgnoreCaseInstance = new CultureAwareComparer(CompareInfo.Invariant, CompareOptions.IgnoreCase); + private const CompareOptions ValidCompareMaskOffFlags = ~(CompareOptions.IgnoreCase | CompareOptions.IgnoreSymbols | CompareOptions.IgnoreNonSpace | CompareOptions.IgnoreWidth | CompareOptions.IgnoreKanaType | CompareOptions.StringSort); private readonly CompareInfo _compareInfo; // Do not rename (binary serialization) @@ -286,7 +284,9 @@ namespace System [Serializable] internal sealed class OrdinalCaseSensitiveComparer : OrdinalComparer, ISerializable { - public OrdinalCaseSensitiveComparer() : base(false) + internal static readonly OrdinalCaseSensitiveComparer Instance = new OrdinalCaseSensitiveComparer(); + + private OrdinalCaseSensitiveComparer() : base(false) { } @@ -313,7 +313,9 @@ namespace System [Serializable] internal sealed class OrdinalIgnoreCaseComparer : OrdinalComparer, ISerializable { - public OrdinalIgnoreCaseComparer() : base(true) + internal static readonly OrdinalIgnoreCaseComparer Instance = new OrdinalIgnoreCaseComparer(); + + private OrdinalIgnoreCaseComparer() : base(true) { } diff --git a/src/libraries/System.Resources.Extensions/tests/TestData.resources b/src/libraries/System.Resources.Extensions/tests/TestData.resources index b2f4b4c..b600677 100644 Binary files a/src/libraries/System.Resources.Extensions/tests/TestData.resources and b/src/libraries/System.Resources.Extensions/tests/TestData.resources differ