Fix corner-case handling of cancellation exception in ForEachAsync (#59065)
authorStephen Toub <stoub@microsoft.com>
Fri, 17 Sep 2021 05:17:35 +0000 (01:17 -0400)
committerGitHub <noreply@github.com>
Fri, 17 Sep 2021 05:17:35 +0000 (01:17 -0400)
* Fix corner-case handling of cancellation exception in ForEachAsync

If code in Parallel.ForEachAsync throws OperationCanceledExceptions containing the CancellationToken passed to the iteration and that token has _not_ had cancellation requested (so why are they throwing with it) and there are no other exceptions, the ForEachAsync will effectively hang after failing to complete the task returned from it.

The issue stems from how we treat cancellation.  If the user-supplied token hasn't been canceled but we have OperationCanceledExceptions for the token passed into the iteration (the "internal" token), it can only have been canceled because an exception occurred.  We filter out these cancellation exceptions, leaving just the exceptions that are deemed to have caused the failure in the first place.  But the code doesn't currently account for the possibility that the developer is (arguably erroneously) throwing such an OperationCanceledException with the internal cancellation token as that root failure. The fix is to only filter out these OCEs if there are other exceptions besides them.

* Stop filtering out cancellation exceptions in Parallel.ForEachAsync

src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs
src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs

index ee56e0d..3a7d40b 100644 (file)
@@ -480,18 +480,10 @@ namespace System.Threading.Tasks
                 }
                 else
                 {
-                    // Fault with all of the received exceptions, but filter out those due to inner cancellation,
-                    // as they're effectively an implementation detail and stem from the original exception.
-                    Debug.Assert(_exceptions.Count > 0, "If _exceptions was created, it should have also been populated.");
-                    for (int i = 0; i < _exceptions.Count; i++)
-                    {
-                        if (_exceptions[i] is OperationCanceledException oce && oce.CancellationToken == Cancellation.Token)
-                        {
-                            _exceptions[i] = null!;
-                        }
-                    }
-                    _exceptions.RemoveAll(e => e is null);
-                    Debug.Assert(_exceptions.Count > 0, "Since external cancellation wasn't requested, there should have been a non-cancellation exception that triggered internal cancellation.");
+                    // Fail the task with the resulting exceptions.  The first should be the initial
+                    // exception that triggered the operation to shut down.  The others, if any, may
+                    // include cancellation exceptions from other concurrent operations being canceled
+                    // in response to the primary exception.
                     taskSet = TrySetException(_exceptions);
                 }
 
index 91747d3..97ac99b 100644 (file)
@@ -618,6 +618,64 @@ namespace System.Threading.Tasks.Tests
             Assert.True(t.IsCanceled);
         }
 
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task Cancellation_FaultsForOceForNonCancellation(bool internalToken)
+        {
+            static async IAsyncEnumerable<int> Iterate()
+            {
+                int counter = 0;
+                while (true)
+                {
+                    await Task.Yield();
+                    yield return counter++;
+                }
+            }
+
+            var cts = new CancellationTokenSource();
+
+            Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { CancellationToken = cts.Token }, (item, cancellationToken) =>
+            {
+                throw new OperationCanceledException(internalToken ? cancellationToken : cts.Token);
+            });
+
+            await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
+            Assert.True(t.IsFaulted);
+        }
+
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
+        [InlineData(0, 4)]
+        [InlineData(1, 4)]
+        [InlineData(2, 4)]
+        [InlineData(3, 4)]
+        [InlineData(4, 4)]
+        public async Task Cancellation_InternalCancellationExceptionsArentFilteredOut(int numThrowingNonCanceledOce, int total)
+        {
+            var cts = new CancellationTokenSource();
+
+            var barrier = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+            int remainingCount = total;
+
+            Task t = Parallel.ForEachAsync(Enumerable.Range(0, total), new ParallelOptions { CancellationToken = cts.Token, MaxDegreeOfParallelism = total }, async (item, cancellationToken) =>
+            {
+                // Wait for all operations to be started
+                if (Interlocked.Decrement(ref remainingCount) == 0)
+                {
+                    barrier.SetResult();
+                }
+                await barrier.Task;
+
+                throw item < numThrowingNonCanceledOce ?
+                    new OperationCanceledException(cancellationToken) :
+                    throw new FormatException();
+            });
+
+            await Assert.ThrowsAnyAsync<Exception>(() => t);
+            Assert.Equal(total, t.Exception.InnerExceptions.Count);
+            Assert.Equal(numThrowingNonCanceledOce, t.Exception.InnerExceptions.Count(e => e is OperationCanceledException));
+        }
+
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
         public void Exception_FromGetEnumerator_Sync()
         {
@@ -672,7 +730,6 @@ namespace System.Threading.Tasks.Tests
             Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default);
             await Assert.ThrowsAsync<FormatException>(() => t);
             Assert.True(t.IsFaulted);
-            Assert.Equal(1, t.Exception.InnerExceptions.Count);
         }
 
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -694,7 +751,6 @@ namespace System.Threading.Tasks.Tests
             Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default);
             await Assert.ThrowsAsync<FormatException>(() => t);
             Assert.True(t.IsFaulted);
-            Assert.Equal(1, t.Exception.InnerExceptions.Count);
         }
 
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -771,7 +827,6 @@ namespace System.Threading.Tasks.Tests
             Task t = Parallel.ForEachAsync((IEnumerable<int>)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default);
             await Assert.ThrowsAsync<FormatException>(() => t);
             Assert.True(t.IsFaulted);
-            Assert.Equal(1, t.Exception.InnerExceptions.Count);
         }
 
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -780,7 +835,6 @@ namespace System.Threading.Tasks.Tests
             Task t = Parallel.ForEachAsync((IAsyncEnumerable<int>)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default);
             await Assert.ThrowsAsync<DivideByZeroException>(() => t);
             Assert.True(t.IsFaulted);
-            Assert.Equal(1, t.Exception.InnerExceptions.Count);
         }
 
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]