Use intrinsics for SequenceEqual<byte> vectorization to emit at R2R (#32371)
authorBen Adams <thundercat@illyriad.co.uk>
Mon, 27 Apr 2020 22:19:24 +0000 (23:19 +0100)
committerGitHub <noreply@github.com>
Mon, 27 Apr 2020 22:19:24 +0000 (15:19 -0700)
src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs

index 025d22e..a657ceb 100644 (file)
@@ -13,8 +13,10 @@ using Internal.Runtime.CompilerServices;
 #pragma warning disable SA1121 // explicitly using type aliases instead of built-in types
 #if TARGET_64BIT
 using nuint = System.UInt64;
+using nint = System.Int64;
 #else
 using nuint = System.UInt32;
+using nint = System.Int32;
 #endif // TARGET_64BIT
 
 namespace System
@@ -1309,85 +1311,210 @@ namespace System
 
         // Optimized byte-based SequenceEquals. The "length" parameter for this one is declared a nuint rather than int as we also use it for types other than byte
         // where the length can exceed 2Gb once scaled by sizeof(T).
-        [MethodImpl(MethodImplOptions.AggressiveOptimization)]
-        public static unsafe bool SequenceEqual(ref byte first, ref byte second, nuint length)
+        public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
         {
-            IntPtr offset = (IntPtr)0; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations
-            IntPtr lengthToExamine = (IntPtr)(void*)length;
+            bool result;
+            // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations
+            if (length >= sizeof(nuint))
+            {
+                // Conditional jmp foward to favor shorter lengths. (See comment at "Equal:" label)
+                // The longer lengths can make back the time due to branch misprediction
+                // better than shorter lengths.
+                goto Longer;
+            }
 
-            if ((byte*)lengthToExamine >= (byte*)sizeof(UIntPtr))
+#if TARGET_64BIT
+            // On 32-bit, this will always be true since sizeof(nuint) == 4
+            if (length < sizeof(uint))
+#endif
             {
-                // Only check that the ref is the same if buffers are large, and hence
-                // its worth avoiding doing unnecessary comparisons
-                if (Unsafe.AreSame(ref first, ref second))
-                    goto Equal;
+                uint differentBits = 0;
+                nuint offset = (length & 2);
+                if (offset != 0)
+                {
+                    differentBits = LoadUShort(ref first);
+                    differentBits -= LoadUShort(ref second);
+                }
+                if ((length & 1) != 0)
+                {
+                    differentBits |= (uint)Unsafe.AddByteOffset(ref first, offset) - (uint)Unsafe.AddByteOffset(ref second, offset);
+                }
+                result = (differentBits == 0);
+                goto Result;
+            }
+#if TARGET_64BIT
+            else
+            {
+                nuint offset = length - sizeof(uint);
+                uint differentBits = LoadUInt(ref first) - LoadUInt(ref second);
+                differentBits |= LoadUInt(ref first, offset) - LoadUInt(ref second, offset);
+                result = (differentBits == 0);
+                goto Result;
+            }
+#endif
+        Longer:
+            // Only check that the ref is the same if buffers are large,
+            // and hence its worth avoiding doing unnecessary comparisons
+            if (!Unsafe.AreSame(ref first, ref second))
+            {
+                // C# compiler inverts this test, making the outer goto the conditional jmp.
+                goto Vector;
+            }
 
-                if (Vector.IsHardwareAccelerated && (byte*)lengthToExamine >= (byte*)Vector<byte>.Count)
+            // This becomes a conditional jmp foward to not favor it.
+            goto Equal;
+
+        Result:
+            return result;
+        // When the sequence is equal; which is the longest execution, we want it to determine that
+        // as fast as possible so we do not want the early outs to be "predicted not taken" branches.
+        Equal:
+            return true;
+
+        Vector:
+            if (Sse2.IsSupported)
+            {
+                if (Avx2.IsSupported && length >= (nuint)Vector256<byte>.Count)
                 {
-                    lengthToExamine -= Vector<byte>.Count;
-                    while ((byte*)lengthToExamine > (byte*)offset)
+                    Vector256<byte> vecResult;
+                    nuint offset = 0;
+                    nuint lengthToExamine = length - (nuint)Vector256<byte>.Count;
+                    // Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
+                    Debug.Assert(lengthToExamine < length);
+                    if (lengthToExamine != 0)
                     {
-                        if (LoadVector(ref first, offset) != LoadVector(ref second, offset))
+                        do
                         {
-                            goto NotEqual;
-                        }
-                        offset += Vector<byte>.Count;
+                            vecResult = Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset));
+                            if (Avx2.MoveMask(vecResult) != -1)
+                            {
+                                goto NotEqual;
+                            }
+                            offset += (nuint)Vector256<byte>.Count;
+                        } while (lengthToExamine > offset);
                     }
-                    return LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine);
-                }
 
-                Debug.Assert((byte*)lengthToExamine >= (byte*)sizeof(UIntPtr));
+                    // Do final compare as Vector256<byte>.Count from end rather than start
+                    vecResult = Avx2.CompareEqual(LoadVector256(ref first, lengthToExamine), LoadVector256(ref second, lengthToExamine));
+                    if (Avx2.MoveMask(vecResult) == -1)
+                    {
+                        // C# compiler inverts this test, making the outer goto the conditional jmp.
+                        goto Equal;
+                    }
 
-                lengthToExamine -= sizeof(UIntPtr);
-                while ((byte*)lengthToExamine > (byte*)offset)
+                    // This becomes a conditional jmp foward to not favor it.
+                    goto NotEqual;
+                }
+                // Use Vector128.Size as Vector128<byte>.Count doesn't inline at R2R time
+                // https://github.com/dotnet/runtime/issues/32714
+                else if (length >= Vector128.Size)
                 {
-                    if (LoadUIntPtr(ref first, offset) != LoadUIntPtr(ref second, offset))
+                    Vector128<byte> vecResult;
+                    nuint offset = 0;
+                    nuint lengthToExamine = length - Vector128.Size;
+                    // Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
+                    Debug.Assert(lengthToExamine < length);
+                    if (lengthToExamine != 0)
                     {
-                        goto NotEqual;
+                        do
+                        {
+                            // We use instrincs directly as .Equals calls .AsByte() which doesn't inline at R2R time
+                            // https://github.com/dotnet/runtime/issues/32714
+                            vecResult = Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset));
+                            if (Sse2.MoveMask(vecResult) != 0xFFFF)
+                            {
+                                goto NotEqual;
+                            }
+                            offset += Vector128.Size;
+                        } while (lengthToExamine > offset);
                     }
-                    offset += sizeof(UIntPtr);
-                }
-                return LoadUIntPtr(ref first, lengthToExamine) == LoadUIntPtr(ref second, lengthToExamine);
-            }
 
-            Debug.Assert((byte*)lengthToExamine < (byte*)sizeof(UIntPtr));
+                    // Do final compare as Vector128<byte>.Count from end rather than start
+                    vecResult = Sse2.CompareEqual(LoadVector128(ref first, lengthToExamine), LoadVector128(ref second, lengthToExamine));
+                    if (Sse2.MoveMask(vecResult) == 0xFFFF)
+                    {
+                        // C# compiler inverts this test, making the outer goto the conditional jmp.
+                        goto Equal;
+                    }
 
-            // On 32-bit, this will never be true since sizeof(UIntPtr) == 4
-#if TARGET_64BIT
-            if ((byte*)lengthToExamine >= (byte*)sizeof(int))
-            {
-                if (LoadInt(ref first, offset) != LoadInt(ref second, offset))
-                {
+                    // This becomes a conditional jmp foward to not favor it.
                     goto NotEqual;
                 }
-                offset += sizeof(int);
-                lengthToExamine -= sizeof(int);
             }
-#endif
-
-            if ((byte*)lengthToExamine >= (byte*)sizeof(short))
+            else if (Vector.IsHardwareAccelerated && length >= (nuint)Vector<byte>.Count)
             {
-                if (LoadShort(ref first, offset) != LoadShort(ref second, offset))
+                nuint offset = 0;
+                nuint lengthToExamine = length - (nuint)Vector<byte>.Count;
+                // Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
+                Debug.Assert(lengthToExamine < length);
+                if (lengthToExamine > 0)
                 {
-                    goto NotEqual;
+                    do
+                    {
+                        if (LoadVector(ref first, offset) != LoadVector(ref second, offset))
+                        {
+                            goto NotEqual;
+                        }
+                        offset += (nuint)Vector<byte>.Count;
+                    } while (lengthToExamine > offset);
                 }
-                offset += sizeof(short);
-                lengthToExamine -= sizeof(short);
+
+                // Do final compare as Vector<byte>.Count from end rather than start
+                if (LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine))
+                {
+                    // C# compiler inverts this test, making the outer goto the conditional jmp.
+                    goto Equal;
+                }
+
+                // This becomes a conditional jmp foward to not favor it.
+                goto NotEqual;
             }
 
-            if (lengthToExamine != IntPtr.Zero)
+#if TARGET_64BIT
+            if (Sse2.IsSupported)
             {
-                Debug.Assert((int)lengthToExamine == 1);
+                Debug.Assert(length <= sizeof(nuint) * 2);
 
-                if (Unsafe.AddByteOffset(ref first, offset) != Unsafe.AddByteOffset(ref second, offset))
+                nuint offset = length - sizeof(nuint);
+                nuint differentBits = LoadNUInt(ref first) - LoadNUInt(ref second);
+                differentBits |= LoadNUInt(ref first, offset) - LoadNUInt(ref second, offset);
+                result = (differentBits == 0);
+                goto Result;
+            }
+            else
+#endif
+            {
+                Debug.Assert(length >= sizeof(nuint));
                 {
-                    goto NotEqual;
+                    nuint offset = 0;
+                    nuint lengthToExamine = length - sizeof(nuint);
+                    // Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
+                    Debug.Assert(lengthToExamine < length);
+                    if (lengthToExamine > 0)
+                    {
+                        do
+                        {
+                            // Compare unsigned so not do a sign extend mov on 64 bit
+                            if (LoadNUInt(ref first, offset) != LoadNUInt(ref second, offset))
+                            {
+                                goto NotEqual;
+                            }
+                            offset += sizeof(nuint);
+                        } while (lengthToExamine > offset);
+                    }
+
+                    // Do final compare as sizeof(nuint) from end rather than start
+                    result = (LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine));
+                    goto Result;
                 }
             }
 
-        Equal:
-            return true;
-        NotEqual: // Workaround for https://github.com/dotnet/runtime/issues/8795
+            // As there are so many true/false exit points the Jit will coalesce them to one location.
+            // We want them at the end so the conditional early exit jmps are all jmp forwards so the
+            // branch predictor in a uninitialized state will not take them e.g.
+            // - loops are conditional jmps backwards and predicted
+            // - exceptions are conditional fowards jmps and not predicted
+        NotEqual:
             return false;
         }
 
@@ -1644,27 +1771,55 @@ namespace System
                                                        0x01ul << 48) + 1;
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        private static unsafe short LoadShort(ref byte start, IntPtr offset)
-            => Unsafe.ReadUnaligned<short>(ref Unsafe.AddByteOffset(ref start, offset));
+        private static ushort LoadUShort(ref byte start)
+            => Unsafe.ReadUnaligned<ushort>(ref start);
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static ushort LoadUShort(ref byte start, nuint offset)
+            => Unsafe.ReadUnaligned<ushort>(ref Unsafe.AddByteOffset(ref start, offset));
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        private static unsafe int LoadInt(ref byte start, IntPtr offset)
-            => Unsafe.ReadUnaligned<int>(ref Unsafe.AddByteOffset(ref start, offset));
+        private static uint LoadUInt(ref byte start)
+            => Unsafe.ReadUnaligned<uint>(ref start);
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        private static unsafe UIntPtr LoadUIntPtr(ref byte start, IntPtr offset)
+        private static uint LoadUInt(ref byte start, nuint offset)
+            => Unsafe.ReadUnaligned<uint>(ref Unsafe.AddByteOffset(ref start, offset));
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static nuint LoadNUInt(ref byte start)
+            => Unsafe.ReadUnaligned<nuint>(ref start);
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static nuint LoadNUInt(ref byte start, nuint offset)
+            => Unsafe.ReadUnaligned<nuint>(ref Unsafe.AddByteOffset(ref start, offset));
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static UIntPtr LoadUIntPtr(ref byte start, IntPtr offset)
             => Unsafe.ReadUnaligned<UIntPtr>(ref Unsafe.AddByteOffset(ref start, offset));
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        private static unsafe Vector<byte> LoadVector(ref byte start, IntPtr offset)
+        private static Vector<byte> LoadVector(ref byte start, IntPtr offset)
+            => Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static Vector<byte> LoadVector(ref byte start, nuint offset)
             => Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        private static unsafe Vector128<byte> LoadVector128(ref byte start, IntPtr offset)
+        private static Vector128<byte> LoadVector128(ref byte start, IntPtr offset)
             => Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        private static unsafe Vector256<byte> LoadVector256(ref byte start, IntPtr offset)
+        private static Vector128<byte> LoadVector128(ref byte start, nuint offset)
+            => Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static Vector256<byte> LoadVector256(ref byte start, IntPtr offset)
+            => Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static Vector256<byte> LoadVector256(ref byte start, nuint offset)
             => Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]