From 12025658acf6c9747fcbb63569be4107b3e0f0f3 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 8 Oct 2020 21:57:38 -0400 Subject: [PATCH] Add CancellationToken.Register with callback accepting token (#43114) * Add CancellationToken.Register with callback accepting token * Address PR feedback --- .../Tasks/TaskCompletionSourceWithCancellation.cs | 10 +----- .../src/System/Diagnostics/Process.cs | 9 +++-- .../Net/Http/SocketsHttpHandler/CreditWaiter.cs | 5 ++- .../Net/Http/SocketsHttpHandler/Http2Connection.cs | 13 +++---- .../Net/Http/SocketsHttpHandler/Http2Stream.cs | 12 ++----- .../src/System/Threading/CancellationToken.cs | 27 +++++++++++++- .../System/Threading/CancellationTokenSource.cs | 42 ++++++++++++++-------- .../src/System/Threading/Tasks/Task.cs | 22 ++++++------ src/libraries/System.Runtime/ref/System.Runtime.cs | 2 ++ .../tests/CancellationTokenTests.cs | 38 ++++++++++++++++---- 10 files changed, 113 insertions(+), 67 deletions(-) diff --git a/src/libraries/Common/src/System/Threading/Tasks/TaskCompletionSourceWithCancellation.cs b/src/libraries/Common/src/System/Threading/Tasks/TaskCompletionSourceWithCancellation.cs index 4c1e370..436333f 100644 --- a/src/libraries/Common/src/System/Threading/Tasks/TaskCompletionSourceWithCancellation.cs +++ b/src/libraries/Common/src/System/Threading/Tasks/TaskCompletionSourceWithCancellation.cs @@ -11,21 +11,13 @@ namespace System.Threading.Tasks /// internal class TaskCompletionSourceWithCancellation : TaskCompletionSource { - private CancellationToken _cancellationToken; - public TaskCompletionSourceWithCancellation() : base(TaskCreationOptions.RunContinuationsAsynchronously) { } - private void OnCancellation() - { - TrySetCanceled(_cancellationToken); - } - public async ValueTask WaitWithCancellationAsync(CancellationToken cancellationToken) { - _cancellationToken = cancellationToken; - using (cancellationToken.UnsafeRegister(static s => ((TaskCompletionSourceWithCancellation)s!).OnCancellation(), this)) + using (cancellationToken.UnsafeRegister(static (s, cancellationToken) => ((TaskCompletionSourceWithCancellation)s!).TrySetCanceled(cancellationToken), this)) { return await Task.ConfigureAwait(false); } diff --git a/src/libraries/System.Diagnostics.Process/src/System/Diagnostics/Process.cs b/src/libraries/System.Diagnostics.Process/src/System/Diagnostics/Process.cs index 99184a2..ab5e7be 100644 --- a/src/libraries/System.Diagnostics.Process/src/System/Diagnostics/Process.cs +++ b/src/libraries/System.Diagnostics.Process/src/System/Diagnostics/Process.cs @@ -1459,9 +1459,9 @@ namespace System.Diagnostics throw; } - var tcs = new TaskCompletionSourceWithCancellation(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - EventHandler handler = (_, _) => tcs.TrySetResult(true); + EventHandler handler = (_, _) => tcs.TrySetResult(); Exited += handler; try @@ -1473,7 +1473,10 @@ namespace System.Diagnostics else { // CASE 1.1 & CASE 3.1: Process exits or is canceled here - await tcs.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false); + using (cancellationToken.UnsafeRegister(static (s, cancellationToken) => ((TaskCompletionSource)s!).TrySetCanceled(cancellationToken), tcs)) + { + await tcs.Task.ConfigureAwait(false); + } } // Wait until output streams have been drained diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditWaiter.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditWaiter.cs index 43a8d86..c6052fd 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditWaiter.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditWaiter.cs @@ -51,11 +51,10 @@ namespace System.Net.Http private void RegisterCancellation(CancellationToken cancellationToken) { _cancellationToken = cancellationToken; - _registration = cancellationToken.UnsafeRegister(static s => + _registration = cancellationToken.UnsafeRegister(static (s, cancellationToken) => { // The callback will only fire if cancellation owns the right to complete the instance. - var thisRef = (CreditWaiter)s!; - thisRef._source.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(thisRef._cancellationToken))); + ((CreditWaiter)s!)._source.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(cancellationToken))); }, this); } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index be457ef..e966271 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -844,7 +844,6 @@ namespace System.Net.Http private abstract class WriteQueueEntry : TaskCompletionSource { - private readonly CancellationToken _cancellationToken; private readonly CancellationTokenRegistration _cancellationRegistration; public WriteQueueEntry(int writeBytes, CancellationToken cancellationToken) @@ -852,17 +851,15 @@ namespace System.Net.Http { WriteBytes = writeBytes; - _cancellationToken = cancellationToken; - _cancellationRegistration = cancellationToken.UnsafeRegister(static s => ((WriteQueueEntry)s!).OnCancellation(), this); + _cancellationRegistration = cancellationToken.UnsafeRegister(static (s, cancellationToken) => + { + bool canceled = ((WriteQueueEntry)s!).TrySetCanceled(cancellationToken); + Debug.Assert(canceled, "Callback should have been unregistered if the operation was completing successfully."); + }, this); } public int WriteBytes { get; } - private void OnCancellation() - { - SetCanceled(_cancellationToken); - } - public bool TryDisableCancellation() { _cancellationRegistration.Dispose(); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 36ed2d9..4b2d169 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -1326,7 +1326,7 @@ namespace System.Net.Http // However, this could still be non-cancelable if HttpMessageInvoker was used, at which point this will only be // cancelable if the caller's token was cancelable. - _waitSourceCancellation = cancellationToken.UnsafeRegister(static s => + _waitSourceCancellation = cancellationToken.UnsafeRegister(static (s, cancellationToken) => { var thisRef = (Http2Stream)s!; @@ -1342,18 +1342,10 @@ namespace System.Net.Http { // Wake up the wait. It will then immediately check whether cancellation was requested and throw if it was. thisRef._waitSource.SetException(ExceptionDispatchInfo.SetCurrentStackTrace( - CancellationHelper.CreateOperationCanceledException(null, thisRef._waitSourceCancellation.Token))); + CancellationHelper.CreateOperationCanceledException(null, cancellationToken))); } }, this); - // There's a race condition in UnsafeRegister above. If cancellation is requested prior to UnsafeRegister, - // the delegate may be invoked synchronously as part of the UnsafeRegister call. In that case, it will execute - // before _waitSourceCancellation has been set, which means UnsafeRegister will have set a cancellation - // exception into the wait source with a default token rather than the ideal one. To handle that, - // we check for cancellation again, and throw here with the right token. Worst case, if cancellation is - // requested prior to here, we end up allocating an extra OCE object. - CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - return new ValueTask(this, _waitSource.Version); } diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/CancellationToken.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/CancellationToken.cs index c7961ac..18d7194 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/CancellationToken.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/CancellationToken.cs @@ -195,6 +195,19 @@ namespace System.Threading public CancellationTokenRegistration Register(Action callback, object? state) => Register(callback, state, useSynchronizationContext: false, useExecutionContext: true); + /// Registers a delegate that will be called when this CancellationToken is canceled. + /// + /// If this token is already in the canceled state, the delegate will be run immediately and synchronously. Any exception the delegate + /// generates will be propagated out of this method call. The current ExecutionContext, if one exists, + /// will be captured along with the delegate and will be used when executing it. The current is not captured. + /// + /// The delegate to be executed when the CancellationToken is canceled. + /// The state to pass to the when the delegate is invoked. This may be null. + /// The instance that can be used to unregister the callback. + /// is null. + public CancellationTokenRegistration Register(Action callback, object? state) => + Register(callback, state, useSynchronizationContext: false, useExecutionContext: true); + /// /// Registers a delegate that will be called when this /// CancellationToken is canceled. @@ -245,6 +258,18 @@ namespace System.Threading public CancellationTokenRegistration UnsafeRegister(Action callback, object? state) => Register(callback, state, useSynchronizationContext: false, useExecutionContext: false); + /// Registers a delegate that will be called when this CancellationToken is canceled. + /// + /// If this token is already in the canceled state, the delegate will be run immediately and synchronously. Any exception the delegate + /// generates will be propagated out of this method call. is not captured nor flowed to the callback's invocation. + /// + /// The delegate to be executed when the CancellationToken is canceled. + /// The state to pass to the when the delegate is invoked. This may be null. + /// The instance that can be used to unregister the callback. + /// is null. + public CancellationTokenRegistration UnsafeRegister(Action callback, object? state) => + Register(callback, state, useSynchronizationContext: false, useExecutionContext: false); + /// /// Registers a delegate that will be called when this /// CancellationToken is canceled. @@ -267,7 +292,7 @@ namespace System.Threading /// is null. /// The associated CancellationTokenSource has been disposed. - private CancellationTokenRegistration Register(Action callback, object? state, bool useSynchronizationContext, bool useExecutionContext) + private CancellationTokenRegistration Register(Delegate callback, object? state, bool useSynchronizationContext, bool useExecutionContext) { if (callback == null) throw new ArgumentNullException(nameof(callback)); diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/CancellationTokenSource.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/CancellationTokenSource.cs index d439fd3..c83625f 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/CancellationTokenSource.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/CancellationTokenSource.cs @@ -475,9 +475,10 @@ namespace System.Threading /// callback will have been run by the time this method returns. /// internal CancellationTokenRegistration InternalRegister( - Action callback, object? stateForCallback, SynchronizationContext? syncContext, ExecutionContext? executionContext) + Delegate callback, object? stateForCallback, SynchronizationContext? syncContext, ExecutionContext? executionContext) { Debug.Assert(this != s_neverCanceledSource, "This source should never be exposed via a CancellationToken."); + Debug.Assert(callback is Action || callback is Action); // If not canceled, register the handler; if canceled already, run the callback synchronously. // This also ensures that during ExecuteCallbackHandlers() there will be no mutation of the _callbackPartitions. @@ -571,7 +572,7 @@ namespace System.Threading } // Cancellation already occurred. Run the callback on this thread and return an empty registration. - callback(stateForCallback); + Invoke(callback, stateForCallback, this); return default; } @@ -1012,7 +1013,7 @@ namespace System.Threading public CallbackNode? Next; public long Id; - public Action? Callback; + public Delegate? Callback; // Action or Action public object? CallbackState; public ExecutionContext? ExecutionContext; public SynchronizationContext? SynchronizationContext; @@ -1026,23 +1027,36 @@ namespace System.Threading public void ExecuteCallback() { ExecutionContext? context = ExecutionContext; - if (context != null) + if (context is null) { - ExecutionContext.RunInternal(context, static s => - { - Debug.Assert(s is CallbackNode, $"Expected {typeof(CallbackNode)}, got {s}"); - CallbackNode n = (CallbackNode)s; - - Debug.Assert(n.Callback != null); - n.Callback(n.CallbackState); - }, this); + Debug.Assert(Callback != null); + Invoke(Callback, CallbackState, Partition.Source); } else { - Debug.Assert(Callback != null); - Callback(CallbackState); + ExecutionContext.RunInternal(context, static s => + { + var node = (CallbackNode)s!; + Debug.Assert(node.Callback != null); + Invoke(node.Callback, node.CallbackState, node.Partition.Source); + }, this); + } } } + + private static void Invoke(Delegate d, object? state, CancellationTokenSource source) + { + Debug.Assert(d is Action || d is Action); + + if (d is Action actionWithState) + { + actionWithState(state); + } + else + { + ((Action)d)(state, new CancellationToken(source)); + } + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs index 13b4914..b82412d 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs @@ -5412,25 +5412,23 @@ namespace System.Threading.Tasks /// DelayPromise that also supports cancellation. private sealed class DelayPromiseWithCancellation : DelayPromise { - private readonly CancellationToken _token; private readonly CancellationTokenRegistration _registration; internal DelayPromiseWithCancellation(int millisecondsDelay, CancellationToken token) : base(millisecondsDelay) { Debug.Assert(token.CanBeCanceled); - _token = token; - _registration = token.UnsafeRegister(static state => ((DelayPromiseWithCancellation)state!).CompleteCanceled(), this); - } - - private void CompleteCanceled() - { - if (TrySetCanceled(_token)) + _registration = token.UnsafeRegister(static (state, cancellationToken) => { - Cleanup(); - // This path doesn't invoke RemoveFromActiveTasks or TraceOperationCompletion - // because that's strangely already handled inside of TrySetCanceled. - } + var thisRef = (DelayPromiseWithCancellation)state!; + if (thisRef.TrySetCanceled(cancellationToken)) + { + thisRef.Cleanup(); + // This path doesn't invoke RemoveFromActiveTasks or TraceOperationCompletion + // because that's strangely already handled inside of TrySetCanceled. + } + }, this); + } protected override void Cleanup() diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index e5143cc..e74776a 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -10873,9 +10873,11 @@ namespace System.Threading public System.Threading.CancellationTokenRegistration Register(System.Action callback) { throw null; } public System.Threading.CancellationTokenRegistration Register(System.Action callback, bool useSynchronizationContext) { throw null; } public System.Threading.CancellationTokenRegistration Register(System.Action callback, object? state) { throw null; } + public System.Threading.CancellationTokenRegistration Register(System.Action callback, object? state) { throw null; } public System.Threading.CancellationTokenRegistration Register(System.Action callback, object? state, bool useSynchronizationContext) { throw null; } public void ThrowIfCancellationRequested() { } public System.Threading.CancellationTokenRegistration UnsafeRegister(System.Action callback, object? state) { throw null; } + public System.Threading.CancellationTokenRegistration UnsafeRegister(System.Action callback, object? state) { throw null; } } public readonly partial struct CancellationTokenRegistration : System.IAsyncDisposable, System.IDisposable, System.IEquatable { diff --git a/src/libraries/System.Threading.Tasks/tests/CancellationTokenTests.cs b/src/libraries/System.Threading.Tasks/tests/CancellationTokenTests.cs index d354ad1..798cfc4 100644 --- a/src/libraries/System.Threading.Tasks/tests/CancellationTokenTests.cs +++ b/src/libraries/System.Threading.Tasks/tests/CancellationTokenTests.cs @@ -15,12 +15,18 @@ namespace System.Threading.Tasks.Tests [Fact] public static void CancellationTokenRegister_Exceptions() { - CancellationToken token = new CancellationToken(); - Assert.Throws(() => token.Register(null)); + CancellationToken token = default; - Assert.Throws(() => token.Register(null, false)); + AssertExtensions.Throws("callback", () => token.Register(null)); + AssertExtensions.Throws("callback", () => token.Register(null, false)); - Assert.Throws(() => token.Register(null, null)); + AssertExtensions.Throws("callback", () => token.Register((Action)null, null)); + AssertExtensions.Throws("callback", () => token.Register((Action)null, null, false)); + AssertExtensions.Throws(() => token.Register((Action)null, null, true)); + AssertExtensions.Throws(() => token.Register((Action)null, null)); + + AssertExtensions.Throws("callback", () => token.UnsafeRegister((Action)null, null)); + AssertExtensions.Throws("callback", () => token.UnsafeRegister((Action)null, null)); } [Fact] @@ -1505,8 +1511,10 @@ namespace System.Threading.Tasks.Tests // Validating that no exception is thrown. } - [Fact] - public static void Register_ExecutionContextFlowsIfExpected() + [Theory] + [InlineData(false)] + [InlineData(true)] + public static void Register_ExecutionContextFlowsIfExpected(bool callbackWithToken) { var cts = new CancellationTokenSource(); @@ -1526,10 +1534,26 @@ namespace System.Threading.Tasks.Tests }; CancellationToken ct = cts.Token; - if (flowExecutionContext) + if (flowExecutionContext && callbackWithToken) + { + ct.Register((s, t) => + { + Assert.Equal(ct, t); + callback(s); + }, i); + } + else if (flowExecutionContext) { ct.Register(callback, i); } + else if (callbackWithToken) + { + ct.UnsafeRegister((s, t) => + { + Assert.Equal(ct, t); + callback(s); + }, i); + } else { ct.UnsafeRegister(callback, i); -- 2.7.4