Improve performance of String.Equals(..., OrdinalIgnoreCase) (dotnet/coreclr#20734)
authorLevi Broderick <GrabYourPitchforks@users.noreply.github.com>
Sat, 3 Nov 2018 00:15:47 +0000 (17:15 -0700)
committerGitHub <noreply@github.com>
Sat, 3 Nov 2018 00:15:47 +0000 (17:15 -0700)
- Tries to consume multiple chars in parallel when possible
- Didn't vectorize because inputs to this function are generally fairly small
- Moved static GlobalizationMode lookup out of hot path
- Removed indirection so that StringComparer now calls directly into workhorse routine

Commit migrated from https://github.com/dotnet/coreclr/commit/dd6c69022fbbe4111551f76bb9f9804538cf1e2e

src/libraries/System.Private.CoreLib/src/System/Globalization/CompareInfo.cs
src/libraries/System.Private.CoreLib/src/System/StringComparer.cs
src/libraries/System.Private.CoreLib/src/System/Text/Utf16Utility.cs

index a23de3a..ab1eecf 100644 (file)
@@ -16,7 +16,7 @@ using System.Reflection;
 using System.Diagnostics;
 using System.Runtime.InteropServices;
 using System.Runtime.Serialization;
-using System.Buffers;
+using System.Text;
 using Internal.Runtime.CompilerServices;
 
 namespace System.Globalization
@@ -595,37 +595,143 @@ namespace System.Globalization
         }
 
 
-        internal static bool EqualsOrdinalIgnoreCase(ref char strA, ref char strB, int length)
+        internal static bool EqualsOrdinalIgnoreCase(ref char charA, ref char charB, int length)
         {
-            ref char charA = ref strA;
-            ref char charB = ref strB;
+            IntPtr byteOffset = IntPtr.Zero;
 
-            // in InvariantMode we support all range and not only the ascii characters.
-            char maxChar = (GlobalizationMode.Invariant ? (char)0xFFFF : (char)0x7F);
+#if BIT64
+            // Read 4 chars (64 bits) at a time from each string
+            while ((uint)length >= 4)
+            {
+                ulong valueA = Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<char, byte>(ref Unsafe.AddByteOffset(ref charA, byteOffset)));
+                ulong valueB = Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<char, byte>(ref Unsafe.AddByteOffset(ref charB, byteOffset)));
 
-            while (length != 0 && charA <= maxChar && charB <= maxChar)
+                // A 32-bit test - even with the bit-twiddling here - is more efficient than a 64-bit test.
+                ulong temp = valueA | valueB;
+                if (!Utf16Utility.AllCharsInUInt32AreAscii((uint)temp | (uint)(temp >> 32)))
+                {
+                    goto NonAscii; // one of the inputs contains non-ASCII data
+                }
+
+                // Generally, the caller has likely performed a first-pass check that the input strings
+                // are likely equal. Consider a dictionary which computes the hash code of its key before
+                // performing a proper deep equality check of the string contents. We want to optimize for
+                // the case where the equality check is likely to succeed, which means that we want to avoid
+                // branching within this loop unless we're about to exit the loop, either due to failure or
+                // due to us running out of input data.
+
+                if (!Utf16Utility.UInt64OrdinalIgnoreCaseAscii(valueA, valueB))
+                {
+                    return false;
+                }
+
+                byteOffset += 8;
+                length -= 4;
+            }
+#endif
+
+            // Read 2 chars (32 bits) at a time from each string
+#if BIT64
+            if ((uint)length >= 2)
+#else
+            while ((uint)length >= 2)
+#endif
             {
-                // Ordinal equals or lowercase equals if the result ends up in the a-z range 
-                if (charA == charB ||
-                    ((charA | 0x20) == (charB | 0x20) &&
-                        (uint)((charA | 0x20) - 'a') <= (uint)('z' - 'a')))
+                uint valueA = Unsafe.ReadUnaligned<uint>(ref Unsafe.As<char, byte>(ref Unsafe.AddByteOffset(ref charA, byteOffset)));
+                uint valueB = Unsafe.ReadUnaligned<uint>(ref Unsafe.As<char, byte>(ref Unsafe.AddByteOffset(ref charB, byteOffset)));
+
+                if (!Utf16Utility.AllCharsInUInt32AreAscii(valueA | valueB))
                 {
-                    length--;
-                    charA = ref Unsafe.Add(ref charA, 1);
-                    charB = ref Unsafe.Add(ref charB, 1);
+                    goto NonAscii; // one of the inputs contains non-ASCII data
                 }
-                else
+
+                // Generally, the caller has likely performed a first-pass check that the input strings
+                // are likely equal. Consider a dictionary which computes the hash code of its key before
+                // performing a proper deep equality check of the string contents. We want to optimize for
+                // the case where the equality check is likely to succeed, which means that we want to avoid
+                // branching within this loop unless we're about to exit the loop, either due to failure or
+                // due to us running out of input data.
+
+                if (!Utf16Utility.UInt32OrdinalIgnoreCaseAscii(valueA, valueB))
                 {
                     return false;
                 }
+
+                byteOffset += 4;
+                length -= 2;
             }
 
-            if (length == 0)
-                return true;
+            if (length != 0)
+            {
+                Debug.Assert(length == 1);
 
-            Debug.Assert(!GlobalizationMode.Invariant);
+                uint valueA = Unsafe.AddByteOffset(ref charA, byteOffset);
+                uint valueB = Unsafe.AddByteOffset(ref charB, byteOffset);
 
-            return CompareStringOrdinalIgnoreCase(ref charA, length, ref charB, length) == 0;
+                if ((valueA | valueB) > 0x7Fu)
+                {
+                    goto NonAscii; // one of the inputs contains non-ASCII data
+                }
+
+                if (valueA == valueB)
+                {
+                    return true; // exact match
+                }
+
+                valueA |= 0x20u;
+                if ((uint)(valueA - 'a') > (uint)('z' - 'a'))
+                {
+                    return false; // not exact match, and first input isn't in [A-Za-z]
+                }
+
+                // The ternary operator below seems redundant but helps RyuJIT generate more optimal code.
+                // See https://github.com/dotnet/coreclr/issues/914.
+                return (valueA == (valueB | 0x20u)) ? true : false;
+            }
+
+            Debug.Assert(length == 0);
+            return true;
+
+        NonAscii:
+            // The non-ASCII case is factored out into its own helper method so that the JIT
+            // doesn't need to emit a complex prolog for its caller (this method).
+            return EqualsOrdinalIgnoreCaseNonAscii(ref Unsafe.AddByteOffset(ref charA, byteOffset), ref Unsafe.AddByteOffset(ref charB, byteOffset), length);
+        }
+
+        private static bool EqualsOrdinalIgnoreCaseNonAscii(ref char charA, ref char charB, int length)
+        {
+            if (!GlobalizationMode.Invariant)
+            {
+                return CompareStringOrdinalIgnoreCase(ref charA, length, ref charB, length) == 0;
+            }
+            else
+            {
+                // If we don't have localization tables to consult, we'll still perform a case-insensitive
+                // check for ASCII characters, but if we see anything outside the ASCII range we'll immediately
+                // fail if it doesn't have true bitwise equality.
+
+                IntPtr byteOffset = IntPtr.Zero;
+                while (length != 0)
+                {
+                    // Ordinal equals or lowercase equals if the result ends up in the a-z range 
+                    uint valueA = Unsafe.AddByteOffset(ref charA, byteOffset);
+                    uint valueB = Unsafe.AddByteOffset(ref charB, byteOffset);
+
+                    if (valueA == valueB ||
+                        ((valueA | 0x20) == (valueB | 0x20) &&
+                            (uint)((valueA | 0x20) - 'a') <= (uint)('z' - 'a')))
+                    {
+                        byteOffset += 2;
+                        length--;
+                    }
+                    else
+                    {
+                        return false;
+                    }
+                }
+
+                return true;
+            }
         }
 
         ////////////////////////////////////////////////////////////////////////
index cec4b89..e4378c5 100644 (file)
@@ -293,7 +293,7 @@ namespace System
                 {
                     return false;
                 }
-                return string.Equals(x, y, StringComparison.OrdinalIgnoreCase);
+                return CompareInfo.EqualsOrdinalIgnoreCase(ref x.GetRawStringData(), ref y.GetRawStringData(), x.Length);
             }
             return x.Equals(y);
         }
@@ -367,7 +367,25 @@ namespace System
 
         public override int Compare(string x, string y) => string.Compare(x, y, StringComparison.OrdinalIgnoreCase);
 
-        public override bool Equals(string x, string y) => string.Equals(x, y, StringComparison.OrdinalIgnoreCase);
+        public override bool Equals(string x, string y)
+        {
+            if (ReferenceEquals(x, y))
+            {
+                return true;
+            }
+
+            if (x is null || y is null)
+            {
+                return false;
+            }
+
+            if (x.Length != y.Length)
+            {
+                return false;
+            }
+
+            return CompareInfo.EqualsOrdinalIgnoreCase(ref x.GetRawStringData(), ref y.GetRawStringData(), x.Length);
+        }
 
         public override int GetHashCode(string obj)
         {
index 821528d..bed3905 100644 (file)
@@ -15,7 +15,16 @@ namespace System.Text
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         internal static bool AllCharsInUInt32AreAscii(uint value)
         {
-            return (value & ~0x007F007Fu) == 0;
+            return (value & ~0x007F_007Fu) == 0;
+        }
+
+        /// <summary>
+        /// Returns true iff the UInt64 represents four ASCII UTF-16 characters in machine endianness.
+        /// </summary>
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        internal static bool AllCharsInUInt64AreAscii(ulong value)
+        {
+            return (value & ~0x007F_007F_007F_007Ful) == 0;
         }
 
         /// <summary>
@@ -33,16 +42,16 @@ namespace System.Text
             Debug.Assert(AllCharsInUInt32AreAscii(value));
 
             // the 0x80 bit of each word of 'lowerIndicator' will be set iff the word has value >= 'A'
-            uint lowerIndicator = value + 0x00800080u - 0x00410041u;
+            uint lowerIndicator = value + 0x0080_0080u - 0x0041_0041u;
 
             // the 0x80 bit of each word of 'upperIndicator' will be set iff the word has value > 'Z'
-            uint upperIndicator = value + 0x00800080u - 0x005B005Bu;
+            uint upperIndicator = value + 0x0080_0080u - 0x005B_005Bu;
 
             // the 0x80 bit of each word of 'combinedIndicator' will be set iff the word has value >= 'A' and <= 'Z'
             uint combinedIndicator = (lowerIndicator ^ upperIndicator);
 
             // the 0x20 bit of each word of 'mask' will be set iff the word has value >= 'A' and <= 'Z'
-            uint mask = (combinedIndicator & 0x00800080u) >> 2;
+            uint mask = (combinedIndicator & 0x0080_0080u) >> 2;
 
             return value ^ mask; // bit flip uppercase letters [A-Z] => [a-z]
         }
@@ -62,16 +71,16 @@ namespace System.Text
             Debug.Assert(AllCharsInUInt32AreAscii(value));
 
             // the 0x80 bit of each word of 'lowerIndicator' will be set iff the word has value >= 'a'
-            uint lowerIndicator = value + 0x00800080u - 0x00610061u;
+            uint lowerIndicator = value + 0x0080_0080u - 0x0061_0061u;
 
             // the 0x80 bit of each word of 'upperIndicator' will be set iff the word has value > 'z'
-            uint upperIndicator = value + 0x00800080u - 0x007B007Bu;
+            uint upperIndicator = value + 0x0080_0080u - 0x007B_007Bu;
 
             // the 0x80 bit of each word of 'combinedIndicator' will be set iff the word has value >= 'a' and <= 'z'
             uint combinedIndicator = (lowerIndicator ^ upperIndicator);
 
             // the 0x20 bit of each word of 'mask' will be set iff the word has value >= 'a' and <= 'z'
-            uint mask = (combinedIndicator & 0x00800080u) >> 2;
+            uint mask = (combinedIndicator & 0x0080_0080u) >> 2;
 
             return value ^ mask; // bit flip lowercase letters [a-z] => [A-Z]
         }
@@ -90,15 +99,15 @@ namespace System.Text
             Debug.Assert(AllCharsInUInt32AreAscii(value));
 
             // the 0x80 bit of each word of 'lowerIndicator' will be set iff the word has value >= 'a'
-            uint lowerIndicator = value + 0x00800080u - 0x00610061u;
+            uint lowerIndicator = value + 0x0080_0080u - 0x0061_0061u;
 
             // the 0x80 bit of each word of 'upperIndicator' will be set iff the word has value > 'z'
-            uint upperIndicator = value + 0x00800080u - 0x007B007Bu;
+            uint upperIndicator = value + 0x0080_0080u - 0x007B_007Bu;
 
             // the 0x80 bit of each word of 'combinedIndicator' will be set iff the word has value >= 'a' and <= 'z'
             uint combinedIndicator = (lowerIndicator ^ upperIndicator);
 
-            return (combinedIndicator & 0x00800080u) != 0;
+            return (combinedIndicator & 0x0080_0080u) != 0;
         }
 
         /// <summary>
@@ -115,15 +124,92 @@ namespace System.Text
             Debug.Assert(AllCharsInUInt32AreAscii(value));
 
             // the 0x80 bit of each word of 'lowerIndicator' will be set iff the word has value >= 'A'
-            uint lowerIndicator = value + 0x00800080u - 0x00410041u;
+            uint lowerIndicator = value + 0x0080_0080u - 0x0041_0041u;
 
             // the 0x80 bit of each word of 'upperIndicator' will be set iff the word has value > 'Z'
-            uint upperIndicator = value + 0x00800080u - 0x005B005Bu;
+            uint upperIndicator = value + 0x0080_0080u - 0x005B_005Bu;
 
             // the 0x80 bit of each word of 'combinedIndicator' will be set iff the word has value >= 'A' and <= 'Z'
             uint combinedIndicator = (lowerIndicator ^ upperIndicator);
 
-            return (combinedIndicator & 0x00800080u) != 0;
+            return (combinedIndicator & 0x0080_0080u) != 0;
+        }
+
+        /// <summary>
+        /// Given two UInt32s that represent two ASCII UTF-16 characters each, returns true iff
+        /// the two inputs are equal using an ordinal case-insensitive comparison.
+        /// </summary>
+        /// <remarks>
+        /// This is a branchless implementation.
+        /// </remarks>
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        internal static bool UInt32OrdinalIgnoreCaseAscii(uint valueA, uint valueB)
+        {
+            // ASSUMPTION: Caller has validated that input values are ASCII.
+            Debug.Assert(AllCharsInUInt32AreAscii(valueA));
+            Debug.Assert(AllCharsInUInt32AreAscii(valueB));
+
+            // a mask of all bits which are different between A and B
+            uint differentBits = valueA ^ valueB;
+
+            // the 0x80 bit of each word of 'lowerIndicator' will be set iff the word has value < 'A'
+            uint lowerIndicator = valueA + 0x0100_0100u - 0x0041_0041u;
+
+            // the 0x80 bit of each word of 'upperIndicator' will be set iff (word | 0x20) has value > 'z'
+            uint upperIndicator = (valueA | 0x0020_0020u) + 0x0080_0080u - 0x007B_007Bu;
+
+            // the 0x80 bit of each word of 'combinedIndicator' will be set iff the word is *not* [A-Za-z]
+            uint combinedIndicator = lowerIndicator | upperIndicator;
+
+            // Shift all the 0x80 bits of 'combinedIndicator' into the 0x20 positions, then set all bits
+            // aside from 0x20. This creates a mask where all bits are set *except* for the 0x20 bits
+            // which correspond to alpha chars (either lower or upper). For these alpha chars only, the
+            // 0x20 bit is allowed to differ between the two input values. Every other char must be an
+            // exact bitwise match between the two input values. In other words, (valueA & mask) will
+            // convert valueA to uppercase, so (valueA & mask) == (valueB & mask) answers "is the uppercase
+            // form of valueA equal to the uppercase form of valueB?" (Technically if valueA has an alpha
+            // char in the same position as a non-alpha char in valueB, or vice versa, this operation will
+            // result in nonsense, but it'll still compute as inequal regardless, which is what we want ultimately.)
+            // The line below is a more efficient way of doing the same check taking advantage of the XOR
+            // computation we performed at the beginning of the method.
+
+            return (((combinedIndicator >> 2) | ~0x0020_0020u) & differentBits) == 0;
+        }
+
+        /// <summary>
+        /// Given two UInt64s that represent four ASCII UTF-16 characters each, returns true iff
+        /// the two inputs are equal using an ordinal case-insensitive comparison.
+        /// </summary>
+        /// <remarks>
+        /// This is a branchless implementation.
+        /// </remarks>
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        internal static bool UInt64OrdinalIgnoreCaseAscii(ulong valueA, ulong valueB)
+        {
+            // ASSUMPTION: Caller has validated that input values are ASCII.
+            Debug.Assert(AllCharsInUInt64AreAscii(valueA));
+            Debug.Assert(AllCharsInUInt64AreAscii(valueB));
+
+            // the 0x80 bit of each word of 'lowerIndicator' will be set iff the word has value >= 'A'
+            ulong lowerIndicator = valueA + 0x0080_0080_0080_0080ul - 0x0041_0041_0041_0041ul;
+
+            // the 0x80 bit of each word of 'upperIndicator' will be set iff (word | 0x20) has value <= 'z'
+            ulong upperIndicator = (valueA | 0x0020_0020_0020_0020ul) + 0x0100_0100_0100_0100ul - 0x007B_007B_007B_007Bul;
+
+            // the 0x20 bit of each word of 'combinedIndicator' will be set iff the word is [A-Za-z]
+            ulong combinedIndicator = (0x0080_0080_0080_0080ul & lowerIndicator & upperIndicator) >> 2;
+
+            // Convert both values to lowercase (using the combined indicator from the first value)
+            // and compare for equality. It's possible that the first value will contain an alpha character
+            // where the second value doesn't (or vice versa), and applying the combined indicator will
+            // create nonsensical data, but the comparison would have failed anyway in this case so it's
+            // a safe operation to perform.
+            //
+            // This 64-bit method is similar to the 32-bit method, but it performs the equivalent of convert-to-
+            // lowercase-then-compare rather than convert-to-uppercase-and-compare. This particular operation
+            // happens to be faster on x64.
+
+            return (valueA | combinedIndicator) == (valueB | combinedIndicator);
         }
     }
 }