Use CancellationToken.UnsafeRegister in a few more places (dotnet/corefx#37551)
authorStephen Toub <stoub@microsoft.com>
Thu, 16 May 2019 18:08:10 +0000 (11:08 -0700)
committerGitHub <noreply@github.com>
Thu, 16 May 2019 18:08:10 +0000 (11:08 -0700)
CancellationToken.Register captures the current ExecutionContext and uses it to invoke the callback if/when it's invoked.  That's generally desirable and is the right default, but in cases where we know for certain the callback doesn't care about EC (e.g. we're not invoking any 3rd-party code), we can use UnsafeRegister instead (newly added in 3.0), which skips capturing the ExecutionContext, as if Capture returned null.  This helps few a couple of small costs:
- Avoids thread local lookups to capture the current EC.
- Avoids additional delegate invocations and thread local gets/sets to invoke the callback with the captured EC.
- Avoids holding on to the EC in case it's needed, which can potentially keep alive an unbounded amount of state due to AsyncLocals.

Commit migrated from https://github.com/dotnet/corefx/commit/b1a1bfa0997ce35e8540fbbe74276858e315807e

src/libraries/Common/src/System/Net/WebSockets/ManagedWebSocket.cs
src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs
src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs
src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs
src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs
src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj
src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs
src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs [new file with mode: 0644]
src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs

index f422f5e..73343be 100644 (file)
@@ -179,7 +179,9 @@ namespace System.Net.WebSockets
             _receiveBuffer = new byte[ReceiveBufferMinLength];
 
             // Set up the abort source so that if it's triggered, we transition the instance appropriately.
-            _abortSource.Token.Register(s =>
+            // There's no need to store the resulting CancellationTokenRegistration, as this instance owns
+            // the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration.
+            _abortSource.Token.UnsafeRegister(s =>
             {
                 var thisRef = (ManagedWebSocket)s;
 
index 448ee16..62cb2e1 100644 (file)
@@ -526,7 +526,7 @@ namespace System.IO
                 // When cancellation is requested, clear out all watches.  This should force any active or future reads 
                 // on the inotify handle to return 0 bytes read immediately, allowing us to wake up from the blocking call 
                 // and exit the processing loop and clean up.
-                var ctr = _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this);
+                var ctr = _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this);
                 try
                 {
                     // Previous event information
index b608879..d0be36b 100644 (file)
@@ -169,7 +169,7 @@ namespace System.IO
                 _includeChildren = includeChildren;
                 _filterFlags = filter;
                 _cancellationToken = cancelToken;
-                _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this);
+                _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this);
                 _stopping = false;
             }
 
index ab7fd35..3118965 100644 (file)
@@ -70,7 +70,7 @@ namespace System.IO.Pipes
                 if (state == NoResult)
                 {
                     // Register the cancellation
-                    _cancellationRegistration = cancellationToken.Register(thisRef => ((PipeCompletionSource<TResult>)thisRef).Cancel(), this);
+                    _cancellationRegistration = cancellationToken.UnsafeRegister(thisRef => ((PipeCompletionSource<TResult>)thisRef).Cancel(), this);
 
                     // Grab the state for case if IO completed while we were setting the registration.
                     state = Interlocked.Exchange(ref _state, NoResult);
index e1d098e..d516af8 100644 (file)
@@ -202,7 +202,7 @@ namespace System.IO.Pipes
                         if (!t.IsCompletedSuccessfully)
                         {
                             var cancelTcs = new TaskCompletionSource<bool>();
-                            using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), cancelTcs))
+                            using (cancellationToken.UnsafeRegister(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), cancelTcs))
                             {
                                 if (t == await Task.WhenAny(t, cancelTcs.Task).ConfigureAwait(false))
                                 {
index b2e6f6b..2f7d99e 100644 (file)
@@ -60,7 +60,7 @@ namespace System.Net.Http
                 if (Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, saea))
                 {
                     // Connect completing asynchronously. Enable it to be canceled and wait for it.
-                    using (cancellationToken.Register(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea))
+                    using (cancellationToken.UnsafeRegister(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea))
                     {
                         await saea.Builder.Task.ConfigureAwait(false);
                     }
index fecb415..33d21d3 100644 (file)
       <Link>Common\System\Net\WebSockets\WebSocketValidate.cs</Link>
     </Compile>
     <Compile Include="System\Net\WebSockets\ManagedWebSocket.netstandard.cs" />
+    <Compile Include="System\Net\WebSockets\ManagedWebSocketExtensions.cs" />
     <Compile Include="System\Net\WebSockets\WebSocketProtocol.cs" />
   </ItemGroup>
   <ItemGroup Condition="'$(TargetGroup)'=='netstandard'">
-    <Compile Include="System\Net\WebSockets\ManagedWebSocketExtensions.cs" />
+    <Compile Include="System\Net\WebSockets\ManagedWebSocketExtensions.netstandard.cs" />
     <Reference Include="System.Runtime.CompilerServices.Unsafe" />
   </ItemGroup>
   <ItemGroup>
index 52b9e2f..7eb3997 100644 (file)
@@ -2,87 +2,13 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-using System.Buffers;
-using System.Diagnostics;
-using System.IO;
-using System.Runtime.CompilerServices;
-using System.Runtime.InteropServices;
-using System.Text;
 using System.Threading;
-using System.Threading.Tasks;
 
 namespace System.Net.WebSockets
 {
-    internal static class ManagedWebSocketExtensions
+    internal static partial class ManagedWebSocketExtensions
     {
-        internal static unsafe string GetString(this UTF8Encoding encoding, Span<byte> bytes)
-        {
-            fixed (byte* b = &MemoryMarshal.GetReference(bytes))
-            {
-                return encoding.GetString(b, bytes.Length);
-            }
-        }
-
-        internal static ValueTask<int> ReadAsync(this Stream stream, Memory<byte> destination, CancellationToken cancellationToken = default)
-        {
-            if (MemoryMarshal.TryGetArray(destination, out ArraySegment<byte> array))
-            {
-                return new ValueTask<int>(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken));
-            }
-            else
-            {
-                byte[] buffer = ArrayPool<byte>.Shared.Rent(destination.Length);
-                return new ValueTask<int>(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination));
-
-                async Task<int> FinishReadAsync(Task<int> readTask, byte[] localBuffer, Memory<byte> localDestination)
-                {
-                    try
-                    {
-                        int result = await readTask.ConfigureAwait(false);
-                        new Span<byte>(localBuffer, 0, result).CopyTo(localDestination.Span);
-                        return result;
-                    }
-                    finally
-                    {
-                        ArrayPool<byte>.Shared.Return(localBuffer);
-                    }
-                }
-            }
-        }
-
-        internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
-        {
-            if (MemoryMarshal.TryGetArray(source, out ArraySegment<byte> array))
-            {
-                return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken));
-            }
-            else
-            {
-                byte[] buffer = ArrayPool<byte>.Shared.Rent(source.Length);
-                source.Span.CopyTo(buffer);
-                return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer));
-
-                async Task FinishWriteAsync(Task writeTask, byte[] localBuffer)
-                {
-                    try
-                    {
-                        await writeTask.ConfigureAwait(false);
-                    }
-                    finally
-                    {
-                        ArrayPool<byte>.Shared.Return(localBuffer);
-                    }
-                }
-            }
-        }
-    }
-
-    internal static class BitConverter
-    {
-        internal static unsafe int ToInt32(ReadOnlySpan<byte> value)
-        {
-            Debug.Assert(value.Length >= sizeof(int));
-            return Unsafe.ReadUnaligned<int>(ref MemoryMarshal.GetReference(value));
-        }
+        internal static CancellationTokenRegistration UnsafeRegister(this CancellationToken cancellationToken, Action<object> callback, object state) =>
+            cancellationToken.Register(callback, state);
     }
 }
diff --git a/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs b/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs
new file mode 100644 (file)
index 0000000..89e2326
--- /dev/null
@@ -0,0 +1,88 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Buffers;
+using System.Diagnostics;
+using System.IO;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Net.WebSockets
+{
+    internal static partial class ManagedWebSocketExtensions
+    {
+        internal static unsafe string GetString(this UTF8Encoding encoding, Span<byte> bytes)
+        {
+            fixed (byte* b = &MemoryMarshal.GetReference(bytes))
+            {
+                return encoding.GetString(b, bytes.Length);
+            }
+        }
+
+        internal static ValueTask<int> ReadAsync(this Stream stream, Memory<byte> destination, CancellationToken cancellationToken = default)
+        {
+            if (MemoryMarshal.TryGetArray(destination, out ArraySegment<byte> array))
+            {
+                return new ValueTask<int>(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken));
+            }
+            else
+            {
+                byte[] buffer = ArrayPool<byte>.Shared.Rent(destination.Length);
+                return new ValueTask<int>(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination));
+
+                async Task<int> FinishReadAsync(Task<int> readTask, byte[] localBuffer, Memory<byte> localDestination)
+                {
+                    try
+                    {
+                        int result = await readTask.ConfigureAwait(false);
+                        new Span<byte>(localBuffer, 0, result).CopyTo(localDestination.Span);
+                        return result;
+                    }
+                    finally
+                    {
+                        ArrayPool<byte>.Shared.Return(localBuffer);
+                    }
+                }
+            }
+        }
+
+        internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
+        {
+            if (MemoryMarshal.TryGetArray(source, out ArraySegment<byte> array))
+            {
+                return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken));
+            }
+            else
+            {
+                byte[] buffer = ArrayPool<byte>.Shared.Rent(source.Length);
+                source.Span.CopyTo(buffer);
+                return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer));
+
+                async Task FinishWriteAsync(Task writeTask, byte[] localBuffer)
+                {
+                    try
+                    {
+                        await writeTask.ConfigureAwait(false);
+                    }
+                    finally
+                    {
+                        ArrayPool<byte>.Shared.Return(localBuffer);
+                    }
+                }
+            }
+        }
+    }
+
+    internal static class BitConverter
+    {
+        internal static unsafe int ToInt32(ReadOnlySpan<byte> value)
+        {
+            Debug.Assert(value.Length >= sizeof(int));
+            return Unsafe.ReadUnaligned<int>(ref MemoryMarshal.GetReference(value));
+        }
+    }
+}
index 3dc264a..2484140 100644 (file)
@@ -1060,13 +1060,13 @@ namespace System.Threading.Tasks
             // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled
             CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled)
                             ? default(CancellationTokenRegistration)
-                            : parallelOptions.CancellationToken.Register((o) =>
+                            : parallelOptions.CancellationToken.UnsafeRegister((o) =>
                             {
                                 // Record our cancellation before stopping processing
                                 oce = new OperationCanceledException(parallelOptions.CancellationToken);
                                 // Cause processing to stop
                                 sharedPStateFlags.Cancel();
-                            }, state: null, useSynchronizationContext: false);
+                            }, state: null);
 
             // ETW event for Parallel For begin
             int forkJoinContextID = 0;
@@ -1322,13 +1322,13 @@ namespace System.Threading.Tasks
             // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled
             CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled)
                             ? default(CancellationTokenRegistration)
-                            : parallelOptions.CancellationToken.Register((o) =>
+                            : parallelOptions.CancellationToken.UnsafeRegister((o) =>
                             {
                                 // Record our cancellation before stopping processing
                                 oce = new OperationCanceledException(parallelOptions.CancellationToken);
                                 // Cause processing to stop
                                 sharedPStateFlags.Cancel();
-                            }, state: null, useSynchronizationContext: false);
+                            }, state: null);
 
             // ETW event for Parallel For begin
             int forkJoinContextID = 0;
@@ -3121,13 +3121,13 @@ namespace System.Threading.Tasks
             // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled
             CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled)
                             ? default(CancellationTokenRegistration)
-                            : parallelOptions.CancellationToken.Register((o) =>
+                            : parallelOptions.CancellationToken.UnsafeRegister((o) =>
                             {
                                 // Record our cancellation before stopping processing
                                 oce = new OperationCanceledException(parallelOptions.CancellationToken);
                                 // Cause processing to stop
                                 sharedPStateFlags.Cancel();
-                            }, state: null, useSynchronizationContext: false);
+                            }, state: null);
 
             // Get our dynamic partitioner -- depends on whether source is castable to OrderablePartitioner
             // Also, do some error checking.