Speedup .SequenceCompareTo(byte, ...) (#22127)
authorBen Adams <thundercat@illyriad.co.uk>
Fri, 25 Jan 2019 02:48:43 +0000 (03:48 +0100)
committerJan Kotas <jkotas@microsoft.com>
Fri, 25 Jan 2019 02:48:43 +0000 (18:48 -0800)
* Speedup .SequenceCompareTo(byte, ...)

* Rename jump location

* Better annotations for clarity

src/System.Private.CoreLib/shared/System/SpanHelpers.Byte.cs

index 63a564f..3062a40 100644 (file)
@@ -276,13 +276,16 @@ namespace System
                         {
                             Vector256<byte> search = LoadVector256(ref searchSpace, offset);
                             int matches = Avx2.MoveMask(Avx2.CompareEqual(values, search));
+                            // Note that MoveMask has converted the equal vector elements into a set of bit flags,
+                            // So the bit position in 'matches' corresponds to the element offset.
                             if (matches == 0)
                             {
+                                // Zero flags set so no matches
                                 offset += Vector256<byte>.Count;
                                 continue;
                             }
 
-                            // Find offset of first match
+                            // Find bitflag offset of first match and add to current offset
                             return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                         } while ((byte*)nLength > (byte*)offset);
                     }
@@ -293,14 +296,16 @@ namespace System
                         Vector128<byte> values = Vector128.Create(value);
                         Vector128<byte> search = LoadVector128(ref searchSpace, offset);
 
+                        // Same method as above
                         int matches = Sse2.MoveMask(Sse2.CompareEqual(values, search));
                         if (matches == 0)
                         {
+                            // Zero flags set so no matches
                             offset += Vector128<byte>.Count;
                         }
                         else
                         {
-                            // Find offset of first match
+                            // Find bitflag offset of first match and add to current offset
                             return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                         }
                     }
@@ -323,14 +328,16 @@ namespace System
                     {
                         Vector128<byte> search = LoadVector128(ref searchSpace, offset);
 
+                        // Same method as above
                         int matches = Sse2.MoveMask(Sse2.CompareEqual(values, search));
                         if (matches == 0)
                         {
+                            // Zero flags set so no matches
                             offset += Vector128<byte>.Count;
                             continue;
                         }
 
-                        // Find offset of first match
+                        // Find bitflag offset of first match and add to current offset
                         return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                     }
 
@@ -358,7 +365,7 @@ namespace System
                             continue;
                         }
 
-                        // Find offset of first match
+                        // Find offset of first match and add to current offset
                         return (int)(byte*)offset + LocateFirstFoundByte(matches);
                     }
 
@@ -499,7 +506,7 @@ namespace System
                         continue;
                     }
 
-                    // Find offset of first match
+                    // Find offset of first match and add to current offset
                     return (int)(offset) - Vector<byte>.Count + LocateLastFoundByte(matches);
                 }
                 if ((byte*)offset > (byte*)0)
@@ -628,15 +635,19 @@ namespace System
                         do
                         {
                             Vector256<byte> search = LoadVector256(ref searchSpace, offset);
+                            // Note that MoveMask has converted the equal vector elements into a set of bit flags,
+                            // So the bit position in 'matches' corresponds to the element offset.
                             int matches = Avx2.MoveMask(Avx2.CompareEqual(values0, search));
+                            // Bitwise Or to combine the flagged matches for the second value to our match flags
                             matches |= Avx2.MoveMask(Avx2.CompareEqual(values1, search));
                             if (matches == 0)
                             {
+                                // Zero flags set so no matches
                                 offset += Vector256<byte>.Count;
                                 continue;
                             }
 
-                            // Find offset of first match
+                            // Find bitflag offset of first match and add to current offset
                             return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                         } while ((byte*)nLength > (byte*)offset);
                     }
@@ -648,15 +659,17 @@ namespace System
                         Vector128<byte> values1 = Vector128.Create(value1);
 
                         Vector128<byte> search = LoadVector128(ref searchSpace, offset);
+                        // Same method as above
                         int matches = Sse2.MoveMask(Sse2.CompareEqual(values0, search));
                         matches |= Sse2.MoveMask(Sse2.CompareEqual(values1, search));
                         if (matches == 0)
                         {
+                            // Zero flags set so no matches
                             offset += Vector128<byte>.Count;
                         }
                         else
                         {
-                            // Find offset of first match
+                            // Find bitflag offset of first match and add to current offset
                             return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                         }
                     }
@@ -680,15 +693,17 @@ namespace System
                     while ((byte*)nLength > (byte*)offset)
                     {
                         Vector128<byte> search = LoadVector128(ref searchSpace, offset);
+                        // Same method as above
                         int matches = Sse2.MoveMask(Sse2.CompareEqual(values0, search));
                         matches |= Sse2.MoveMask(Sse2.CompareEqual(values1, search));
                         if (matches == 0)
                         {
+                            // Zero flags set so no matches
                             offset += Vector128<byte>.Count;
                             continue;
                         }
 
-                        // Find offset of first match
+                        // Find bitflag offset of first match and add to current offset
                         return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                     }
 
@@ -720,7 +735,7 @@ namespace System
                             continue;
                         }
 
-                        // Find offset of first match
+                        // Find offset of first match and add to current offset
                         return (int)(byte*)offset + LocateFirstFoundByte(matches);
                     }
 
@@ -755,8 +770,8 @@ namespace System
             Debug.Assert(length >= 0);
 
             uint uValue0 = value0; // Use uint for comparisons to avoid unnecessary 8->32 extensions
-            uint uValue1 = value1; // Use uint for comparisons to avoid unnecessary 8->32 extensions
-            uint uValue2 = value2; // Use uint for comparisons to avoid unnecessary 8->32 extensions
+            uint uValue1 = value1;
+            uint uValue2 = value2;
             IntPtr offset = (IntPtr)0; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations
             IntPtr nLength = (IntPtr)length;
 
@@ -853,16 +868,21 @@ namespace System
                         do
                         {
                             Vector256<byte> search = LoadVector256(ref searchSpace, offset);
+                            // Note that MoveMask has converted the equal vector elements into a set of bit flags,
+                            // So the bit position in 'matches' corresponds to the element offset.
                             int matches = Avx2.MoveMask(Avx2.CompareEqual(values0, search));
+                            // Bitwise Or to combine the flagged matches for the second value to our match flags
                             matches |= Avx2.MoveMask(Avx2.CompareEqual(values1, search));
+                            // Bitwise Or to combine the flagged matches for the third value to our match flags
                             matches |= Avx2.MoveMask(Avx2.CompareEqual(values2, search));
                             if (matches == 0)
                             {
+                                // Zero flags set so no matches
                                 offset += Vector256<byte>.Count;
                                 continue;
                             }
 
-                            // Find offset of first match
+                            // Find bitflag offset of first match and add to current offset
                             return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                         } while ((byte*)nLength > (byte*)offset);
                     }
@@ -875,16 +895,18 @@ namespace System
                         Vector128<byte> values2 = Vector128.Create(value2);
 
                         Vector128<byte> search = LoadVector128(ref searchSpace, offset);
+                        // Same method as above
                         int matches = Sse2.MoveMask(Sse2.CompareEqual(values0, search));
                         matches |= Sse2.MoveMask(Sse2.CompareEqual(values1, search));
                         matches |= Sse2.MoveMask(Sse2.CompareEqual(values2, search));
                         if (matches == 0)
                         {
+                            // Zero flags set so no matches
                             offset += Vector128<byte>.Count;
                         }
                         else
                         {
-                            // Find offset of first match
+                            // Find bitflag offset of first match and add to current offset
                             return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                         }
                     }
@@ -909,16 +931,18 @@ namespace System
                     while ((byte*)nLength > (byte*)offset)
                     {
                         Vector128<byte> search = LoadVector128(ref searchSpace, offset);
+                        // Same method as above
                         int matches = Sse2.MoveMask(Sse2.CompareEqual(values0, search));
                         matches |= Sse2.MoveMask(Sse2.CompareEqual(values1, search));
                         matches |= Sse2.MoveMask(Sse2.CompareEqual(values2, search));
                         if (matches == 0)
                         {
+                            // Zero flags set so no matches
                             offset += Vector128<byte>.Count;
                             continue;
                         }
 
-                        // Find offset of first match
+                        // Find bitflag offset of first match and add to current offset
                         return ((int)(byte*)offset) + BitOps.TrailingZeroCount(matches);
                     }
 
@@ -955,7 +979,7 @@ namespace System
                             continue;
                         }
 
-                        // Find offset of first match
+                        // Find offset of first match and add to current offset
                         return (int)(byte*)offset + LocateFirstFoundByte(matches);
                     }
 
@@ -990,7 +1014,7 @@ namespace System
             Debug.Assert(length >= 0);
 
             uint uValue0 = value0; // Use uint for comparisons to avoid unnecessary 8->32 extensions
-            uint uValue1 = value1; // Use uint for comparisons to avoid unnecessary 8->32 extensions
+            uint uValue1 = value1;
             IntPtr offset = (IntPtr)length; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations
             IntPtr nLength = (IntPtr)length;
 
@@ -1080,7 +1104,7 @@ namespace System
                         continue;
                     }
 
-                    // Find offset of first match
+                    // Find offset of first match and add to current offset
                     return (int)(offset) - Vector<byte>.Count + LocateLastFoundByte(matches);
                 }
 
@@ -1114,8 +1138,8 @@ namespace System
             Debug.Assert(length >= 0);
 
             uint uValue0 = value0; // Use uint for comparisons to avoid unnecessary 8->32 extensions
-            uint uValue1 = value1; // Use uint for comparisons to avoid unnecessary 8->32 extensions
-            uint uValue2 = value2; // Use uint for comparisons to avoid unnecessary 8->32 extensions
+            uint uValue1 = value1;
+            uint uValue2 = value2;
             IntPtr offset = (IntPtr)length; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations
             IntPtr nLength = (IntPtr)length;
 
@@ -1210,7 +1234,7 @@ namespace System
                         continue;
                     }
 
-                    // Find offset of first match
+                    // Find offset of first match and add to current offset
                     return (int)(offset) - Vector<byte>.Count + LocateLastFoundByte(matches);
                 }
 
@@ -1324,18 +1348,149 @@ namespace System
             IntPtr offset = (IntPtr)0; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations
             IntPtr nLength = (IntPtr)(void*)minLength;
 
-            if (Vector.IsHardwareAccelerated && (byte*)nLength > (byte*)Vector<byte>.Count)
+            if (Avx2.IsSupported)
             {
-                nLength -= Vector<byte>.Count;
-                while ((byte*)nLength > (byte*)offset)
+                if ((byte*)nLength >= (byte*)Vector256<byte>.Count)
                 {
-                    if (LoadVector(ref first, offset) != LoadVector(ref second, offset))
+                    nLength -= Vector256<byte>.Count;
+                    uint matches;
+                    while ((byte*)nLength > (byte*)offset)
                     {
-                        goto NotEqual;
+                        matches = (uint)Avx2.MoveMask(Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset)));
+                        // Note that MoveMask has converted the equal vector elements into a set of bit flags,
+                        // So the bit position in 'matches' corresponds to the element offset.
+
+                        // 32 elements in Vector256<byte> so we compare to uint.MaxValue to check if everything matched
+                        if (matches == uint.MaxValue)
+                        {
+                            // All matched
+                            offset += Vector256<byte>.Count;
+                            continue;
+                        }
+
+                        goto Difference;
                     }
-                    offset += Vector<byte>.Count;
+                    // Move to Vector length from end for final compare
+                    offset = nLength;
+                    // Same as method as above
+                    matches = (uint)Avx2.MoveMask(Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset)));
+                    if (matches == uint.MaxValue)
+                    {
+                        // All matched
+                        goto Equal;
+                    }
+                Difference:
+                    // Invert matches to find differences
+                    uint differences = ~matches;
+                    // Find bitflag offset of first difference and add to current offset
+                    offset = (IntPtr)((int)(byte*)offset + BitOps.TrailingZeroCount((int)differences));
+
+                    int result = Unsafe.AddByteOffset(ref first, offset).CompareTo(Unsafe.AddByteOffset(ref second, offset));
+                    Debug.Assert(result != 0);
+
+                    return result;
+                }
+
+                if ((byte*)nLength >= (byte*)Vector128<byte>.Count)
+                {
+                    nLength -= Vector128<byte>.Count;
+                    uint matches;
+                    if ((byte*)nLength > (byte*)offset)
+                    {
+                        matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset)));
+                        // Note that MoveMask has converted the equal vector elements into a set of bit flags,
+                        // So the bit position in 'matches' corresponds to the element offset.
+
+                        // 16 elements in Vector128<byte> so we compare to ushort.MaxValue to check if everything matched
+                        if (matches == ushort.MaxValue)
+                        {
+                            // All matched
+                            offset += Vector128<byte>.Count;
+                        }
+                        else
+                        {
+                            goto Difference;
+                        }
+                    }
+                    // Move to Vector length from end for final compare
+                    offset = nLength;
+                    // Same as method as above
+                    matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset)));
+                    if (matches == ushort.MaxValue)
+                    {
+                        // All matched
+                        goto Equal;
+                    }
+                Difference:
+                    // Invert matches to find differences
+                    uint differences = ~matches;
+                    // Find bitflag offset of first difference and add to current offset
+                    offset = (IntPtr)((int)(byte*)offset + BitOps.TrailingZeroCount((int)differences));
+
+                    int result = Unsafe.AddByteOffset(ref first, offset).CompareTo(Unsafe.AddByteOffset(ref second, offset));
+                    Debug.Assert(result != 0);
+
+                    return result;
+                }
+            }
+            else if (Sse2.IsSupported)
+            {
+                if ((byte*)nLength >= (byte*)Vector128<byte>.Count)
+                {
+                    nLength -= Vector128<byte>.Count;
+                    uint matches;
+                    while ((byte*)nLength > (byte*)offset)
+                    {
+                        matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset)));
+                        // Note that MoveMask has converted the equal vector elements into a set of bit flags,
+                        // So the bit position in 'matches' corresponds to the element offset.
+
+                        // 16 elements in Vector128<byte> so we compare to ushort.MaxValue to check if everything matched
+                        if (matches == ushort.MaxValue)
+                        {
+                            // All matched
+                            offset += Vector128<byte>.Count;
+                            continue;
+                        }
+
+                        goto Difference;
+                    }
+                    // Move to Vector length from end for final compare
+                    offset = nLength;
+                    // Same as method as above
+                    matches = (uint)Sse2.MoveMask(Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset)));
+                    if (matches == ushort.MaxValue)
+                    {
+                        // All matched
+                        goto Equal;
+                    }
+                Difference:
+                    // Invert matches to find differences
+                    uint differences = ~matches;
+                    // Find bitflag offset of first difference and add to current offset
+                    offset = (IntPtr)((int)(byte*)offset + BitOps.TrailingZeroCount((int)differences));
+
+                    int result = Unsafe.AddByteOffset(ref first, offset).CompareTo(Unsafe.AddByteOffset(ref second, offset));
+                    Debug.Assert(result != 0);
+
+                    return result;
+                }
+            }
+            else if (Vector.IsHardwareAccelerated)
+            {
+                if ((byte*)nLength > (byte*)Vector<byte>.Count)
+                {
+                    nLength -= Vector<byte>.Count;
+                    while ((byte*)nLength > (byte*)offset)
+                    {
+                        if (LoadVector(ref first, offset) != LoadVector(ref second, offset))
+                        {
+                            goto BytewiseCheck;
+                        }
+                        offset += Vector<byte>.Count;
+                    }
+                    goto BytewiseCheck;
                 }
-                goto NotEqual;
             }
 
             if ((byte*)nLength > (byte*)sizeof(UIntPtr))
@@ -1345,13 +1500,13 @@ namespace System
                 {
                     if (LoadUIntPtr(ref first, offset) != LoadUIntPtr(ref second, offset))
                     {
-                        goto NotEqual;
+                        goto BytewiseCheck;
                     }
                     offset += sizeof(UIntPtr);
                 }
             }
 
-        NotEqual:  // Workaround for https://github.com/dotnet/coreclr/issues/13549
+        BytewiseCheck:  // Workaround for https://github.com/dotnet/coreclr/issues/13549
             while ((byte*)minLength > (byte*)offset)
             {
                 int result = Unsafe.AddByteOffset(ref first, offset).CompareTo(Unsafe.AddByteOffset(ref second, offset));