Nullable: CancellationToken (#23609)
authorStephen Toub <stoub@microsoft.com>
Mon, 1 Apr 2019 23:35:08 +0000 (19:35 -0400)
committerGitHub <noreply@github.com>
Mon, 1 Apr 2019 23:35:08 +0000 (19:35 -0400)
src/System.Private.CoreLib/shared/System/Threading/CancellationToken.cs
src/System.Private.CoreLib/shared/System/Threading/CancellationTokenRegistration.cs
src/System.Private.CoreLib/shared/System/Threading/CancellationTokenSource.cs

index 5402749..98b3ba2 100644 (file)
@@ -2,8 +2,8 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+#nullable enable
 using System.Diagnostics;
-using System.Runtime.CompilerServices;
 
 namespace System.Threading
 {
@@ -33,10 +33,14 @@ namespace System.Threading
         // The backing TokenSource.  
         // if null, it implicitly represents the same thing as new CancellationToken(false).
         // When required, it will be instantiated to reflect this.
-        private readonly CancellationTokenSource _source;
+        private readonly CancellationTokenSource? _source;
         //!! warning. If more fields are added, the assumptions in CreateLinkedToken may no longer be valid
 
-        private readonly static Action<object> s_actionToActionObjShunt = obj => ((Action)obj)();
+        private readonly static Action<object?> s_actionToActionObjShunt = obj =>
+        {
+            Debug.Assert(obj is Action, $"Expected {typeof(Action)}, got {obj}");
+            ((Action)obj)();
+        };
 
         /// <summary>
         /// Returns an empty CancellationToken value.
@@ -96,7 +100,7 @@ namespace System.Threading
         /// <summary>
         /// Internal constructor only a CancellationTokenSource should create a CancellationToken
         /// </summary>
-        internal CancellationToken(CancellationTokenSource source) => _source = source;
+        internal CancellationToken(CancellationTokenSource? source) => _source = source;
 
         /// <summary>
         /// Initializes the <see cref="T:System.Threading.CancellationToken">CancellationToken</see>.
@@ -189,7 +193,7 @@ namespace System.Threading
         /// <returns>The <see cref="T:System.Threading.CancellationTokenRegistration"/> instance that can 
         /// be used to unregister the callback.</returns>
         /// <exception cref="T:System.ArgumentNullException"><paramref name="callback"/> is null.</exception>
-        public CancellationTokenRegistration Register(Action<object> callback, object state) =>
+        public CancellationTokenRegistration Register(Action<object?> callback, object? state) =>
             Register(callback, state, useSynchronizationContext: false, useExecutionContext: true);
 
         /// <summary>
@@ -217,7 +221,7 @@ namespace System.Threading
         /// <exception cref="T:System.ArgumentNullException"><paramref name="callback"/> is null.</exception>
         /// <exception cref="T:System.ObjectDisposedException">The associated <see
         /// cref="T:System.Threading.CancellationTokenSource">CancellationTokenSource</see> has been disposed.</exception>
-        public CancellationTokenRegistration Register(Action<object> callback, object state, bool useSynchronizationContext) =>
+        public CancellationTokenRegistration Register(Action<object?> callback, object? state, bool useSynchronizationContext) =>
             Register(callback, state, useSynchronizationContext, useExecutionContext: true);
 
         /// <summary>
@@ -239,7 +243,7 @@ namespace System.Threading
         /// <returns>The <see cref="T:System.Threading.CancellationTokenRegistration"/> instance that can 
         /// be used to unregister the callback.</returns>
         /// <exception cref="T:System.ArgumentNullException"><paramref name="callback"/> is null.</exception>
-        public CancellationTokenRegistration UnsafeRegister(Action<object> callback, object state) =>
+        public CancellationTokenRegistration UnsafeRegister(Action<object?> callback, object? state) =>
             Register(callback, state, useSynchronizationContext: false, useExecutionContext: false);
 
         /// <summary>
@@ -263,12 +267,12 @@ namespace System.Threading
         /// <exception cref="T:System.ArgumentNullException"><paramref name="callback"/> is null.</exception>
         /// <exception cref="T:System.ObjectDisposedException">The associated <see
         /// cref="T:System.Threading.CancellationTokenSource">CancellationTokenSource</see> has been disposed.</exception>
-        private CancellationTokenRegistration Register(Action<object> callback, object state, bool useSynchronizationContext, bool useExecutionContext)
+        private CancellationTokenRegistration Register(Action<object?> callback, object? state, bool useSynchronizationContext, bool useExecutionContext)
         {
             if (callback == null)
                 throw new ArgumentNullException(nameof(callback));
 
-            CancellationTokenSource source = _source;
+            CancellationTokenSource? source = _source;
             return source != null ?
                 source.InternalRegister(callback, state, useSynchronizationContext ? SynchronizationContext.Current : null, useExecutionContext ? ExecutionContext.Capture() : null) :
                 default; // Nothing to do for tokens than can never reach the canceled state. Give back a dummy registration.
@@ -296,7 +300,7 @@ namespace System.Threading
         /// from public CancellationToken constructors and their <see cref="IsCancellationRequested"/> values are equal.</returns>
         /// <exception cref="T:System.ObjectDisposedException">An associated <see
         /// cref="T:System.Threading.CancellationTokenSource">CancellationTokenSource</see> has been disposed.</exception>
-        public override bool Equals(object other) => other is CancellationToken && Equals((CancellationToken)other);
+        public override bool Equals(object? other) => other is CancellationToken && Equals((CancellationToken)other);
 
         /// <summary>
         /// Serves as a hash function for a <see cref="T:System.Threading.CancellationToken">CancellationToken</see>.
index bab2ce9..edb29e0 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+#nullable enable
 using System.Threading.Tasks;
 
 namespace System.Threading
@@ -147,7 +148,7 @@ namespace System.Threading
         /// they both refer to the output of a single call to the same Register method of a 
         /// <see cref="T:System.Threading.CancellationToken">CancellationToken</see>. 
         /// </returns>
-        public override bool Equals(object obj) => obj is CancellationTokenRegistration && Equals((CancellationTokenRegistration)obj);
+        public override bool Equals(object? obj) => obj is CancellationTokenRegistration && Equals((CancellationTokenRegistration)obj);
 
         /// <summary>
         /// Determines whether the current <see cref="T:System.Threading.CancellationToken">CancellationToken</see> instance is equal to the 
index 8c4dd01..998d688 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+#nullable enable
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Threading.Tasks;
@@ -30,7 +31,10 @@ namespace System.Threading
 
         /// <summary>Delegate used with <see cref="Timer"/> to trigger cancellation of a <see cref="CancellationTokenSource"/>.</summary>
         private static readonly TimerCallback s_timerCallback = obj =>
+        {
+            Debug.Assert(obj is CancellationTokenSource, $"Expected {typeof(CancellationTokenSource)}, got {obj}");
             ((CancellationTokenSource)obj).NotifyCancellation(throwOnFirstException: false); // skip ThrowIfDisposed() check in Cancel()
+        };
 
         /// <summary>The number of callback partitions to use in a <see cref="CancellationTokenSource"/>. Must be a power of 2.</summary>
         private static readonly int s_numPartitions = GetPartitionCount();
@@ -49,11 +53,11 @@ namespace System.Threading
         /// <summary>Tracks the running callback to assist ctr.Dispose() to wait for the target callback to complete.</summary>
         private long _executingCallbackId;
         /// <summary>Partitions of callbacks.  Split into multiple partitions to help with scalability of registering/unregistering; each is protected by its own lock.</summary>
-        private volatile CallbackPartition[] _callbackPartitions;
+        private volatile CallbackPartition?[]? _callbackPartitions;
         /// <summary>TimerQueueTimer used by CancelAfter and Timer-related ctors. Used instead of Timer to avoid extra allocations and because the rooted behavior is desired.</summary>
-        private volatile TimerQueueTimer _timer;
+        private volatile TimerQueueTimer? _timer;
         /// <summary><see cref="System.Threading.WaitHandle"/> lazily initialized and returned from <see cref="WaitHandle"/>.</summary>
-        private volatile ManualResetEvent _kernelEvent;
+        private volatile ManualResetEvent? _kernelEvent;
         /// <summary>Whether this <see cref="CancellationTokenSource"/> has been disposed.</summary>
         private bool _disposed;
 
@@ -129,10 +133,10 @@ namespace System.Threading
                 //   2. if IsCancellationRequested = false, then NotifyCancellation will see that the event exists, and will call Set().
                 if (IsCancellationRequested)
                 {
-                    _kernelEvent.Set();
+                    _kernelEvent!.Set(); // TODO-NULLABLE: The ! shouldn't be necessary due to CompareExchange initialization above.
                 }
 
-                return _kernelEvent;
+                return _kernelEvent!; // TODO-NULLABLE: The ! shouldn't be necessary due to CompareExchange initialization above.
             }
         }
 
@@ -355,7 +359,7 @@ namespace System.Threading
             // expired and Disposed itself).  But this would be considered bad behavior, as
             // Dispose() is not thread-safe and should not be called concurrently with CancelAfter().
 
-            TimerQueueTimer timer = _timer;
+            TimerQueueTimer? timer = _timer;
             if (timer == null)
             {
                 // Lazily initialize the timer in a thread-safe fashion.
@@ -363,7 +367,7 @@ namespace System.Threading
                 // chance on a timer "losing" the initialization and then
                 // cancelling the token before it (the timer) can be disposed.
                 timer = new TimerQueueTimer(s_timerCallback, this, Timeout.UnsignedInfinite, Timeout.UnsignedInfinite, flowExecutionContext: false);
-                TimerQueueTimer currentTimer = Interlocked.CompareExchange(ref _timer, timer, null);
+                TimerQueueTimer? currentTimer = Interlocked.CompareExchange(ref _timer, timer, null);
                 if (currentTimer != null)
                 {
                     // We did not initialize the timer.  Dispose the new timer.
@@ -426,7 +430,7 @@ namespace System.Threading
                 // internal source of cancellation, then Disposes of that linked source, which could
                 // happen at the same time the external entity is requesting cancellation).
 
-                TimerQueueTimer timer = _timer;
+                TimerQueueTimer? timer = _timer;
                 if (timer != null)
                 {
                     _timer = null;
@@ -442,7 +446,7 @@ namespace System.Threading
                 // transitioned to and while it's in the NotifyingState.
                 if (_kernelEvent != null)
                 {
-                    ManualResetEvent mre = Interlocked.Exchange(ref _kernelEvent, null);
+                    ManualResetEvent? mre = Interlocked.Exchange<ManualResetEvent?>(ref _kernelEvent!, null);
                     if (mre != null && _state != NotifyingState)
                     {
                         mre.Dispose();
@@ -471,7 +475,7 @@ namespace System.Threading
         /// callback will have been run by the time this method returns.
         /// </summary>
         internal CancellationTokenRegistration InternalRegister(
-            Action<object> callback, object stateForCallback, SynchronizationContext syncContext, ExecutionContext executionContext)
+            Action<object?> callback, object? stateForCallback, SynchronizationContext? syncContext, ExecutionContext? executionContext)
         {
             Debug.Assert(this != s_neverCanceledSource, "This source should never be exposed via a CancellationToken.");
 
@@ -493,7 +497,7 @@ namespace System.Threading
                 }
 
                 // Get the partitions...
-                CallbackPartition[] partitions = _callbackPartitions;
+                CallbackPartition?[]? partitions = _callbackPartitions;
                 if (partitions == null)
                 {
                     partitions = new CallbackPartition[s_numPartitions];
@@ -503,7 +507,7 @@ namespace System.Threading
                 // ...and determine which partition to use.
                 int partitionIndex = Environment.CurrentManagedThreadId & s_numPartitionsMask;
                 Debug.Assert(partitionIndex < partitions.Length, $"Expected {partitionIndex} to be less than {partitions.Length}");
-                CallbackPartition partition = partitions[partitionIndex];
+                CallbackPartition? partition = partitions[partitionIndex];
                 if (partition == null)
                 {
                     partition = new CallbackPartition(this);
@@ -512,7 +516,7 @@ namespace System.Threading
 
                 // Store the callback information into the callback arrays.
                 long id;
-                CallbackNode node;
+                CallbackNode? node;
                 bool lockTaken = false;
                 partition.Lock.Enter(ref lockTaken);
                 try
@@ -576,7 +580,7 @@ namespace System.Threading
             if (!IsCancellationRequested && Interlocked.CompareExchange(ref _state, NotifyingState, NotCanceledState) == NotCanceledState)
             {
                 // Dispose of the timer, if any.  Dispose may be running concurrently here, but TimerQueueTimer.Close is thread-safe.
-                TimerQueueTimer timer = _timer;
+                TimerQueueTimer? timer = _timer;
                 if (timer != null)
                 {
                     _timer = null;
@@ -609,20 +613,20 @@ namespace System.Threading
 
             // If there are no callbacks to run, we can safely exit.  Any race conditions to lazy initialize it
             // will see IsCancellationRequested and will then run the callback themselves.
-            CallbackPartition[] partitions = Interlocked.Exchange(ref _callbackPartitions, null);
+            CallbackPartition?[]? partitions = Interlocked.Exchange(ref _callbackPartitions, null);
             if (partitions == null)
             {
                 Interlocked.Exchange(ref _state, NotifyingCompleteState);
                 return;
             }
 
-            List<Exception> exceptionList = null;
+            List<Exception>? exceptionList = null;
             try
             {
                 // For each partition, and each callback in that partition, execute the associated handler.
                 // We call the delegates in LIFO order on each partition so that callbacks fire 'deepest first'.
                 // This is intended to help with nesting scenarios so that child enlisters cancel before their parents.
-                foreach (CallbackPartition partition in partitions)
+                foreach (CallbackPartition? partition in partitions)
                 {
                     if (partition == null)
                     {
@@ -635,7 +639,7 @@ namespace System.Threading
                     // to still be effective even as other registrations are being invoked.
                     while (true)
                     {
-                        CallbackNode node;
+                        CallbackNode? node;
                         bool lockTaken = false;
                         partition.Lock.Enter(ref lockTaken);
                         try
@@ -871,9 +875,12 @@ namespace System.Threading
 
         private sealed class LinkedNCancellationTokenSource : CancellationTokenSource
         {
-            internal static readonly Action<object> s_linkedTokenCancelDelegate =
-                s => ((CancellationTokenSource)s).NotifyCancellation(throwOnFirstException: false); // skip ThrowIfDisposed() check in Cancel()
-            private CancellationTokenRegistration[] _linkingRegistrations;
+            internal static readonly Action<object?> s_linkedTokenCancelDelegate = s =>
+            {
+                Debug.Assert(s is CancellationTokenSource, $"Expected {typeof(CancellationTokenSource)}, got {s}");
+                ((CancellationTokenSource)s).NotifyCancellation(throwOnFirstException: false); // skip ThrowIfDisposed() check in Cancel()
+            };
+            private CancellationTokenRegistration[]? _linkingRegistrations;
 
             internal LinkedNCancellationTokenSource(params CancellationToken[] tokens)
             {
@@ -898,7 +905,7 @@ namespace System.Threading
                     return;
                 }
 
-                CancellationTokenRegistration[] linkingRegistrations = _linkingRegistrations;
+                CancellationTokenRegistration[]? linkingRegistrations = _linkingRegistrations;
                 if (linkingRegistrations != null)
                 {
                     _linkingRegistrations = null; // release for GC once we're done enumerating
@@ -919,9 +926,9 @@ namespace System.Threading
             /// <summary>Lock that protects all state in the partition.</summary>
             public SpinLock Lock = new SpinLock(enableThreadOwnerTracking: false); // mutable struct; do not make this readonly
             /// <summary>Doubly-linked list of callbacks registered with the partition. Callbacks are removed during unregistration and as they're invoked.</summary>
-            public CallbackNode Callbacks;
+            public CallbackNode? Callbacks;
             /// <summary>Singly-linked list of free nodes that can be used for subsequent callback registrations.</summary>
-            public CallbackNode FreeNodeList;
+            public CallbackNode? FreeNodeList;
             /// <summary>Every callback is assigned a unique, never-reused ID.  This defines the next available ID.</summary>
             public long NextAvailableId = 1; // avoid using 0, as that's the default long value and used to represent an empty node
 
@@ -995,14 +1002,14 @@ namespace System.Threading
         internal sealed class CallbackNode
         {
             public readonly CallbackPartition Partition;
-            public CallbackNode Prev;
-            public CallbackNode Next;
+            public CallbackNode? Prev;
+            public CallbackNode? Next;
 
             public long Id;
-            public Action<object> Callback;
-            public object CallbackState;
-            public ExecutionContext ExecutionContext;
-            public SynchronizationContext SynchronizationContext;
+            public Action<object?>? Callback;
+            public object? CallbackState;
+            public ExecutionContext? ExecutionContext;
+            public SynchronizationContext? SynchronizationContext;
             
             public CallbackNode(CallbackPartition partition)
             {
@@ -1012,17 +1019,21 @@ namespace System.Threading
 
             public void ExecuteCallback()
             {
-                ExecutionContext context = ExecutionContext;
+                ExecutionContext? context = ExecutionContext;
                 if (context != null)
                 {
                     ExecutionContext.RunInternal(context, s =>
                     {
+                        Debug.Assert(s is CallbackNode, $"Expected {typeof(CallbackNode)}, got {s}");
                         CallbackNode n = (CallbackNode)s;
+
+                        Debug.Assert(n.Callback != null);
                         n.Callback(n.CallbackState);
                     }, this);
                 }
                 else
                 {
+                    Debug.Assert(Callback != null);
                     Callback(CallbackState);
                 }
             }