Fix incorrect early exit in SortKey.Compare and seal type (#31779)
authorLevi Broderick <GrabYourPitchforks@users.noreply.github.com>
Thu, 6 Feb 2020 04:49:11 +0000 (20:49 -0800)
committerGitHub <noreply@github.com>
Thu, 6 Feb 2020 04:49:11 +0000 (20:49 -0800)
Also addresses some erroneous parameter checking in GetHashCode and fixes endianness issues in InvariantCreateSortKey

src/libraries/System.Globalization/tests/CompareInfo/CompareInfoTests.cs
src/libraries/System.Globalization/tests/Invariant/InvariantMode.cs
src/libraries/System.Private.CoreLib/src/System/Globalization/CompareInfo.Invariant.cs
src/libraries/System.Private.CoreLib/src/System/Globalization/CompareInfo.Unix.cs
src/libraries/System.Private.CoreLib/src/System/Globalization/CompareInfo.Windows.cs
src/libraries/System.Private.CoreLib/src/System/Globalization/CompareInfo.cs
src/libraries/System.Private.CoreLib/src/System/Globalization/SortKey.cs
src/libraries/System.Runtime/ref/System.Runtime.cs

index 88285e3..d319c69 100644 (file)
@@ -71,7 +71,7 @@ namespace System.Globalization.Tests
         {
             AssertExtensions.Throws<ArgumentNullException>("source", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode(null, CompareOptions.None));
 
-            AssertExtensions.Throws<ArgumentException>("options", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode("Test", CompareOptions.StringSort));
+            AssertExtensions.Throws<ArgumentException>("options", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode("Test", CompareOptions.OrdinalIgnoreCase | CompareOptions.IgnoreCase));
             AssertExtensions.Throws<ArgumentException>("options", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode("Test", CompareOptions.Ordinal | CompareOptions.IgnoreSymbols));
             AssertExtensions.Throws<ArgumentException>("options", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode("Test", (CompareOptions)(-1)));
         }
@@ -182,9 +182,15 @@ namespace System.Globalization.Tests
             yield return new object[] { s_invariantCompare, "\u30FC", "\u2010", ignoreKanaIgnoreWidthIgnoreCase, 1 };
 
             yield return new object[] { s_invariantCompare, "/", "\uFF0F", ignoreKanaIgnoreWidthIgnoreCase, 0 };
-            yield return new object[] { s_invariantCompare, "'", "\uFF07", ignoreKanaIgnoreWidthIgnoreCase, PlatformDetection.IsWindows7 ? -1 : 0};
             yield return new object[] { s_invariantCompare, "\"", "\uFF02", ignoreKanaIgnoreWidthIgnoreCase, 0 };
 
+            if (!PlatformDetection.IsWindows7)
+            {
+                // For the below string, LCMapStringEx and CompareStringEx on Windows 7 return inconsistent results.
+                // We'll only run this test case on Win8+ or on non-Windows machines.
+                yield return new object[] { s_invariantCompare, "'", "\uFF07", ignoreKanaIgnoreWidthIgnoreCase, 0 };
+            }
+
             yield return new object[] { s_invariantCompare, "\u3042", "\u30A1", CompareOptions.None, s_expectedHiraganaToKatakanaCompare };
             yield return new object[] { s_invariantCompare, "\u3042", "\u30A2", CompareOptions.None, s_expectedHiraganaToKatakanaCompare };
             yield return new object[] { s_invariantCompare, "\u3042", "\uFF71", CompareOptions.None, s_expectedHiraganaToKatakanaCompare };
@@ -349,12 +355,18 @@ namespace System.Globalization.Tests
 
         [Theory]
         [MemberData(nameof(SortKey_TestData))]
-        public void SortKeyTest(CompareInfo compareInfo, string string1, string string2, CompareOptions options, int expected)
+        public void SortKeyTest(CompareInfo compareInfo, string string1, string string2, CompareOptions options, int expectedSign)
         {
             SortKey sk1 = compareInfo.GetSortKey(string1, options);
             SortKey sk2 = compareInfo.GetSortKey(string2, options);
 
-            Assert.Equal(expected, SortKey.Compare(sk1, sk2));
+            Assert.Equal(expectedSign, Math.Sign(SortKey.Compare(sk1, sk2)));
+            Assert.Equal(expectedSign == 0, sk1.Equals(sk2));
+            Assert.Equal(Math.Sign(compareInfo.Compare(string1, string2, options)), Math.Sign(SortKey.Compare(sk1, sk2)));
+
+            Assert.Equal(compareInfo.GetHashCode(string1, options), sk1.GetHashCode());
+            Assert.Equal(compareInfo.GetHashCode(string2, options), sk2.GetHashCode());
+
             Assert.Equal(string1, sk1.OriginalString);
             Assert.Equal(string2, sk2.OriginalString);
         }
@@ -389,6 +401,9 @@ namespace System.Globalization.Tests
             Assert.Equal(sk4.GetHashCode(), sk5.GetHashCode());
             Assert.Equal(sk4.KeyData, sk5.KeyData);
 
+            Assert.False(sk1.Equals(null));
+            Assert.True(sk1.Equals(sk1));
+
             AssertExtensions.Throws<ArgumentNullException>("source", () => ci.GetSortKey(null));
             AssertExtensions.Throws<ArgumentException>("options", () => ci.GetSortKey(s1, CompareOptions.Ordinal));
         }
@@ -462,7 +477,7 @@ namespace System.Globalization.Tests
         [Fact]
         public void GetHashCode_Span_Invalid()
         {
-            AssertExtensions.Throws<ArgumentException>("options", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode("Test".AsSpan(), CompareOptions.StringSort));
+            AssertExtensions.Throws<ArgumentException>("options", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode("Test".AsSpan(), CompareOptions.OrdinalIgnoreCase | CompareOptions.IgnoreCase));
             AssertExtensions.Throws<ArgumentException>("options", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode("Test".AsSpan(), CompareOptions.Ordinal | CompareOptions.IgnoreSymbols));
             AssertExtensions.Throws<ArgumentException>("options", () => CultureInfo.InvariantCulture.CompareInfo.GetHashCode("Test".AsSpan(), (CompareOptions)(-1)));
         }
index 1a186bd..2089e84 100644 (file)
@@ -639,6 +639,25 @@ namespace System.Globalization.Tests
             Assert.NotEqual(0, SortKey.Compare(sortKeyForEmptyString, sortKeyForZeroWidthJoiner));
         }
 
+        [Theory]
+        [InlineData("", "", 0)]
+        [InlineData("", "not-empty", -1)]
+        [InlineData("not-empty", "", 1)]
+        [InlineData("hello", "hello", 0)]
+        [InlineData("prefix", "prefix-with-more-data", -1)]
+        [InlineData("prefix-with-more-data", "prefix", 1)]
+        [InlineData("e", "\u0115", -1)] // U+0115 = LATIN SMALL LETTER E WITH BREVE, tests endianness handling
+        public void TestSortKey_Compare_And_Equals(string value1, string value2, int expectedSign)
+        {
+            // These tests are in the "invariant" unit test project because we rely on Invariant mode
+            // copying the input data directly into the sort key.
+
+            SortKey sortKey1 = CultureInfo.InvariantCulture.CompareInfo.GetSortKey(value1);
+            SortKey sortKey2 = CultureInfo.InvariantCulture.CompareInfo.GetSortKey(value2);
+
+            Assert.Equal(expectedSign, Math.Sign(SortKey.Compare(sortKey1, sortKey2)));
+            Assert.Equal(expectedSign == 0, sortKey1.Equals(sortKey2));
+        }
 
         private static StringComparison GetStringComparison(CompareOptions options)
         {
index 92a962b..4b65d5f 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Buffers.Binary;
 using System.Diagnostics;
 using System.Runtime.InteropServices;
 
@@ -85,10 +86,10 @@ namespace System.Globalization
                 lastSourceStart = sourceCount - valueCount;
                 if (ignoreCase)
                 {
-                    char firstValueChar = InvariantToUpper(value[0]);
+                    char firstValueChar = InvariantCaseFold(value[0]);
                     for (ctrSource = 0; ctrSource <= lastSourceStart; ctrSource++)
                     {
-                        sourceChar = InvariantToUpper(source[ctrSource]);
+                        sourceChar = InvariantCaseFold(source[ctrSource]);
                         if (sourceChar != firstValueChar)
                         {
                             continue;
@@ -96,8 +97,8 @@ namespace System.Globalization
 
                         for (ctrValue = 1; ctrValue < valueCount; ctrValue++)
                         {
-                            sourceChar = InvariantToUpper(source[ctrSource + ctrValue]);
-                            valueChar = InvariantToUpper(value[ctrValue]);
+                            sourceChar = InvariantCaseFold(source[ctrSource + ctrValue]);
+                            valueChar = InvariantCaseFold(value[ctrValue]);
 
                             if (sourceChar != valueChar)
                             {
@@ -145,18 +146,18 @@ namespace System.Globalization
                 lastSourceStart = sourceCount - valueCount;
                 if (ignoreCase)
                 {
-                    char firstValueChar = InvariantToUpper(value[0]);
+                    char firstValueChar = InvariantCaseFold(value[0]);
                     for (ctrSource = lastSourceStart; ctrSource >= 0; ctrSource--)
                     {
-                        sourceChar = InvariantToUpper(source[ctrSource]);
+                        sourceChar = InvariantCaseFold(source[ctrSource]);
                         if (sourceChar != firstValueChar)
                         {
                             continue;
                         }
                         for (ctrValue = 1; ctrValue < valueCount; ctrValue++)
                         {
-                            sourceChar = InvariantToUpper(source[ctrSource + ctrValue]);
-                            valueChar = InvariantToUpper(value[ctrValue]);
+                            sourceChar = InvariantCaseFold(source[ctrSource + ctrValue]);
+                            valueChar = InvariantCaseFold(value[ctrValue]);
 
                             if (sourceChar != valueChar)
                             {
@@ -203,16 +204,21 @@ namespace System.Globalization
             return -1;
         }
 
-        private static char InvariantToUpper(char c)
+        private static char InvariantCaseFold(char c)
         {
+            // If we ever make Invariant mode support more than just simple ASCII-range case folding,
+            // then we should update this method to perform proper case folding instead of an
+            // uppercase conversion. For now it only understands the ASCII range and reflects all
+            // non-ASCII values unchanged.
+
             return (uint)(c - 'a') <= (uint)('z' - 'a') ? (char)(c - 0x20) : c;
         }
 
-        private unsafe SortKey InvariantCreateSortKey(string source, CompareOptions options)
+        private SortKey InvariantCreateSortKey(string source, CompareOptions options)
         {
             if (source == null) { throw new ArgumentNullException(nameof(source)); }
 
-            if ((options & ValidSortkeyCtorMaskOffFlags) != 0)
+            if ((options & ValidCompareMaskOffFlags) != 0)
             {
                 throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
             }
@@ -227,23 +233,41 @@ namespace System.Globalization
                 // In the invariant mode, all string comparisons are done as ordinal so when generating the sort keys we generate it according to this fact
                 keyData = new byte[source.Length * sizeof(char)];
 
-                fixed (char* pChar = source) fixed (byte* pByte = keyData)
+                if ((options & (CompareOptions.IgnoreCase | CompareOptions.OrdinalIgnoreCase)) != 0)
                 {
-                    if ((options & (CompareOptions.IgnoreCase | CompareOptions.OrdinalIgnoreCase)) != 0)
-                    {
-                        short* pShort = (short*)pByte;
-                        for (int i = 0; i < source.Length; i++)
-                        {
-                            pShort[i] = (short)InvariantToUpper(source[i]);
-                        }
-                    }
-                    else
-                    {
-                        Buffer.MemoryCopy(pChar, pByte, keyData.Length, keyData.Length);
-                    }
+                    InvariantCreateSortKeyOrdinalIgnoreCase(source, keyData);
+                }
+                else
+                {
+                    InvariantCreateSortKeyOrdinal(source, keyData);
                 }
             }
+
             return new SortKey(Name, source, options, keyData);
         }
+
+        private static void InvariantCreateSortKeyOrdinal(ReadOnlySpan<char> source, Span<byte> sortKey)
+        {
+            Debug.Assert(sortKey.Length >= source.Length * sizeof(char));
+
+            for (int i = 0; i < source.Length; i++)
+            {
+                // convert machine-endian to big-endian
+                BinaryPrimitives.WriteUInt16BigEndian(sortKey, (ushort)source[i]);
+                sortKey = sortKey.Slice(sizeof(ushort));
+            }
+        }
+
+        private static void InvariantCreateSortKeyOrdinalIgnoreCase(ReadOnlySpan<char> source, Span<byte> sortKey)
+        {
+            Debug.Assert(sortKey.Length >= source.Length * sizeof(char));
+
+            for (int i = 0; i < source.Length; i++)
+            {
+                // convert machine-endian to big-endian
+                BinaryPrimitives.WriteUInt16BigEndian(sortKey, (ushort)InvariantCaseFold(source[i]));
+                sortKey = sortKey.Slice(sizeof(ushort));
+            }
+        }
     }
 }
index 8f7a6e4..c30a491 100644 (file)
@@ -816,7 +816,7 @@ namespace System.Globalization
 
             if (source==null) { throw new ArgumentNullException(nameof(source)); }
 
-            if ((options & ValidSortkeyCtorMaskOffFlags) != 0)
+            if ((options & ValidCompareMaskOffFlags) != 0)
             {
                 throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
             }
index 6ae7a7d..9860810 100644 (file)
@@ -505,7 +505,7 @@ namespace System.Globalization
 
             if (source == null) { throw new ArgumentNullException(nameof(source)); }
 
-            if ((options & ValidSortkeyCtorMaskOffFlags) != 0)
+            if ((options & ValidCompareMaskOffFlags) != 0)
             {
                 throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
             }
index 50336fe..3971e9c 100644 (file)
@@ -24,21 +24,11 @@ namespace System.Globalization
             ~(CompareOptions.IgnoreCase | CompareOptions.IgnoreSymbols | CompareOptions.IgnoreNonSpace |
               CompareOptions.IgnoreWidth | CompareOptions.IgnoreKanaType);
 
-        // Mask used to check if Compare() has the right flags.
+        // Mask used to check if Compare() / GetHashCode(string) / GetSortKey has the right flags.
         private const CompareOptions ValidCompareMaskOffFlags =
             ~(CompareOptions.IgnoreCase | CompareOptions.IgnoreSymbols | CompareOptions.IgnoreNonSpace |
               CompareOptions.IgnoreWidth | CompareOptions.IgnoreKanaType | CompareOptions.StringSort);
 
-        // Mask used to check if GetHashCode(string) has the right flags.
-        private const CompareOptions ValidHashCodeOfStringMaskOffFlags =
-            ~(CompareOptions.IgnoreCase | CompareOptions.IgnoreSymbols | CompareOptions.IgnoreNonSpace |
-              CompareOptions.IgnoreWidth | CompareOptions.IgnoreKanaType);
-
-        // Mask used to check if we have the right flags.
-        private const CompareOptions ValidSortkeyCtorMaskOffFlags =
-            ~(CompareOptions.IgnoreCase | CompareOptions.IgnoreSymbols | CompareOptions.IgnoreNonSpace |
-              CompareOptions.IgnoreWidth | CompareOptions.IgnoreKanaType | CompareOptions.StringSort);
-
         // Cache the invariant CompareInfo
         internal static readonly CompareInfo Invariant = CultureInfo.InvariantCulture.CompareInfo;
 
@@ -1399,7 +1389,7 @@ namespace System.Globalization
             {
                 throw new ArgumentNullException(nameof(source));
             }
-            if ((options & ValidHashCodeOfStringMaskOffFlags) == 0)
+            if ((options & ValidCompareMaskOffFlags) == 0)
             {
                 // No unsupported flags are set - continue on with the regular logic
                 if (GlobalizationMode.Invariant)
@@ -1428,7 +1418,7 @@ namespace System.Globalization
 
         public int GetHashCode(ReadOnlySpan<char> source, CompareOptions options)
         {
-            if ((options & ValidHashCodeOfStringMaskOffFlags) == 0)
+            if ((options & ValidCompareMaskOffFlags) == 0)
             {
                 // No unsupported flags are set - continue on with the regular logic
                 if (GlobalizationMode.Invariant)
index 89b12a1..a3690c1 100644 (file)
@@ -9,7 +9,7 @@ namespace System.Globalization
     /// <summary>
     /// This class implements a set of methods for retrieving
     /// </summary>
-    public partial class SortKey
+    public sealed partial class SortKey
     {
         private readonly string _localeName;
         private readonly CompareOptions _options;
@@ -32,13 +32,13 @@ namespace System.Globalization
         /// Returns the original string used to create the current instance
         /// of SortKey.
         /// </summary>
-        public virtual string OriginalString => _string;
+        public string OriginalString => _string;
 
         /// <summary>
         /// Returns a byte array representing the current instance of the
         /// sort key.
         /// </summary>
-        public virtual byte[] KeyData => (byte[])_keyData.Clone();
+        public byte[] KeyData => (byte[])_keyData.Clone();
 
         /// <summary>
         /// Compares the two sort keys.  Returns 0 if the two sort keys are
@@ -62,44 +62,21 @@ namespace System.Globalization
             Debug.Assert(key1Data != null, "key1Data != null");
             Debug.Assert(key2Data != null, "key2Data != null");
 
-            if (key1Data.Length == 0)
-            {
-                if (key2Data.Length == 0)
-                {
-                    return 0;
-                }
-
-                return -1;
-            }
-            if (key2Data.Length == 0)
-            {
-                return 1;
-            }
-
-            int compLen = (key1Data.Length < key2Data.Length) ? key1Data.Length : key2Data.Length;
-            for (int i = 0; i < compLen; i++)
-            {
-                if (key1Data[i] > key2Data[i])
-                {
-                    return 1;
-                }
-                if (key1Data[i] < key2Data[i])
-                {
-                    return -1;
-                }
-            }
+            // SortKey comparisons are done as an ordinal comparison by the raw sort key bytes.
 
-            return 0;
+            return new ReadOnlySpan<byte>(key1Data).SequenceCompareTo(key2Data);
         }
 
         public override bool Equals(object? value)
         {
-            return value is SortKey otherSortKey && Compare(this, otherSortKey) == 0;
+            return value is SortKey other
+                && new ReadOnlySpan<byte>(_keyData).SequenceEqual(other._keyData);
         }
 
         public override int GetHashCode()
         {
-            return CompareInfo.GetCompareInfo(_localeName).GetHashCode(_string, _options);
+            // keep this in sync with CompareInfo.GetHashCodeOfString
+            return Marvin.ComputeHash32(_keyData, Marvin.DefaultSeed);
         }
 
         public override string ToString()
index af0088d..f860375 100644 (file)
@@ -5086,11 +5086,11 @@ namespace System.Globalization
         public override int GetHashCode() { throw null; }
         public override string ToString() { throw null; }
     }
-    public partial class SortKey
+    public sealed partial class SortKey
     {
         internal SortKey() { }
-        public virtual byte[] KeyData { get { throw null; } }
-        public virtual string OriginalString { get { throw null; } }
+        public byte[] KeyData { get { throw null; } }
+        public string OriginalString { get { throw null; } }
         public static int Compare(System.Globalization.SortKey sortkey1, System.Globalization.SortKey sortkey2) { throw null; }
         public override bool Equals(object? value) { throw null; }
         public override int GetHashCode() { throw null; }