Improve vectorization of IndexOf(chars, StringComparison.OrdinalIgnoreCase) (#85437)
authorStephen Toub <stoub@microsoft.com>
Mon, 1 May 2023 14:36:52 +0000 (10:36 -0400)
committerGitHub <noreply@github.com>
Mon, 1 May 2023 14:36:52 +0000 (10:36 -0400)
* Improve vectorization of IndexOf(chars, StringComparison.OrdinalIgnoreCase)

Use the same general "Algorithm 1: Generic SIMD" that we do for StringComparison.Ordinal, adapter for OrdinalIgnoreCase.

* Fix duplicate local

src/libraries/System.Private.CoreLib/src/System/Globalization/Ordinal.cs

index 36fa86b..6a89be1 100644 (file)
@@ -2,10 +2,12 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Diagnostics;
-using System.Text.Unicode;
+using System.Numerics;
 using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
 using System.Runtime.Intrinsics;
+using System.Runtime.Intrinsics.X86;
+using System.Text.Unicode;
 
 namespace System.Globalization
 {
@@ -295,7 +297,6 @@ namespace System.Globalization
                 // A non-linguistic search compares chars directly against one another, so large
                 // target strings can never be found inside small search spaces. This check also
                 // handles empty 'source' spans.
-
                 return -1;
             }
 
@@ -309,25 +310,38 @@ namespace System.Globalization
                 return CompareInfo.NlsIndexOfOrdinalCore(source, value, ignoreCase: true, fromBeginning: true);
             }
 
-            // If value starts with an ASCII char, we can use a vectorized path
+            // If value doesn't start with ASCII, fall back to a non-vectorized non-ASCII friendly version.
             ref char valueRef = ref MemoryMarshal.GetReference(value);
             char valueChar = valueRef;
-
             if (!char.IsAscii(valueChar))
             {
-                // Fallback to a more non-ASCII friendly version
                 return OrdinalCasing.IndexOf(source, value);
             }
 
             // Hoist some expressions from the loop
             int valueTailLength = value.Length - 1;
-            int searchSpaceLength = source.Length - valueTailLength;
+            int searchSpaceMinusValueTailLength = source.Length - valueTailLength;
             ref char searchSpace = ref MemoryMarshal.GetReference(source);
             char valueCharU = default;
             char valueCharL = default;
             nint offset = 0;
             bool isLetter = false;
 
+            // If the input is long enough and the value ends with ASCII, we can take a special vectorized
+            // path that compares both the beginning and the end at the same time.
+            if (Vector128.IsHardwareAccelerated && searchSpaceMinusValueTailLength >= Vector128<ushort>.Count)
+            {
+                valueCharU = Unsafe.Add(ref valueRef, valueTailLength);
+                if (char.IsAscii(valueCharU))
+                {
+                    goto SearchTwoChars;
+                }
+            }
+
+            // We're searching for the first character and it's known to be ASCII. If it's not a letter,
+            // then IgnoreCase doesn't impact what it matches and we just need to do a normal search
+            // for that single character. If it is a letter, then we need to search for both its upper
+            // and lower-case variants.
             if (char.IsAsciiLetter(valueChar))
             {
                 valueCharU = (char)(valueChar & ~0x20);
@@ -340,16 +354,16 @@ namespace System.Globalization
                 // Do a quick search for the first element of "value".
                 int relativeIndex = isLetter ?
                     PackedSpanHelpers.PackedIndexOfIsSupported
-                        ? PackedSpanHelpers.IndexOfAny(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceLength)
-                        : SpanHelpers.IndexOfAnyChar(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceLength) :
-                    SpanHelpers.IndexOfChar(ref Unsafe.Add(ref searchSpace, offset), valueChar, searchSpaceLength);
+                        ? PackedSpanHelpers.IndexOfAny(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceMinusValueTailLength)
+                        : SpanHelpers.IndexOfAnyChar(ref Unsafe.Add(ref searchSpace, offset), valueCharU, valueCharL, searchSpaceMinusValueTailLength) :
+                    SpanHelpers.IndexOfChar(ref Unsafe.Add(ref searchSpace, offset), valueChar, searchSpaceMinusValueTailLength);
                 if (relativeIndex < 0)
                 {
                     break;
                 }
 
-                searchSpaceLength -= relativeIndex;
-                if (searchSpaceLength <= 0)
+                searchSpaceMinusValueTailLength -= relativeIndex;
+                if (searchSpaceMinusValueTailLength <= 0)
                 {
                     break;
                 }
@@ -364,12 +378,185 @@ namespace System.Globalization
                     return (int)offset;  // The tail matched. Return a successful find.
                 }
 
-                searchSpaceLength--;
+                searchSpaceMinusValueTailLength--;
                 offset++;
             }
-            while (searchSpaceLength > 0);
+            while (searchSpaceMinusValueTailLength > 0);
 
             return -1;
+
+        // Based on SpanHelpers.IndexOf(ref char, int, ref char, int), which was in turn based on
+        // http://0x80.pl/articles/simd-strfind.html#algorithm-1-generic-simd. This version has additional
+        // modifications to support case-insensitive searches.
+        SearchTwoChars:
+            // Both the first character in value (valueChar) and the last character in value (valueCharU) are ASCII. Get their lowercase variants.
+            valueChar = (char)(valueChar | 0x20);
+            valueCharU = (char)(valueCharU | 0x20);
+
+            // The search is more efficient if the two characters being searched for are different. As long as they are equal, walk backwards
+            // from the last character in the search value until we find a character that's different. Since we're dealing with IgnoreCase,
+            // we compare the lowercase variants, as that's what we'll be comparing against in the main loop.
+            nint ch1ch2Distance = valueTailLength;
+            while (valueCharU == valueChar && ch1ch2Distance > 1)
+            {
+                char tmp = Unsafe.Add(ref valueRef, ch1ch2Distance - 1);
+                if (!char.IsAscii(tmp))
+                {
+                    break;
+                }
+                --ch1ch2Distance;
+                valueCharU = (char)(tmp | 0x20);
+            }
+
+            // Use Vector256 if the input is long enough.
+            if (Vector256.IsHardwareAccelerated && searchSpaceMinusValueTailLength - Vector256<ushort>.Count >= 0)
+            {
+                // Create a vector for each of the lowercase ASCII characters we're searching for.
+                Vector256<ushort> ch1 = Vector256.Create((ushort)valueChar);
+                Vector256<ushort> ch2 = Vector256.Create((ushort)valueCharU);
+
+                nint searchSpaceMinusValueTailLengthAndVector = searchSpaceMinusValueTailLength - (nint)Vector256<ushort>.Count;
+                do
+                {
+                    // Make sure we don't go out of bounds.
+                    Debug.Assert(offset + ch1ch2Distance + Vector256<ushort>.Count <= source.Length);
+
+                    // Load a vector from the current search space offset and another from the offset plus the distance between the two characters.
+                    // For each, | with 0x20 so that letters are lowercased, then & those together to get a mask. If the mask is all zeros, there
+                    // was no match.  If it wasn't, we have to do more work to check for a match.
+                    Vector256<ushort> cmpCh2 = Vector256.Equals(ch2, Vector256.BitwiseOr(Vector256.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)), Vector256.Create((ushort)0x20)));
+                    Vector256<ushort> cmpCh1 = Vector256.Equals(ch1, Vector256.BitwiseOr(Vector256.LoadUnsafe(ref searchSpace, (nuint)offset), Vector256.Create((ushort)0x20)));
+                    Vector256<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();
+                    if (cmpAnd != Vector256<byte>.Zero)
+                    {
+                        goto CandidateFound;
+                    }
+
+                LoopFooter:
+                    // No match. Advance to the next vector.
+                    offset += Vector256<ushort>.Count;
+
+                    // If we've reached the end of the search space, bail.
+                    if (offset == searchSpaceMinusValueTailLength)
+                    {
+                        return -1;
+                    }
+
+                    // If we're within a vector's length of the end of the search space, adjust the offset
+                    // to point to the last vector so that our next iteration will process it.
+                    if (offset > searchSpaceMinusValueTailLengthAndVector)
+                    {
+                        offset = searchSpaceMinusValueTailLengthAndVector;
+                    }
+
+                    continue;
+
+                CandidateFound:
+                    // Possible matches at the current location. Extract the bits for each element.
+                    // For each set bits, we'll check if it's a match at that location.
+                    uint mask = cmpAnd.ExtractMostSignificantBits();
+                    do
+                    {
+                        // Do a full IgnoreCase equality comparison. SpanHelpers.IndexOf skips comparing the two characters in some cases,
+                        // but we don't actually know that the two characters are equal, since we compared with | 0x20. So we just compare
+                        // the full string always.
+                        int bitPos = BitOperations.TrailingZeroCount(mask);
+                        nint charPos = (nint)((uint)bitPos / 2); // div by 2 (shr) because we work with 2-byte chars
+                        if (EqualsIgnoreCase(ref Unsafe.Add(ref searchSpace, offset + charPos), ref valueRef, value.Length))
+                        {
+                            // Match! Return the index.
+                            return (int)(offset + charPos);
+                        }
+
+                        // Clear the two lowest set bits in the mask. If there are no more set bits, we're done.
+                        // If any remain, we loop around to do the next comparison.
+                        if (Bmi1.IsSupported)
+                        {
+                            mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
+                        }
+                        else
+                        {
+                            mask &= ~(uint)(0b11 << bitPos);
+                        }
+                    } while (mask != 0);
+                    goto LoopFooter;
+
+                } while (true);
+            }
+            else // 128bit vector path (SSE2 or AdvSimd)
+            {
+                // Create a vector for each of the lowercase ASCII characters we're searching for.
+                Vector128<ushort> ch1 = Vector128.Create((ushort)valueChar);
+                Vector128<ushort> ch2 = Vector128.Create((ushort)valueCharU);
+
+                nint searchSpaceMinusValueTailLengthAndVector = searchSpaceMinusValueTailLength - (nint)Vector128<ushort>.Count;
+                do
+                {
+                    // Make sure we don't go out of bounds.
+                    Debug.Assert(offset + ch1ch2Distance + Vector128<ushort>.Count <= source.Length);
+
+                    // Load a vector from the current search space offset and another from the offset plus the distance between the two characters.
+                    // For each, | with 0x20 so that letters are lowercased, then & those together to get a mask. If the mask is all zeros, there
+                    // was no match.  If it wasn't, we have to do more work to check for a match.
+                    Vector128<ushort> cmpCh2 = Vector128.Equals(ch2, Vector128.BitwiseOr(Vector128.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)), Vector128.Create((ushort)0x20)));
+                    Vector128<ushort> cmpCh1 = Vector128.Equals(ch1, Vector128.BitwiseOr(Vector128.LoadUnsafe(ref searchSpace, (nuint)offset), Vector128.Create((ushort)0x20)));
+                    Vector128<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();
+                    if (cmpAnd != Vector128<byte>.Zero)
+                    {
+                        goto CandidateFound;
+                    }
+
+                LoopFooter:
+                    // No match. Advance to the next vector.
+                    offset += Vector128<ushort>.Count;
+
+                    // If we've reached the end of the search space, bail.
+                    if (offset == searchSpaceMinusValueTailLength)
+                    {
+                        return -1;
+                    }
+
+                    // If we're within a vector's length of the end of the search space, adjust the offset
+                    // to point to the last vector so that our next iteration will process it.
+                    if (offset > searchSpaceMinusValueTailLengthAndVector)
+                    {
+                        offset = searchSpaceMinusValueTailLengthAndVector;
+                    }
+
+                    continue;
+
+                CandidateFound:
+                    // Possible matches at the current location. Extract the bits for each element.
+                    // For each set bits, we'll check if it's a match at that location.
+                    uint mask = cmpAnd.ExtractMostSignificantBits();
+                    do
+                    {
+                        // Do a full IgnoreCase equality comparison. SpanHelpers.IndexOf skips comparing the two characters in some cases,
+                        // but we don't actually know that the two characters are equal, since we compared with | 0x20. So we just compare
+                        // the full string always.
+                        int bitPos = BitOperations.TrailingZeroCount(mask);
+                        int charPos = (int)((uint)bitPos / 2); // div by 2 (shr) because we work with 2-byte chars
+                        if (EqualsIgnoreCase(ref Unsafe.Add(ref searchSpace, offset + charPos), ref valueRef, value.Length))
+                        {
+                            // Match! Return the index.
+                            return (int)(offset + charPos);
+                        }
+
+                        // Clear the two lowest set bits in the mask. If there are no more set bits, we're done.
+                        // If any remain, we loop around to do the next comparison.
+                        if (Bmi1.IsSupported)
+                        {
+                            mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
+                        }
+                        else
+                        {
+                            mask &= ~(uint)(0b11 << bitPos);
+                        }
+                    } while (mask != 0);
+                    goto LoopFooter;
+
+                } while (true);
+            }
         }
 
         internal static int LastIndexOf(string source, string value, int startIndex, int count)