Add string.GetHashCode(ROS<char>) and related APIs (#20422)
authorLevi Broderick <GrabYourPitchforks@users.noreply.github.com>
Thu, 8 Nov 2018 07:54:11 +0000 (23:54 -0800)
committerGitHub <noreply@github.com>
Thu, 8 Nov 2018 07:54:11 +0000 (23:54 -0800)
src/System.Private.CoreLib/shared/Interop/Unix/System.Globalization.Native/Interop.Collation.cs
src/System.Private.CoreLib/shared/System/Globalization/CompareInfo.Unix.cs
src/System.Private.CoreLib/shared/System/Globalization/CompareInfo.Windows.cs
src/System.Private.CoreLib/shared/System/Globalization/CompareInfo.cs
src/System.Private.CoreLib/shared/System/String.Comparison.cs

index aeeb60f..aea7615 100644 (file)
@@ -48,7 +48,7 @@ internal static partial class Interop
         internal static extern unsafe bool EndsWith(SafeSortHandle sortHandle, string target, int cwTargetLength, string source, int cwSourceLength, CompareOptions options);
 
         [DllImport(Libraries.GlobalizationNative, CharSet = CharSet.Unicode, EntryPoint = "GlobalizationNative_GetSortKey")]
-        internal static extern unsafe int GetSortKey(SafeSortHandle sortHandle, string str, int strLength, byte* sortKey, int sortKeyLength, CompareOptions options);
+        internal static extern unsafe int GetSortKey(SafeSortHandle sortHandle, char* str, int strLength, byte* sortKey, int sortKeyLength, CompareOptions options);
 
         [DllImport(Libraries.GlobalizationNative, CharSet = CharSet.Unicode, EntryPoint = "GlobalizationNative_CompareStringOrdinalIgnoreCase")]
         internal static extern unsafe int CompareStringOrdinalIgnoreCase(char* lpStr1, int cwStr1Len, char* lpStr2, int cwStr2Len);
index f517540..966cc9d 100644 (file)
@@ -798,14 +798,17 @@ namespace System.Globalization
             }
             else
             {
-                int sortKeyLength = Interop.Globalization.GetSortKey(_sortHandle, source, source.Length, null, 0, options);
-                keyData = new byte[sortKeyLength];
-
-                fixed (byte* pSortKey = keyData)
+                fixed (char* pSource = source)
                 {
-                    if (Interop.Globalization.GetSortKey(_sortHandle, source, source.Length, pSortKey, sortKeyLength, options) != sortKeyLength)
+                    int sortKeyLength = Interop.Globalization.GetSortKey(_sortHandle, pSource, source.Length, null, 0, options);
+                    keyData = new byte[sortKeyLength];
+
+                    fixed (byte* pSortKey = keyData)
                     {
-                        throw new ArgumentException(SR.Arg_ExternalException);
+                        if (Interop.Globalization.GetSortKey(_sortHandle, pSource, source.Length, pSortKey, sortKeyLength, options) != sortKeyLength)
+                        {
+                            throw new ArgumentException(SR.Arg_ExternalException);
+                        }
                     }
                 }
             }
@@ -856,11 +859,9 @@ namespace System.Globalization
         // ---- PAL layer ends here ----
         // -----------------------------
 
-        internal unsafe int GetHashCodeOfStringCore(string source, CompareOptions options)
+        internal unsafe int GetHashCodeOfStringCore(ReadOnlySpan<char> source, CompareOptions options)
         {
             Debug.Assert(!_invariantMode);
-
-            Debug.Assert(source != null);
             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
 
             if (source.Length == 0)
@@ -868,30 +869,33 @@ namespace System.Globalization
                 return 0;
             }
 
-            int sortKeyLength = Interop.Globalization.GetSortKey(_sortHandle, source, source.Length, null, 0, options);
+            fixed (char* pSource = source)
+            {
+                int sortKeyLength = Interop.Globalization.GetSortKey(_sortHandle, pSource, source.Length, null, 0, options);
 
-            byte[] borrowedArr = null;
-            Span<byte> span = sortKeyLength <= 512 ?
-                stackalloc byte[512] :
-                (borrowedArr = ArrayPool<byte>.Shared.Rent(sortKeyLength));
+                byte[] borrowedArr = null;
+                Span<byte> span = sortKeyLength <= 512 ?
+                    stackalloc byte[512] :
+                    (borrowedArr = ArrayPool<byte>.Shared.Rent(sortKeyLength));
 
-            fixed (byte* pSortKey = &MemoryMarshal.GetReference(span))
-            {
-                if (Interop.Globalization.GetSortKey(_sortHandle, source, source.Length, pSortKey, sortKeyLength, options) != sortKeyLength)
+                fixed (byte* pSortKey = &MemoryMarshal.GetReference(span))
                 {
-                    throw new ArgumentException(SR.Arg_ExternalException);
+                    if (Interop.Globalization.GetSortKey(_sortHandle, pSource, source.Length, pSortKey, sortKeyLength, options) != sortKeyLength)
+                    {
+                        throw new ArgumentException(SR.Arg_ExternalException);
+                    }
                 }
-            }
 
-            int hash = Marvin.ComputeHash32(span.Slice(0, sortKeyLength), Marvin.DefaultSeed);
+                int hash = Marvin.ComputeHash32(span.Slice(0, sortKeyLength), Marvin.DefaultSeed);
 
-            // Return the borrowed array if necessary.
-            if (borrowedArr != null)
-            {
-                ArrayPool<byte>.Shared.Return(borrowedArr);
-            }
+                // Return the borrowed array if necessary.
+                if (borrowedArr != null)
+                {
+                    ArrayPool<byte>.Shared.Return(borrowedArr);
+                }
 
-            return hash;
+                return hash;
+            }
         }
 
         private static CompareOptions GetOrdinalCompareOptions(CompareOptions options)
index d1b12c6..f6fb690 100644 (file)
@@ -111,12 +111,10 @@ namespace System.Globalization
 
             return FindStringOrdinal(FIND_FROMEND, source, startIndex - count + 1, count, value, value.Length, ignoreCase);
         }
-
-        private unsafe int GetHashCodeOfStringCore(string source, CompareOptions options)
+        
+        private unsafe int GetHashCodeOfStringCore(ReadOnlySpan<char> source, CompareOptions options)
         {
             Debug.Assert(!_invariantMode);
-
-            Debug.Assert(source != null);
             Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);
 
             if (source.Length == 0)
@@ -130,7 +128,7 @@ namespace System.Globalization
             {
                 int sortKeyLength = Interop.Kernel32.LCMapStringEx(_sortHandle != IntPtr.Zero ? null : _sortName,
                                                   flags,
-                                                  pSource, source.Length,
+                                                  pSource, source.Length /* in chars */,
                                                   null, 0,
                                                   null, null, _sortHandle);
                 if (sortKeyLength == 0)
@@ -138,6 +136,11 @@ namespace System.Globalization
                     throw new ArgumentException(SR.Arg_ExternalException);
                 }
 
+                // Note in calls to LCMapStringEx below, the input buffer is specified in wchars (and wchar count),
+                // but the output buffer is specified in bytes (and byte count). This is because when generating
+                // sort keys, LCMapStringEx treats the output buffer as containing opaque binary data.
+                // See https://docs.microsoft.com/en-us/windows/desktop/api/winnls/nf-winnls-lcmapstringex.
+
                 byte[] borrowedArr = null;
                 Span<byte> span = sortKeyLength <= 512 ?
                     stackalloc byte[512] :
@@ -147,7 +150,7 @@ namespace System.Globalization
                 {
                     if (Interop.Kernel32.LCMapStringEx(_sortHandle != IntPtr.Zero ? null : _sortName,
                                                       flags,
-                                                      pSource, source.Length,
+                                                      pSource, source.Length /* in chars */,
                                                       pSortKey, sortKeyLength,
                                                       null, null, _sortHandle) != sortKeyLength)
                     {
index ab1eecf..bfe509d 100644 (file)
@@ -1420,43 +1420,71 @@ namespace System.Globalization
             {
                 throw new ArgumentNullException(nameof(source));
             }
+            if ((options & ValidHashCodeOfStringMaskOffFlags) == 0)
+            {
+                // No unsupported flags are set - continue on with the regular logic
+
+                if (_invariantMode)
+                {
+                    return ((options & CompareOptions.IgnoreCase) != 0) ? source.GetHashCodeOrdinalIgnoreCase() : source.GetHashCode();
+                }
 
-            if ((options & ValidHashCodeOfStringMaskOffFlags) != 0)
+                return GetHashCodeOfStringCore(source, options);
+            }
+            else if (options == CompareOptions.Ordinal)
             {
-                throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
+                // We allow Ordinal in isolation
+                return source.GetHashCode();
             }
-
-            if (_invariantMode)
+            else if (options == CompareOptions.OrdinalIgnoreCase)
             {
-                return ((options & CompareOptions.IgnoreCase) != 0) ? source.GetHashCodeOrdinalIgnoreCase() : source.GetHashCode();
+                // We allow OrdinalIgnoreCase in isolation
+                return source.GetHashCodeOrdinalIgnoreCase();
+            }
+            else
+            {
+                // Unsupported combination of flags specified
+                throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
             }
-
-            return GetHashCodeOfStringCore(source, options);
         }
 
         public virtual int GetHashCode(string source, CompareOptions options)
         {
-            if (source == null)
+            // virtual method delegates to non-virtual method
+            return GetHashCodeOfString(source, options);
+        }
+
+        public int GetHashCode(ReadOnlySpan<char> source, CompareOptions options)
+        {
+            //
+            //  Parameter validation
+            //
+            if ((options & ValidHashCodeOfStringMaskOffFlags) == 0)
             {
-                throw new ArgumentNullException(nameof(source));
-            }
+                // No unsupported flags are set - continue on with the regular logic
 
-            if (options == CompareOptions.Ordinal)
+                if (_invariantMode)
+                {
+                    return ((options & CompareOptions.IgnoreCase) != 0) ? string.GetHashCodeOrdinalIgnoreCase(source) : string.GetHashCode(source);
+                }
+
+                return GetHashCodeOfStringCore(source, options);
+            }
+            else if (options == CompareOptions.Ordinal)
             {
-                return source.GetHashCode();
+                // We allow Ordinal in isolation
+                return string.GetHashCode(source);
             }
-
-            if (options == CompareOptions.OrdinalIgnoreCase)
+            else if (options == CompareOptions.OrdinalIgnoreCase)
             {
-                return source.GetHashCodeOrdinalIgnoreCase();
+                // We allow OrdinalIgnoreCase in isolation
+                return string.GetHashCodeOrdinalIgnoreCase(source);
+            }
+            else
+            {
+                // Unsupported combination of flags specified
+                throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
             }
-
-            //
-            // GetHashCodeOfString does more parameters validation. basically will throw when
-            // having Ordinal, OrdinalIgnoreCase and StringSort
-            //
-
-            return GetHashCodeOfString(source, options);
         }
 
         ////////////////////////////////////////////////////////////////////////
index f20f6c3..360167d 100644 (file)
@@ -748,7 +748,7 @@ namespace System
         public override int GetHashCode()
         {
             ulong seed = Marvin.DefaultSeed;
-            return Marvin.ComputeHash32(ref Unsafe.As<char, byte>(ref _firstChar), _stringLength * 2, (uint)seed, (uint)(seed >> 32));
+            return Marvin.ComputeHash32(ref Unsafe.As<char, byte>(ref _firstChar), _stringLength * 2 /* in bytes, not chars */, (uint)seed, (uint)(seed >> 32));
         }
 
         // Gets a hash code for this string and this comparison. If strings A and B and comparison C are such
@@ -759,7 +759,48 @@ namespace System
         internal int GetHashCodeOrdinalIgnoreCase()
         {
             ulong seed = Marvin.DefaultSeed;
-            return Marvin.ComputeHash32OrdinalIgnoreCase(ref _firstChar, _stringLength, (uint)seed, (uint)(seed >> 32));
+            return Marvin.ComputeHash32OrdinalIgnoreCase(ref _firstChar, _stringLength /* in chars, not bytes */, (uint)seed, (uint)(seed >> 32));
+        }
+
+        // A span-based equivalent of String.GetHashCode(). Computes an ordinal hash code.
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        public static int GetHashCode(ReadOnlySpan<char> value)
+        {
+            ulong seed = Marvin.DefaultSeed;
+            return Marvin.ComputeHash32(ref Unsafe.As<char, byte>(ref MemoryMarshal.GetReference(value)), value.Length * 2 /* in bytes, not chars */, (uint)seed, (uint)(seed >> 32));
+        }
+
+        // A span-based equivalent of String.GetHashCode(StringComparison). Uses the specified comparison type.
+        public static int GetHashCode(ReadOnlySpan<char> value, StringComparison comparisonType)
+        {
+            switch (comparisonType)
+            {
+                case StringComparison.CurrentCulture:
+                case StringComparison.CurrentCultureIgnoreCase:
+                    return CultureInfo.CurrentCulture.CompareInfo.GetHashCode(value, GetCaseCompareOfComparisonCulture(comparisonType));
+
+                case StringComparison.InvariantCulture:
+                case StringComparison.InvariantCultureIgnoreCase:
+                    return CultureInfo.InvariantCulture.CompareInfo.GetHashCode(value, GetCaseCompareOfComparisonCulture(comparisonType));
+
+                case StringComparison.Ordinal:
+                    return GetHashCode(value);
+
+                case StringComparison.OrdinalIgnoreCase:
+                    return GetHashCodeOrdinalIgnoreCase(value);
+
+                default:
+                    ThrowHelper.ThrowArgumentException(ExceptionResource.NotSupported_StringComparison, ExceptionArgument.comparisonType);
+                    Debug.Fail("Should not reach this point.");
+                    return default;
+            }
+        }
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        internal static int GetHashCodeOrdinalIgnoreCase(ReadOnlySpan<char> value)
+        {
+            ulong seed = Marvin.DefaultSeed;
+            return Marvin.ComputeHash32OrdinalIgnoreCase(ref MemoryMarshal.GetReference(value), value.Length /* in chars, not bytes */, (uint)seed, (uint)(seed >> 32));
         }
 
         // Use this if and only if 'Denial of Service' attacks are not a concern (i.e. never used for free-form user input),