From 377c37c140b3fbedeb2ed1e1694783f65cb71733 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 16 May 2019 11:08:10 -0700 Subject: [PATCH] Use CancellationToken.UnsafeRegister in a few more places (dotnet/corefx#37551) CancellationToken.Register captures the current ExecutionContext and uses it to invoke the callback if/when it's invoked. That's generally desirable and is the right default, but in cases where we know for certain the callback doesn't care about EC (e.g. we're not invoking any 3rd-party code), we can use UnsafeRegister instead (newly added in 3.0), which skips capturing the ExecutionContext, as if Capture returned null. This helps few a couple of small costs: - Avoids thread local lookups to capture the current EC. - Avoids additional delegate invocations and thread local gets/sets to invoke the callback with the captured EC. - Avoids holding on to the EC in case it's needed, which can potentially keep alive an unbounded amount of state due to AsyncLocals. Commit migrated from https://github.com/dotnet/corefx/commit/b1a1bfa0997ce35e8540fbbe74276858e315807e --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 4 +- .../src/System/IO/FileSystemWatcher.Linux.cs | 2 +- .../src/System/IO/FileSystemWatcher.OSX.cs | 2 +- .../src/System/IO/Pipes/PipeCompletionSource.cs | 2 +- .../src/System/IO/Pipes/PipeStream.Unix.cs | 2 +- .../Net/Http/SocketsHttpHandler/ConnectHelper.cs | 2 +- .../System.Net.WebSockets.WebSocketProtocol.csproj | 3 +- .../Net/WebSockets/ManagedWebSocketExtensions.cs | 80 +------------------- .../ManagedWebSocketExtensions.netstandard.cs | 88 ++++++++++++++++++++++ .../src/System/Threading/Tasks/Parallel.cs | 12 +-- 10 files changed, 107 insertions(+), 90 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs diff --git a/src/libraries/Common/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/Common/src/System/Net/WebSockets/ManagedWebSocket.cs index f422f5e..73343be 100644 --- a/src/libraries/Common/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/Common/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -179,7 +179,9 @@ namespace System.Net.WebSockets _receiveBuffer = new byte[ReceiveBufferMinLength]; // Set up the abort source so that if it's triggered, we transition the instance appropriately. - _abortSource.Token.Register(s => + // There's no need to store the resulting CancellationTokenRegistration, as this instance owns + // the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration. + _abortSource.Token.UnsafeRegister(s => { var thisRef = (ManagedWebSocket)s; diff --git a/src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs b/src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs index 448ee16..62cb2e1 100644 --- a/src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs +++ b/src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs @@ -526,7 +526,7 @@ namespace System.IO // When cancellation is requested, clear out all watches. This should force any active or future reads // on the inotify handle to return 0 bytes read immediately, allowing us to wake up from the blocking call // and exit the processing loop and clean up. - var ctr = _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this); + var ctr = _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this); try { // Previous event information diff --git a/src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs b/src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs index b608879..d0be36b 100644 --- a/src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs +++ b/src/libraries/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs @@ -169,7 +169,7 @@ namespace System.IO _includeChildren = includeChildren; _filterFlags = filter; _cancellationToken = cancelToken; - _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this); + _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this); _stopping = false; } diff --git a/src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs b/src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs index ab7fd35..3118965 100644 --- a/src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs +++ b/src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs @@ -70,7 +70,7 @@ namespace System.IO.Pipes if (state == NoResult) { // Register the cancellation - _cancellationRegistration = cancellationToken.Register(thisRef => ((PipeCompletionSource)thisRef).Cancel(), this); + _cancellationRegistration = cancellationToken.UnsafeRegister(thisRef => ((PipeCompletionSource)thisRef).Cancel(), this); // Grab the state for case if IO completed while we were setting the registration. state = Interlocked.Exchange(ref _state, NoResult); diff --git a/src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs b/src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs index e1d098e..d516af8 100644 --- a/src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs +++ b/src/libraries/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs @@ -202,7 +202,7 @@ namespace System.IO.Pipes if (!t.IsCompletedSuccessfully) { var cancelTcs = new TaskCompletionSource(); - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetResult(true), cancelTcs)) + using (cancellationToken.UnsafeRegister(s => ((TaskCompletionSource)s).TrySetResult(true), cancelTcs)) { if (t == await Task.WhenAny(t, cancelTcs.Task).ConfigureAwait(false)) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs index b2e6f6b..2f7d99e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs @@ -60,7 +60,7 @@ namespace System.Net.Http if (Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, saea)) { // Connect completing asynchronously. Enable it to be canceled and wait for it. - using (cancellationToken.Register(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea)) + using (cancellationToken.UnsafeRegister(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea)) { await saea.Builder.Task.ConfigureAwait(false); } diff --git a/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj b/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj index fecb415..33d21d3 100644 --- a/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj +++ b/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj @@ -14,10 +14,11 @@ Common\System\Net\WebSockets\WebSocketValidate.cs + - + diff --git a/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs b/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs index 52b9e2f..7eb3997 100644 --- a/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs +++ b/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs @@ -2,87 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Buffers; -using System.Diagnostics; -using System.IO; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Text; using System.Threading; -using System.Threading.Tasks; namespace System.Net.WebSockets { - internal static class ManagedWebSocketExtensions + internal static partial class ManagedWebSocketExtensions { - internal static unsafe string GetString(this UTF8Encoding encoding, Span bytes) - { - fixed (byte* b = &MemoryMarshal.GetReference(bytes)) - { - return encoding.GetString(b, bytes.Length); - } - } - - internal static ValueTask ReadAsync(this Stream stream, Memory destination, CancellationToken cancellationToken = default) - { - if (MemoryMarshal.TryGetArray(destination, out ArraySegment array)) - { - return new ValueTask(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken)); - } - else - { - byte[] buffer = ArrayPool.Shared.Rent(destination.Length); - return new ValueTask(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination)); - - async Task FinishReadAsync(Task readTask, byte[] localBuffer, Memory localDestination) - { - try - { - int result = await readTask.ConfigureAwait(false); - new Span(localBuffer, 0, result).CopyTo(localDestination.Span); - return result; - } - finally - { - ArrayPool.Shared.Return(localBuffer); - } - } - } - } - - internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory source, CancellationToken cancellationToken = default) - { - if (MemoryMarshal.TryGetArray(source, out ArraySegment array)) - { - return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken)); - } - else - { - byte[] buffer = ArrayPool.Shared.Rent(source.Length); - source.Span.CopyTo(buffer); - return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer)); - - async Task FinishWriteAsync(Task writeTask, byte[] localBuffer) - { - try - { - await writeTask.ConfigureAwait(false); - } - finally - { - ArrayPool.Shared.Return(localBuffer); - } - } - } - } - } - - internal static class BitConverter - { - internal static unsafe int ToInt32(ReadOnlySpan value) - { - Debug.Assert(value.Length >= sizeof(int)); - return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(value)); - } + internal static CancellationTokenRegistration UnsafeRegister(this CancellationToken cancellationToken, Action callback, object state) => + cancellationToken.Register(callback, state); } } diff --git a/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs b/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs new file mode 100644 index 0000000..89e2326 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs @@ -0,0 +1,88 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal static partial class ManagedWebSocketExtensions + { + internal static unsafe string GetString(this UTF8Encoding encoding, Span bytes) + { + fixed (byte* b = &MemoryMarshal.GetReference(bytes)) + { + return encoding.GetString(b, bytes.Length); + } + } + + internal static ValueTask ReadAsync(this Stream stream, Memory destination, CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(destination, out ArraySegment array)) + { + return new ValueTask(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken)); + } + else + { + byte[] buffer = ArrayPool.Shared.Rent(destination.Length); + return new ValueTask(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination)); + + async Task FinishReadAsync(Task readTask, byte[] localBuffer, Memory localDestination) + { + try + { + int result = await readTask.ConfigureAwait(false); + new Span(localBuffer, 0, result).CopyTo(localDestination.Span); + return result; + } + finally + { + ArrayPool.Shared.Return(localBuffer); + } + } + } + } + + internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(source, out ArraySegment array)) + { + return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken)); + } + else + { + byte[] buffer = ArrayPool.Shared.Rent(source.Length); + source.Span.CopyTo(buffer); + return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer)); + + async Task FinishWriteAsync(Task writeTask, byte[] localBuffer) + { + try + { + await writeTask.ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(localBuffer); + } + } + } + } + } + + internal static class BitConverter + { + internal static unsafe int ToInt32(ReadOnlySpan value) + { + Debug.Assert(value.Length >= sizeof(int)); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(value)); + } + } +} diff --git a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs index 3dc264a..2484140 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs +++ b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs @@ -1060,13 +1060,13 @@ namespace System.Threading.Tasks // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // ETW event for Parallel For begin int forkJoinContextID = 0; @@ -1322,13 +1322,13 @@ namespace System.Threading.Tasks // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // ETW event for Parallel For begin int forkJoinContextID = 0; @@ -3121,13 +3121,13 @@ namespace System.Threading.Tasks // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // Get our dynamic partitioner -- depends on whether source is castable to OrderablePartitioner // Also, do some error checking. -- 2.7.4