using Common.System;
using System.Collections.Generic;
+using System.Globalization;
using System.Linq;
+using System.Runtime.Serialization;
using Xunit;
namespace System.Collections.Tests
}
#endregion
+
+ #region Non-randomized comparers
+ [Fact]
+ public void Dictionary_Comparer_NonRandomizedStringComparers()
+ {
+ RunTest(null);
+ RunTest(EqualityComparer<string>.Default);
+ RunTest(StringComparer.Ordinal);
+ RunTest(StringComparer.OrdinalIgnoreCase);
+ RunTest(StringComparer.InvariantCulture);
+ RunTest(StringComparer.InvariantCultureIgnoreCase);
+ RunTest(StringComparer.Create(CultureInfo.InvariantCulture, ignoreCase: false));
+ RunTest(StringComparer.Create(CultureInfo.InvariantCulture, ignoreCase: true));
+
+ void RunTest(IEqualityComparer<string> comparer)
+ {
+ // First, instantiate the dictionary and check its Comparer property
+
+ Dictionary<string, object> dict = new Dictionary<string, object>(comparer);
+ object expected = comparer ?? EqualityComparer<string>.Default;
+
+ Assert.Same(expected, dict.Comparer);
+
+ // Then pretend to serialize the dictionary and check the stored Comparer instance
+
+ SerializationInfo si = new SerializationInfo(typeof(Dictionary<string, object>), new FormatterConverter());
+ dict.GetObjectData(si, new StreamingContext(StreamingContextStates.All));
+
+ Assert.Same(expected, si.GetValue("Comparer", typeof(IEqualityComparer<string>)));
+ }
+ }
+ #endregion
}
}
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
+using System.Numerics;
+using System.Reflection;
+using System.Runtime.InteropServices;
+using System.Runtime.Serialization;
using Xunit;
namespace System.Collections.Tests
dictionary.Remove(key);
}
}
+
+ [Fact]
+ public static void ComparerImplementations_Dictionary_WithWellKnownStringComparers()
+ {
+ Type nonRandomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
+ Type nonRandomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
+ Type randomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
+ Type randomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
+
+ // null comparer
+
+ RunDictionaryTest(
+ equalityComparer: null,
+ expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+ expectedPublicComparerBeforeCollisionThreshold: EqualityComparer<string>.Default.GetType(),
+ expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+ // EqualityComparer<string>.Default comparer
+
+ RunDictionaryTest(
+ equalityComparer: EqualityComparer<string>.Default,
+ expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+ expectedPublicComparerBeforeCollisionThreshold: EqualityComparer<string>.Default.GetType(),
+ expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+ // Ordinal comparer
+
+ RunDictionaryTest(
+ equalityComparer: StringComparer.Ordinal,
+ expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+ expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal.GetType(),
+ expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+ // OrdinalIgnoreCase comparer
+
+ RunDictionaryTest(
+ equalityComparer: StringComparer.OrdinalIgnoreCase,
+ expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType,
+ expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase.GetType(),
+ expectedComparerAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType);
+
+ // linguistic comparer (not optimized)
+
+ RunDictionaryTest(
+ equalityComparer: StringComparer.InvariantCulture,
+ expectedInternalComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
+ expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
+ expectedComparerAfterCollisionThreshold: StringComparer.InvariantCulture.GetType());
+
+ static void RunDictionaryTest(
+ IEqualityComparer<string> equalityComparer,
+ Type expectedInternalComparerBeforeCollisionThreshold,
+ Type expectedPublicComparerBeforeCollisionThreshold,
+ Type expectedComparerAfterCollisionThreshold)
+ {
+ RunCollectionTestCommon(
+ () => new Dictionary<string, object>(equalityComparer),
+ (dictionary, key) => dictionary.Add(key, null),
+ (dictionary, key) => dictionary.ContainsKey(key),
+ dictionary => dictionary.Comparer,
+ expectedInternalComparerBeforeCollisionThreshold,
+ expectedPublicComparerBeforeCollisionThreshold,
+ expectedComparerAfterCollisionThreshold);
+ }
+ }
+
+ [Fact]
+ public static void ComparerImplementations_HashSet_WithWellKnownStringComparers()
+ {
+ Type nonRandomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
+ Type nonRandomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
+ Type randomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
+ Type randomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
+
+ // null comparer
+
+ RunHashSetTest(
+ equalityComparer: null,
+ expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+ expectedPublicComparerBeforeCollisionThreshold: EqualityComparer<string>.Default.GetType(),
+ expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+ // EqualityComparer<string>.Default comparer
+
+ RunHashSetTest(
+ equalityComparer: EqualityComparer<string>.Default,
+ expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+ expectedPublicComparerBeforeCollisionThreshold: EqualityComparer<string>.Default.GetType(),
+ expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+ // Ordinal comparer
+
+ RunHashSetTest(
+ equalityComparer: StringComparer.Ordinal,
+ expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
+ expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal.GetType(),
+ expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType);
+
+ // OrdinalIgnoreCase comparer
+
+ RunHashSetTest(
+ equalityComparer: StringComparer.OrdinalIgnoreCase,
+ expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType,
+ expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase.GetType(),
+ expectedComparerAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType);
+
+ // linguistic comparer (not optimized)
+
+ RunHashSetTest(
+ equalityComparer: StringComparer.InvariantCulture,
+ expectedInternalComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
+ expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
+ expectedComparerAfterCollisionThreshold: StringComparer.InvariantCulture.GetType());
+
+ static void RunHashSetTest(
+ IEqualityComparer<string> equalityComparer,
+ Type expectedInternalComparerBeforeCollisionThreshold,
+ Type expectedPublicComparerBeforeCollisionThreshold,
+ Type expectedComparerAfterCollisionThreshold)
+ {
+ RunCollectionTestCommon(
+ () => new HashSet<string>(equalityComparer),
+ (set, key) => Assert.True(set.Add(key)),
+ (set, key) => set.Contains(key),
+ set => set.Comparer,
+ expectedInternalComparerBeforeCollisionThreshold,
+ expectedPublicComparerBeforeCollisionThreshold,
+ expectedComparerAfterCollisionThreshold);
+ }
+ }
+
+ private static void RunCollectionTestCommon<TCollection>(
+ Func<TCollection> collectionFactory,
+ Action<TCollection, string> addKeyCallback,
+ Func<TCollection, string, bool> containsKeyCallback,
+ Func<TCollection, IEqualityComparer<string>> getComparerCallback,
+ Type expectedInternalComparerBeforeCollisionThreshold,
+ Type expectedPublicComparerBeforeCollisionThreshold,
+ Type expectedComparerAfterCollisionThreshold)
+ {
+ TCollection collection = collectionFactory();
+ List<string> allKeys = new List<string>();
+
+ const int StartOfRange = 0xE020; // use the Unicode Private Use range to avoid accidentally creating strings that really do compare as equal OrdinalIgnoreCase
+ const int Stride = 0x40; // to ensure we don't accidentally reset the 0x20 bit of the seed, which is used to negate OrdinalIgnoreCase effects
+
+ // First, go right up to the collision threshold, but don't exceed it.
+
+ for (int i = 0; i < 100; i++)
+ {
+ string newKey = GenerateCollidingString(i * Stride + StartOfRange);
+ Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal
+ Assert.Equal(0x24716ca0, _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(newKey)); // ensure has a zero hash code OrdinalIgnoreCase
+
+ addKeyCallback(collection, newKey);
+ allKeys.Add(newKey);
+ }
+
+ FieldInfo internalComparerField = collection.GetType().GetField("_comparer", BindingFlags.NonPublic | BindingFlags.Instance);
+ Assert.NotNull(internalComparerField);
+
+ Assert.Equal(expectedInternalComparerBeforeCollisionThreshold, internalComparerField.GetValue(collection)?.GetType());
+ Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection).GetType());
+
+ // Now exceed the collision threshold, which should rebucket entries.
+ // Continue adding a few more entries to ensure we didn't corrupt internal state.
+
+ for (int i = 100; i < 110; i++)
+ {
+ string newKey = GenerateCollidingString(i * Stride + StartOfRange);
+ Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal
+ Assert.Equal(0x24716ca0, _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(newKey)); // ensure has a zero hash code OrdinalIgnoreCase
+
+ addKeyCallback(collection, newKey);
+ allKeys.Add(newKey);
+ }
+
+ Assert.Equal(expectedComparerAfterCollisionThreshold, internalComparerField.GetValue(collection)?.GetType());
+ Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection).GetType()); // shouldn't change this return value after collision threshold met
+
+ // And validate that all strings are present in the dictionary.
+
+ foreach (string key in allKeys)
+ {
+ Assert.True(containsKeyCallback(collection, key));
+ }
+
+ // Also make sure we didn't accidentally put the internal comparer in the serialized object data.
+
+ collection = collectionFactory();
+ SerializationInfo si = new SerializationInfo(collection.GetType(), new FormatterConverter());
+ ((ISerializable)collection).GetObjectData(si, new StreamingContext());
+
+ object serializedComparer = si.GetValue("Comparer", typeof(IEqualityComparer<string>));
+ Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, serializedComparer.GetType());
+ }
+
+ private static Lazy<Func<string, int>> _lazyGetNonRandomizedHashCodeDel = new Lazy<Func<string, int>>(
+ () => GetStringHashCodeOpenDelegate("GetNonRandomizedHashCode"));
+
+ private static Lazy<Func<string, int>> _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel = new Lazy<Func<string, int>>(
+ () => GetStringHashCodeOpenDelegate("GetNonRandomizedHashCodeOrdinalIgnoreCase"));
+
+ // Generates a string with a well-known non-randomized hash code:
+ // - string.GetNonRandomizedHashCode returns 0.
+ // - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0x24716ca0.
+ // Provide a different seed to produce a different string.
+ private static string GenerateCollidingString(int seed)
+ {
+ return string.Create(8, seed, (span, seed) =>
+ {
+ Span<byte> asBytes = MemoryMarshal.AsBytes(span);
+
+ uint hash1 = (5381 << 16) + 5381;
+ uint hash2 = BitOperations.RotateLeft(hash1, 5) + hash1;
+
+ MemoryMarshal.Write(asBytes, ref seed);
+ MemoryMarshal.Write(asBytes.Slice(4), ref hash2); // set hash2 := 0 (for Ordinal)
+
+ hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (uint)seed;
+ hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1);
+
+ MemoryMarshal.Write(asBytes.Slice(8), ref hash1); // set hash1 := 0 (for Ordinal)
+ });
+ }
+
+ private static Func<string, int> GetStringHashCodeOpenDelegate(string methodName)
+ {
+ MethodInfo method = typeof(string).GetMethod(methodName, BindingFlags.Instance | BindingFlags.NonPublic);
+ Assert.NotNull(method);
+
+ return method.CreateDelegate<Func<string, int>>(target: null); // create open delegate unbound to 'this'
+ }
}
}
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IEnumerable.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IEnumerator.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IEqualityComparer.cs" />
+ <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IInternalStringEqualityComparer.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IList.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\InsertionBehavior.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\IReadOnlyCollection.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\KeyNotFoundException.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\KeyValuePair.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\List.cs" />
+ <Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\RandomizedStringEqualityComparer.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\NonRandomizedStringEqualityComparer.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\Generic\ValueListBuilder.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Collections\HashHelpers.cs" />
_comparer = comparer;
}
- if (typeof(TKey) == typeof(string) && _comparer == null)
+ // Special-case EqualityComparer<string>.Default, StringComparer.Ordinal, and StringComparer.OrdinalIgnoreCase.
+ // We use a non-randomized comparer for improved perf, falling back to a randomized comparer if the
+ // hash buckets become unbalanced.
+
+ if (typeof(TKey) == typeof(string))
{
- // To start, move off default comparer for string which is randomised
- _comparer = (IEqualityComparer<TKey>)NonRandomizedStringEqualityComparer.Default;
+ if (_comparer is null)
+ {
+ _comparer = (IEqualityComparer<TKey>)NonRandomizedStringEqualityComparer.WrappedAroundDefaultComparer;
+ }
+ else if (ReferenceEquals(_comparer, StringComparer.Ordinal))
+ {
+ _comparer = (IEqualityComparer<TKey>)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinal;
+ }
+ else if (ReferenceEquals(_comparer, StringComparer.OrdinalIgnoreCase))
+ {
+ _comparer = (IEqualityComparer<TKey>)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinalIgnoreCase;
+ }
}
}
HashHelpers.SerializationInfoTable.Add(this, info);
}
- public IEqualityComparer<TKey> Comparer =>
- (_comparer == null || _comparer is NonRandomizedStringEqualityComparer) ?
- EqualityComparer<TKey>.Default :
- _comparer;
+ public IEqualityComparer<TKey> Comparer
+ {
+ get
+ {
+ if (typeof(TKey) == typeof(string))
+ {
+ return (IEqualityComparer<TKey>)IInternalStringEqualityComparer.GetUnderlyingEqualityComparer((IEqualityComparer<string?>?)_comparer);
+ }
+ else
+ {
+ return _comparer ?? EqualityComparer<TKey>.Default;
+ }
+ }
+ }
public int Count => _count - _freeCount;
}
info.AddValue(VersionName, _version);
- info.AddValue(ComparerName, _comparer ?? EqualityComparer<TKey>.Default, typeof(IEqualityComparer<TKey>));
+ info.AddValue(ComparerName, Comparer, typeof(IEqualityComparer<TKey>));
info.AddValue(HashSizeName, _buckets == null ? 0 : _buckets.Length); // This is the length of the bucket array
if (_buckets != null)
{
// If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing
// i.e. EqualityComparer<string>.Default.
- _comparer = null;
Resize(entries.Length, true);
}
if (!typeof(TKey).IsValueType && forceNewHashCodes)
{
+ Debug.Assert(_comparer is NonRandomizedStringEqualityComparer);
+ _comparer = (IEqualityComparer<TKey>)((NonRandomizedStringEqualityComparer)_comparer).GetRandomizedEqualityComparer();
+
for (int i = 0; i < count; i++)
{
if (entries[i].next >= -1)
{
- Debug.Assert(_comparer == null);
- entries[i].hashCode = (uint)entries[i].key.GetHashCode();
+ entries[i].hashCode = (uint)_comparer.GetHashCode(entries[i].key);
}
}
+
+ if (ReferenceEquals(_comparer, EqualityComparer<TKey>.Default))
+ {
+ _comparer = null;
+ }
}
// Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
_comparer = comparer;
}
- if (typeof(T) == typeof(string) && _comparer == null)
+ // Special-case EqualityComparer<string>.Default, StringComparer.Ordinal, and StringComparer.OrdinalIgnoreCase.
+ // We use a non-randomized comparer for improved perf, falling back to a randomized comparer if the
+ // hash buckets become unbalanced.
+
+ if (typeof(T) == typeof(string))
{
- // To start, move off default comparer for string which is randomized.
- _comparer = (IEqualityComparer<T>)NonRandomizedStringEqualityComparer.Default;
+ if (_comparer is null)
+ {
+ _comparer = (IEqualityComparer<T>)NonRandomizedStringEqualityComparer.WrappedAroundDefaultComparer;
+ }
+ else if (ReferenceEquals(_comparer, StringComparer.Ordinal))
+ {
+ _comparer = (IEqualityComparer<T>)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinal;
+ }
+ else if (ReferenceEquals(_comparer, StringComparer.OrdinalIgnoreCase))
+ {
+ _comparer = (IEqualityComparer<T>)NonRandomizedStringEqualityComparer.WrappedAroundStringComparerOrdinalIgnoreCase;
+ }
}
}
}
info.AddValue(VersionName, _version); // need to serialize version to avoid problems with serializing while enumerating
- info.AddValue(ComparerName, _comparer ?? EqualityComparer<T>.Default, typeof(IEqualityComparer<T>));
+ info.AddValue(ComparerName, Comparer, typeof(IEqualityComparer<T>));
info.AddValue(CapacityName, _buckets == null ? 0 : _buckets.Length);
if (_buckets != null)
}
/// <summary>Gets the <see cref="IEqualityComparer"/> object that is used to determine equality for the values in the set.</summary>
- public IEqualityComparer<T> Comparer =>
- (_comparer == null || _comparer is NonRandomizedStringEqualityComparer) ?
- EqualityComparer<T>.Default :
- _comparer;
+ public IEqualityComparer<T> Comparer
+ {
+ get
+ {
+ if (typeof(T) == typeof(string))
+ {
+ return (IEqualityComparer<T>)IInternalStringEqualityComparer.GetUnderlyingEqualityComparer((IEqualityComparer<string?>?)_comparer);
+ }
+ else
+ {
+ return _comparer ?? EqualityComparer<T>.Default;
+ }
+ }
+ }
/// <summary>Ensures that this hash set can hold the specified number of elements without growing.</summary>
public int EnsureCapacity(int capacity)
if (!typeof(T).IsValueType && forceNewHashCodes)
{
+ Debug.Assert(_comparer is NonRandomizedStringEqualityComparer);
+ _comparer = (IEqualityComparer<T>)((NonRandomizedStringEqualityComparer)_comparer).GetRandomizedEqualityComparer();
+
for (int i = 0; i < count; i++)
{
ref Entry entry = ref entries[i];
if (entry.Next >= -1)
{
- Debug.Assert(_comparer == null);
- entry.HashCode = entry.Value != null ? entry.Value!.GetHashCode() : 0;
+ entry.HashCode = entry.Value != null ? _comparer!.GetHashCode(entry.Value) : 0;
}
}
+
+ if (ReferenceEquals(_comparer, EqualityComparer<T>.Default))
+ {
+ _comparer = null;
+ }
}
// Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
{
// If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing
// i.e. EqualityComparer<string>.Default.
- _comparer = null;
Resize(entries.Length, forceNewHashCodes: true);
location = FindItemIndex(value);
Debug.Assert(location >= 0);
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Collections.Generic
+{
+ /// <summary>
+ /// Represents an <see cref="IEqualityComparer{String}"/> that's meant for internal
+ /// use only and isn't intended to be serialized or returned back to the user.
+ /// Use the <see cref="GetUnderlyingEqualityComparer"/> method to get the object
+ /// that should actually be returned to the caller.
+ /// </summary>
+ internal interface IInternalStringEqualityComparer : IEqualityComparer<string?>
+ {
+ IEqualityComparer<string?> GetUnderlyingEqualityComparer();
+
+ /// <summary>
+ /// Unwraps the internal equality comparer, if proxied.
+ /// Otherwise returns the equality comparer itself or its default equivalent.
+ /// </summary>
+ internal static IEqualityComparer<string?> GetUnderlyingEqualityComparer(IEqualityComparer<string?>? outerComparer)
+ {
+ if (outerComparer is null)
+ {
+ return EqualityComparer<string?>.Default;
+ }
+ else if (outerComparer is IInternalStringEqualityComparer internalComparer)
+ {
+ return internalComparer.GetUnderlyingEqualityComparer();
+ }
+ else
+ {
+ return outerComparer;
+ }
+ }
+ }
+}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Diagnostics;
+using System.Globalization;
using System.Runtime.Serialization;
+using Internal.Runtime.CompilerServices;
namespace System.Collections.Generic
{
// randomized string hashing.
[Serializable] // Required for compatibility with .NET Core 2.0 as we exposed the NonRandomizedStringEqualityComparer inside the serialization blob
// Needs to be public to support binary serialization compatibility
- public sealed class NonRandomizedStringEqualityComparer : EqualityComparer<string?>, ISerializable
+ public class NonRandomizedStringEqualityComparer : IEqualityComparer<string?>, IInternalStringEqualityComparer, ISerializable
{
- internal static new IEqualityComparer<string?> Default { get; } = new NonRandomizedStringEqualityComparer();
+ // Dictionary<...>.Comparer and similar methods need to return the original IEqualityComparer
+ // that was passed in to the ctor. The caller chooses one of these singletons so that the
+ // GetUnderlyingEqualityComparer method can return the correct value.
- private NonRandomizedStringEqualityComparer() { }
+ internal static readonly NonRandomizedStringEqualityComparer WrappedAroundDefaultComparer = new OrdinalComparer(EqualityComparer<string?>.Default);
+ internal static readonly NonRandomizedStringEqualityComparer WrappedAroundStringComparerOrdinal = new OrdinalComparer(StringComparer.Ordinal);
+ internal static readonly NonRandomizedStringEqualityComparer WrappedAroundStringComparerOrdinalIgnoreCase = new OrdinalIgnoreCaseComparer(StringComparer.OrdinalIgnoreCase);
+
+ private readonly IEqualityComparer<string?> _underlyingComparer;
+
+ private NonRandomizedStringEqualityComparer(IEqualityComparer<string?> underlyingComparer)
+ {
+ Debug.Assert(underlyingComparer != null);
+ _underlyingComparer = underlyingComparer;
+ }
// This is used by the serialization engine.
- private NonRandomizedStringEqualityComparer(SerializationInfo information, StreamingContext context) { }
+ protected NonRandomizedStringEqualityComparer(SerializationInfo information, StreamingContext context)
+ : this(EqualityComparer<string?>.Default)
+ {
+ }
+
+ public virtual bool Equals(string? x, string? y)
+ {
+ // This instance may have been deserialized into a class that doesn't guarantee
+ // these parameters are non-null. Can't short-circuit the null checks.
+
+ return string.Equals(x, y);
+ }
+
+ public virtual int GetHashCode(string? obj)
+ {
+ // This instance may have been deserialized into a class that doesn't guarantee
+ // these parameters are non-null. Can't short-circuit the null checks.
+
+ return obj?.GetNonRandomizedHashCode() ?? 0;
+ }
- public sealed override bool Equals(string? x, string? y) => string.Equals(x, y);
+ internal virtual RandomizedStringEqualityComparer GetRandomizedEqualityComparer()
+ {
+ return RandomizedStringEqualityComparer.Create(_underlyingComparer, ignoreCase: false);
+ }
- public sealed override int GetHashCode(string? obj) => obj?.GetNonRandomizedHashCode() ?? 0;
+ // Gets the comparer that should be returned back to the caller when querying the
+ // ICollection.Comparer property. Also used for serialization purposes.
+ public virtual IEqualityComparer<string?> GetUnderlyingEqualityComparer() => _underlyingComparer;
- public void GetObjectData(SerializationInfo info, StreamingContext context)
+ void ISerializable.GetObjectData(SerializationInfo info, StreamingContext context)
{
// We are doing this to stay compatible with .NET Framework.
+ // Our own collection types will never call this (since this type is a wrapper),
+ // but perhaps third-party collection types could try serializing an instance
+ // of this.
info.SetType(typeof(GenericEqualityComparer<string>));
}
+
+ private sealed class OrdinalComparer : NonRandomizedStringEqualityComparer
+ {
+ internal OrdinalComparer(IEqualityComparer<string?> wrappedComparer)
+ : base(wrappedComparer)
+ {
+ }
+
+ public override bool Equals(string? x, string? y) => string.Equals(x, y);
+
+ public override int GetHashCode(string? obj)
+ {
+ Debug.Assert(obj != null, "This implementation is only called from first-party collection types that guarantee non-null parameters.");
+ return obj.GetNonRandomizedHashCode();
+ }
+
+ }
+
+ private sealed class OrdinalIgnoreCaseComparer : NonRandomizedStringEqualityComparer
+ {
+ internal OrdinalIgnoreCaseComparer(IEqualityComparer<string?> wrappedComparer)
+ : base(wrappedComparer)
+ {
+ }
+
+ public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y);
+
+ public override int GetHashCode(string? obj)
+ {
+ Debug.Assert(obj != null, "This implementation is only called from first-party collection types that guarantee non-null parameters.");
+ return obj.GetNonRandomizedHashCodeOrdinalIgnoreCase();
+ }
+
+ internal override RandomizedStringEqualityComparer GetRandomizedEqualityComparer()
+ {
+ return RandomizedStringEqualityComparer.Create(_underlyingComparer, ignoreCase: true);
+ }
+ }
}
}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Internal.Runtime.CompilerServices;
+
+namespace System.Collections.Generic
+{
+ /// <summary>
+ /// A randomized <see cref="EqualityComparer{String}"/> which uses a different seed on each
+ /// construction as a general good hygiene + defense-in-depth mechanism. This implementation
+ /// *does not* need to stay in sync with <see cref="string.GetHashCode"/>, which for stability
+ /// is required to use an app-global seed.
+ /// </summary>
+ internal abstract class RandomizedStringEqualityComparer : EqualityComparer<string?>, IInternalStringEqualityComparer
+ {
+ private readonly MarvinSeed _seed;
+ private readonly IEqualityComparer<string?> _underlyingComparer;
+
+ private unsafe RandomizedStringEqualityComparer(IEqualityComparer<string?> underlyingComparer)
+ {
+ _underlyingComparer = underlyingComparer;
+
+ fixed (MarvinSeed* seed = &_seed)
+ {
+ Interop.GetRandomBytes((byte*)seed, sizeof(MarvinSeed));
+ }
+ }
+
+ internal static RandomizedStringEqualityComparer Create(IEqualityComparer<string?> underlyingComparer, bool ignoreCase)
+ {
+ if (!ignoreCase)
+ {
+ return new OrdinalComparer(underlyingComparer);
+ }
+ else
+ {
+ return new OrdinalIgnoreCaseComparer(underlyingComparer);
+ }
+ }
+
+ public IEqualityComparer<string?> GetUnderlyingEqualityComparer() => _underlyingComparer;
+
+ private struct MarvinSeed
+ {
+ internal uint p0;
+ internal uint p1;
+ }
+
+ private sealed class OrdinalComparer : RandomizedStringEqualityComparer
+ {
+ internal OrdinalComparer(IEqualityComparer<string?> wrappedComparer)
+ : base(wrappedComparer)
+ {
+ }
+
+ public override bool Equals(string? x, string? y) => string.Equals(x, y);
+
+ public override int GetHashCode(string? obj)
+ {
+ if (obj is null)
+ {
+ return 0;
+ }
+
+ // The Ordinal version of Marvin32 operates over bytes.
+ // The multiplication from # chars -> # bytes will never integer overflow.
+ return Marvin.ComputeHash32(
+ ref Unsafe.As<char, byte>(ref obj.GetRawStringData()),
+ (uint)obj.Length * 2,
+ _seed.p0, _seed.p1);
+ }
+ }
+
+ private sealed class OrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer
+ {
+ internal OrdinalIgnoreCaseComparer(IEqualityComparer<string?> wrappedComparer)
+ : base(wrappedComparer)
+ {
+ }
+
+ public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y);
+
+ public override int GetHashCode(string? obj)
+ {
+ if (obj is null)
+ {
+ return 0;
+ }
+
+ // The Ordinal version of Marvin32 operates over bytes, so convert
+ // char count -> byte count. Guaranteed not to integer overflow.
+ return Marvin.ComputeHash32(
+ ref Unsafe.As<char, byte>(ref obj.GetRawStringData()),
+ (uint)obj.Length * sizeof(char),
+ _seed.p0, _seed.p1);
+ }
+ }
+
+ private sealed class RandomizedOrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer
+ {
+ internal RandomizedOrdinalIgnoreCaseComparer(IEqualityComparer<string?> underlyingComparer)
+ : base(underlyingComparer)
+ {
+ }
+
+ public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y);
+
+ public override int GetHashCode(string? obj)
+ {
+ if (obj is null)
+ {
+ return 0;
+ }
+
+ // The OrdinalIgnoreCase version of Marvin32 operates over chars,
+ // so pass in the char count directly.
+ return Marvin.ComputeHash32OrdinalIgnoreCase(
+ ref obj.GetRawStringData(),
+ obj.Length,
+ _seed.p0, _seed.p1);
+ }
+ }
+ }
+}
return SpanHelpers.SequenceCompareTo(ref Unsafe.Add(ref strA.GetRawStringData(), indexA), countA, ref Unsafe.Add(ref strB.GetRawStringData(), indexB), countB);
}
- private static bool EqualsOrdinalIgnoreCase(string strA, string strB)
+ internal static bool EqualsOrdinalIgnoreCase(string? strA, string? strB)
+ {
+ if (ReferenceEquals(strA, strB))
+ {
+ return true;
+ }
+
+ if (strA is null || strB is null)
+ {
+ return false;
+ }
+
+ if (strA.Length != strB.Length)
+ {
+ return false;
+ }
+
+ return EqualsOrdinalIgnoreCaseNoLengthCheck(strA, strB);
+ }
+
+ private static bool EqualsOrdinalIgnoreCaseNoLengthCheck(string strA, string strB)
{
Debug.Assert(strA.Length == strB.Length);
if (this.Length != value.Length)
return false;
- return EqualsOrdinalIgnoreCase(this, value);
+ return EqualsOrdinalIgnoreCaseNoLengthCheck(this, value);
default:
throw new ArgumentException(SR.NotSupported_StringComparison, nameof(comparisonType));
if (a.Length != b.Length)
return false;
- return EqualsOrdinalIgnoreCase(a, b);
+ return EqualsOrdinalIgnoreCaseNoLengthCheck(a, b);
default:
throw new ArgumentException(SR.NotSupported_StringComparison, nameof(comparisonType));
}
}
+ // Use this if and only if 'Denial of Service' attacks are not a concern (i.e. never used for free-form user input),
+ // or are otherwise mitigated
+ internal unsafe int GetNonRandomizedHashCodeOrdinalIgnoreCase()
+ {
+ fixed (char* src = &_firstChar)
+ {
+ Debug.Assert(src[this.Length] == '\0', "src[this.Length] == '\\0'");
+ Debug.Assert(((int)src) % 4 == 0, "Managed string should start at 4 bytes boundary");
+
+ uint hash1 = (5381 << 16) + 5381;
+ uint hash2 = hash1;
+
+ uint* ptr = (uint*)src;
+ int length = this.Length;
+
+ // We "normalize to lowercase" every char by ORing with 0x0020. This casts
+ // a very wide net because it will change, e.g., '^' to '~'. But that should
+ // be ok because we expect this to be very rare in practice.
+
+ const uint NormalizeToLowercase = 0x0020_0020u; // valid both for big-endian and for little-endian
+
+ while (length > 2)
+ {
+ length -= 4;
+ // Where length is 4n-1 (e.g. 3,7,11,15,19) this additionally consumes the null terminator
+ hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (ptr[0] | NormalizeToLowercase);
+ hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (ptr[1] | NormalizeToLowercase);
+ ptr += 2;
+ }
+
+ if (length > 0)
+ {
+ // Where length is 4n-3 (e.g. 1,5,9,13,17) this additionally consumes the null terminator
+ hash2 = (BitOperations.RotateLeft(hash2, 5) + hash2) ^ (ptr[0] | NormalizeToLowercase);
+ }
+
+ return (int)(hash1 + (hash2 * 1566083941));
+ }
+ }
+
// Determines whether a specified string is a prefix of the current instance
//
public bool StartsWith(string value)
[System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
public abstract class StringComparer : IComparer, IEqualityComparer, IComparer<string?>, IEqualityComparer<string?>
{
- private static readonly CultureAwareComparer s_invariantCulture = new CultureAwareComparer(CultureInfo.InvariantCulture, CompareOptions.None);
- private static readonly CultureAwareComparer s_invariantCultureIgnoreCase = new CultureAwareComparer(CultureInfo.InvariantCulture, CompareOptions.IgnoreCase);
- private static readonly OrdinalCaseSensitiveComparer s_ordinal = new OrdinalCaseSensitiveComparer();
- private static readonly OrdinalIgnoreCaseComparer s_ordinalIgnoreCase = new OrdinalIgnoreCaseComparer();
+ public static StringComparer InvariantCulture => CultureAwareComparer.InvariantCaseSensitiveInstance;
- public static StringComparer InvariantCulture => s_invariantCulture;
-
- public static StringComparer InvariantCultureIgnoreCase => s_invariantCultureIgnoreCase;
+ public static StringComparer InvariantCultureIgnoreCase => CultureAwareComparer.InvariantIgnoreCaseInstance;
public static StringComparer CurrentCulture =>
new CultureAwareComparer(CultureInfo.CurrentCulture, CompareOptions.None);
public static StringComparer CurrentCultureIgnoreCase =>
new CultureAwareComparer(CultureInfo.CurrentCulture, CompareOptions.IgnoreCase);
- public static StringComparer Ordinal => s_ordinal;
+ public static StringComparer Ordinal => OrdinalCaseSensitiveComparer.Instance;
- public static StringComparer OrdinalIgnoreCase => s_ordinalIgnoreCase;
+ public static StringComparer OrdinalIgnoreCase => OrdinalIgnoreCaseComparer.Instance;
// Convert a StringComparison to a StringComparer
public static StringComparer FromComparison(StringComparison comparisonType)
[System.Runtime.CompilerServices.TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
public sealed class CultureAwareComparer : StringComparer, ISerializable
{
+ internal static readonly CultureAwareComparer InvariantCaseSensitiveInstance = new CultureAwareComparer(CompareInfo.Invariant, CompareOptions.None);
+ internal static readonly CultureAwareComparer InvariantIgnoreCaseInstance = new CultureAwareComparer(CompareInfo.Invariant, CompareOptions.IgnoreCase);
+
private const CompareOptions ValidCompareMaskOffFlags = ~(CompareOptions.IgnoreCase | CompareOptions.IgnoreSymbols | CompareOptions.IgnoreNonSpace | CompareOptions.IgnoreWidth | CompareOptions.IgnoreKanaType | CompareOptions.StringSort);
private readonly CompareInfo _compareInfo; // Do not rename (binary serialization)
[Serializable]
internal sealed class OrdinalCaseSensitiveComparer : OrdinalComparer, ISerializable
{
- public OrdinalCaseSensitiveComparer() : base(false)
+ internal static readonly OrdinalCaseSensitiveComparer Instance = new OrdinalCaseSensitiveComparer();
+
+ private OrdinalCaseSensitiveComparer() : base(false)
{
}
[Serializable]
internal sealed class OrdinalIgnoreCaseComparer : OrdinalComparer, ISerializable
{
- public OrdinalIgnoreCaseComparer() : base(true)
+ internal static readonly OrdinalIgnoreCaseComparer Instance = new OrdinalIgnoreCaseComparer();
+
+ private OrdinalIgnoreCaseComparer() : base(true)
{
}