[release/6.0-rc2] Fix corner-case handling of cancellation exception in ForEachAsync...
authorgithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Fri, 17 Sep 2021 13:27:49 +0000 (09:27 -0400)
committerGitHub <noreply@github.com>
Fri, 17 Sep 2021 13:27:49 +0000 (09:27 -0400)
* Fix corner-case handling of cancellation exception in ForEachAsync

If code in Parallel.ForEachAsync throws OperationCanceledExceptions containing the CancellationToken passed to the iteration and that token has _not_ had cancellation requested (so why are they throwing with it) and there are no other exceptions, the ForEachAsync will effectively hang after failing to complete the task returned from it.

The issue stems from how we treat cancellation.  If the user-supplied token hasn't been canceled but we have OperationCanceledExceptions for the token passed into the iteration (the "internal" token), it can only have been canceled because an exception occurred.  We filter out these cancellation exceptions, leaving just the exceptions that are deemed to have caused the failure in the first place.  But the code doesn't currently account for the possibility that the developer is (arguably erroneously) throwing such an OperationCanceledException with the internal cancellation token as that root failure. The fix is to only filter out these OCEs if there are other exceptions besides them.

* Stop filtering out cancellation exceptions in Parallel.ForEachAsync

Co-authored-by: Stephen Toub <stoub@microsoft.com>
src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs
src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs

index ee56e0d1a18d157fe57487e584e00c638ff5e1da..3a7d40b7d0489b196d414d65041e92ee7ca8cd34 100644 (file)
@@ -480,18 +480,10 @@ namespace System.Threading.Tasks
                 }
                 else
                 {
-                    // Fault with all of the received exceptions, but filter out those due to inner cancellation,
-                    // as they're effectively an implementation detail and stem from the original exception.
-                    Debug.Assert(_exceptions.Count > 0, "If _exceptions was created, it should have also been populated.");
-                    for (int i = 0; i < _exceptions.Count; i++)
-                    {
-                        if (_exceptions[i] is OperationCanceledException oce && oce.CancellationToken == Cancellation.Token)
-                        {
-                            _exceptions[i] = null!;
-                        }
-                    }
-                    _exceptions.RemoveAll(e => e is null);
-                    Debug.Assert(_exceptions.Count > 0, "Since external cancellation wasn't requested, there should have been a non-cancellation exception that triggered internal cancellation.");
+                    // Fail the task with the resulting exceptions.  The first should be the initial
+                    // exception that triggered the operation to shut down.  The others, if any, may
+                    // include cancellation exceptions from other concurrent operations being canceled
+                    // in response to the primary exception.
                     taskSet = TrySetException(_exceptions);
                 }
 
index 91747d3828b622e0280549d2150fa1b3a29510ca..97ac99bc159600f3fdeb351f7aa0f7984042cc7f 100644 (file)
@@ -618,6 +618,64 @@ namespace System.Threading.Tasks.Tests
             Assert.True(t.IsCanceled);
         }
 
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task Cancellation_FaultsForOceForNonCancellation(bool internalToken)
+        {
+            static async IAsyncEnumerable<int> Iterate()
+            {
+                int counter = 0;
+                while (true)
+                {
+                    await Task.Yield();
+                    yield return counter++;
+                }
+            }
+
+            var cts = new CancellationTokenSource();
+
+            Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { CancellationToken = cts.Token }, (item, cancellationToken) =>
+            {
+                throw new OperationCanceledException(internalToken ? cancellationToken : cts.Token);
+            });
+
+            await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
+            Assert.True(t.IsFaulted);
+        }
+
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
+        [InlineData(0, 4)]
+        [InlineData(1, 4)]
+        [InlineData(2, 4)]
+        [InlineData(3, 4)]
+        [InlineData(4, 4)]
+        public async Task Cancellation_InternalCancellationExceptionsArentFilteredOut(int numThrowingNonCanceledOce, int total)
+        {
+            var cts = new CancellationTokenSource();
+
+            var barrier = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+            int remainingCount = total;
+
+            Task t = Parallel.ForEachAsync(Enumerable.Range(0, total), new ParallelOptions { CancellationToken = cts.Token, MaxDegreeOfParallelism = total }, async (item, cancellationToken) =>
+            {
+                // Wait for all operations to be started
+                if (Interlocked.Decrement(ref remainingCount) == 0)
+                {
+                    barrier.SetResult();
+                }
+                await barrier.Task;
+
+                throw item < numThrowingNonCanceledOce ?
+                    new OperationCanceledException(cancellationToken) :
+                    throw new FormatException();
+            });
+
+            await Assert.ThrowsAnyAsync<Exception>(() => t);
+            Assert.Equal(total, t.Exception.InnerExceptions.Count);
+            Assert.Equal(numThrowingNonCanceledOce, t.Exception.InnerExceptions.Count(e => e is OperationCanceledException));
+        }
+
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
         public void Exception_FromGetEnumerator_Sync()
         {
@@ -672,7 +730,6 @@ namespace System.Threading.Tasks.Tests
             Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default);
             await Assert.ThrowsAsync<FormatException>(() => t);
             Assert.True(t.IsFaulted);
-            Assert.Equal(1, t.Exception.InnerExceptions.Count);
         }
 
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -694,7 +751,6 @@ namespace System.Threading.Tasks.Tests
             Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default);
             await Assert.ThrowsAsync<FormatException>(() => t);
             Assert.True(t.IsFaulted);
-            Assert.Equal(1, t.Exception.InnerExceptions.Count);
         }
 
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -771,7 +827,6 @@ namespace System.Threading.Tasks.Tests
             Task t = Parallel.ForEachAsync((IEnumerable<int>)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default);
             await Assert.ThrowsAsync<FormatException>(() => t);
             Assert.True(t.IsFaulted);
-            Assert.Equal(1, t.Exception.InnerExceptions.Count);
         }
 
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -780,7 +835,6 @@ namespace System.Threading.Tasks.Tests
             Task t = Parallel.ForEachAsync((IAsyncEnumerable<int>)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default);
             await Assert.ThrowsAsync<DivideByZeroException>(() => t);
             Assert.True(t.IsFaulted);
-            Assert.Equal(1, t.Exception.InnerExceptions.Count);
         }
 
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]