Fix System.Text.Json IAsyncEnumerator disposal on cancellation (#57505)
authorEirik Tsarpalis <eirik.tsarpalis@gmail.com>
Tue, 17 Aug 2021 10:16:27 +0000 (11:16 +0100)
committerGitHub <noreply@github.com>
Tue, 17 Aug 2021 10:16:27 +0000 (11:16 +0100)
* Ensure WriteStack.Pending task is awaited on exception. Ensure IAsyncDisposable instances are disposed exactly once.

Fixes #57360.

* Update src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.AsyncEnumerable.cs

src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IAsyncEnumerableOfTConverter.cs
src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonSerializer.Write.Stream.cs
src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs
src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStackFrame.cs
src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.AsyncEnumerable.cs

index ca3def9..20dcedf 100644 (file)
@@ -48,18 +48,22 @@ namespace System.Text.Json.Serialization.Converters
             IAsyncEnumerator<TElement> enumerator;
             ValueTask<bool> moveNextTask;
 
-            if (state.Current.AsyncEnumerator is null)
+            if (state.Current.AsyncDisposable is null)
             {
                 enumerator = value.GetAsyncEnumerator(state.CancellationToken);
+                // async enumerators can only be disposed asynchronously;
+                // store in the WriteStack for future disposal
+                // by the root async serialization context.
+                state.Current.AsyncDisposable = enumerator;
+                // enumerator.MoveNextAsync() calls can throw,
+                // ensure the enumerator already is stored
+                // in the WriteStack for proper disposal.
                 moveNextTask = enumerator.MoveNextAsync();
-                // we always need to attach the enumerator to the stack
-                // since it will need to be disposed asynchronously.
-                state.Current.AsyncEnumerator = enumerator;
             }
             else
             {
-                Debug.Assert(state.Current.AsyncEnumerator is IAsyncEnumerator<TElement>);
-                enumerator = (IAsyncEnumerator<TElement>)state.Current.AsyncEnumerator;
+                Debug.Assert(state.Current.AsyncDisposable is IAsyncEnumerator<TElement>);
+                enumerator = (IAsyncEnumerator<TElement>)state.Current.AsyncDisposable;
 
                 if (state.Current.AsyncEnumeratorIsPendingCompletion)
                 {
@@ -84,6 +88,10 @@ namespace System.Text.Json.Serialization.Converters
             {
                 if (!moveNextTask.Result)
                 {
+                    // we have completed serialization for the enumerator,
+                    // clear from the stack and schedule for async disposal.
+                    state.Current.AsyncDisposable = null;
+                    state.AddCompletedAsyncDisposable(enumerator);
                     return true;
                 }
 
index bd04e35..04b7b10 100644 (file)
@@ -325,28 +325,30 @@ namespace System.Text.Json
                         try
                         {
                             isFinalBlock = WriteCore(converter, writer, value, options, ref state);
+                            await bufferWriter.WriteToStreamAsync(utf8Json, cancellationToken).ConfigureAwait(false);
+                            bufferWriter.Clear();
                         }
                         finally
                         {
-                            if (state.PendingAsyncDisposables?.Count > 0)
+                            // Await any pending resumable converter tasks (currently these can only be IAsyncEnumerator.MoveNextAsync() tasks).
+                            // Note that pending tasks are always awaited, even if an exception has been thrown or the cancellation token has fired.
+                            if (state.PendingTask is not null)
                             {
-                                await state.DisposePendingAsyncDisposables().ConfigureAwait(false);
+                                try
+                                {
+                                    await state.PendingTask.ConfigureAwait(false);
+                                }
+                                catch
+                                {
+                                    // Exceptions should only be propagated by the resuming converter
+                                    // TODO https://github.com/dotnet/runtime/issues/22144
+                                }
                             }
-                        }
-
-                        await bufferWriter.WriteToStreamAsync(utf8Json, cancellationToken).ConfigureAwait(false);
-                        bufferWriter.Clear();
 
-                        if (state.PendingTask is not null)
-                        {
-                            try
-                            {
-                                await state.PendingTask.ConfigureAwait(false);
-                            }
-                            catch
+                            // Dispose any pending async disposables (currently these can only be completed IAsyncEnumerators).
+                            if (state.CompletedAsyncDisposables?.Count > 0)
                             {
-                                // Exceptions will be propagated elsewhere
-                                // TODO https://github.com/dotnet/runtime/issues/22144
+                                await state.DisposeCompletedAsyncDisposables().ConfigureAwait(false);
                             }
                         }
 
@@ -354,6 +356,7 @@ namespace System.Text.Json
                 }
                 catch
                 {
+                    // On exception, walk the WriteStack for any orphaned disposables and try to dispose them.
                     await state.DisposePendingDisposablesOnExceptionAsync().ConfigureAwait(false);
                     throw;
                 }
index bc313da..7312f6f 100644 (file)
@@ -47,9 +47,9 @@ namespace System.Text.Json
         public Task? PendingTask;
 
         /// <summary>
-        /// List of IAsyncDisposables that have been scheduled for disposal by converters.
+        /// List of completed IAsyncDisposables that have been scheduled for disposal by converters.
         /// </summary>
-        public List<IAsyncDisposable>? PendingAsyncDisposables;
+        public List<IAsyncDisposable>? CompletedAsyncDisposables;
 
         /// <summary>
         /// The amount of bytes to write before the underlying Stream should be flushed and the
@@ -196,14 +196,6 @@ namespace System.Text.Json
             {
                 Debug.Assert(_continuationCount == 0);
 
-                if (Current.AsyncEnumerator is not null)
-                {
-                    // we have completed serialization of an AsyncEnumerator,
-                    // pop from the stack and schedule for async disposal.
-                    PendingAsyncDisposables ??= new List<IAsyncDisposable>();
-                    PendingAsyncDisposables.Add(Current.AsyncEnumerator);
-                }
-
                 if (--_count > 0)
                 {
                     Current = _stack[_count - 1];
@@ -211,13 +203,16 @@ namespace System.Text.Json
             }
         }
 
+        public void AddCompletedAsyncDisposable(IAsyncDisposable asyncDisposable)
+            => (CompletedAsyncDisposables ??= new List<IAsyncDisposable>()).Add(asyncDisposable);
+
         // Asynchronously dispose of any AsyncDisposables that have been scheduled for disposal
-        public async ValueTask DisposePendingAsyncDisposables()
+        public async ValueTask DisposeCompletedAsyncDisposables()
         {
-            Debug.Assert(PendingAsyncDisposables?.Count > 0);
+            Debug.Assert(CompletedAsyncDisposables?.Count > 0);
             Exception? exception = null;
 
-            foreach (IAsyncDisposable asyncDisposable in PendingAsyncDisposables)
+            foreach (IAsyncDisposable asyncDisposable in CompletedAsyncDisposables)
             {
                 try
                 {
@@ -234,7 +229,7 @@ namespace System.Text.Json
                 ExceptionDispatchInfo.Capture(exception).Throw();
             }
 
-            PendingAsyncDisposables.Clear();
+            CompletedAsyncDisposables.Clear();
         }
 
         /// <summary>
@@ -245,13 +240,13 @@ namespace System.Text.Json
         {
             Exception? exception = null;
 
-            Debug.Assert(Current.AsyncEnumerator is null);
+            Debug.Assert(Current.AsyncDisposable is null);
             DisposeFrame(Current.CollectionEnumerator, ref exception);
 
             int stackSize = Math.Max(_count, _continuationCount);
             for (int i = 0; i < stackSize - 1; i++)
             {
-                Debug.Assert(_stack[i].AsyncEnumerator is null);
+                Debug.Assert(_stack[i].AsyncDisposable is null);
                 DisposeFrame(_stack[i].CollectionEnumerator, ref exception);
             }
 
@@ -284,12 +279,12 @@ namespace System.Text.Json
         {
             Exception? exception = null;
 
-            exception = await DisposeFrame(Current.CollectionEnumerator, Current.AsyncEnumerator, exception).ConfigureAwait(false);
+            exception = await DisposeFrame(Current.CollectionEnumerator, Current.AsyncDisposable, exception).ConfigureAwait(false);
 
             int stackSize = Math.Max(_count, _continuationCount);
             for (int i = 0; i < stackSize - 1; i++)
             {
-                exception = await DisposeFrame(_stack[i].CollectionEnumerator, _stack[i].AsyncEnumerator, exception).ConfigureAwait(false);
+                exception = await DisposeFrame(_stack[i].CollectionEnumerator, _stack[i].AsyncDisposable, exception).ConfigureAwait(false);
             }
 
             if (exception is not null)
index 97002fb..fdd4972 100644 (file)
@@ -19,7 +19,7 @@ namespace System.Text.Json
         /// <summary>
         /// The enumerator for resumable async disposables.
         /// </summary>
-        public IAsyncDisposable? AsyncEnumerator;
+        public IAsyncDisposable? AsyncDisposable;
 
         /// <summary>
         /// The current stackframe has suspended serialization due to a pending task,
index 9e57524..ba855bd 100644 (file)
@@ -74,16 +74,24 @@ namespace System.Text.Json.Serialization.Tests
             Assert.Equal(1, asyncEnumerable.TotalDisposedEnumerators);
         }
 
-        [Fact, OuterLoop]
-        public async Task WriteAsyncEnumerable_LongRunningEnumeration_Cancellation()
+        [Theory, OuterLoop]
+        [InlineData(5000, 1000, true)]
+        [InlineData(5000, 1000, false)]
+        [InlineData(1000, 10_000, true)]
+        [InlineData(1000, 10_000, false)]
+        public async Task WriteAsyncEnumerable_LongRunningEnumeration_Cancellation(
+            int cancellationTokenSourceDelayMilliseconds,
+            int enumeratorDelayMilliseconds,
+            bool passCancellationTokenToDelayTask)
         {
             var longRunningEnumerable = new MockedAsyncEnumerable<int>(
-                source: Enumerable.Range(1, 100),
+                source: Enumerable.Range(1, 1000),
                 delayInterval: 1,
-                delay: TimeSpan.FromMinutes(1));
+                delay: TimeSpan.FromMilliseconds(enumeratorDelayMilliseconds),
+                passCancellationTokenToDelayTask);
 
             using var utf8Stream = new Utf8MemoryStream();
-            using var cts = new CancellationTokenSource(delay: TimeSpan.FromSeconds(5));
+            using var cts = new CancellationTokenSource(delay: TimeSpan.FromMilliseconds(cancellationTokenSourceDelayMilliseconds));
             await Assert.ThrowsAsync<TaskCanceledException>(async () =>
                 await JsonSerializer.SerializeAsync(utf8Stream, longRunningEnumerable, cancellationToken: cts.Token));
 
@@ -225,21 +233,42 @@ namespace System.Text.Json.Serialization.Tests
             static object[] WrapArgs<TSource>(IEnumerable<TSource> source, int delayInterval, int bufferSize) => new object[]{ source, delayInterval, bufferSize };
         }
 
-        private class MockedAsyncEnumerable<TElement> : IAsyncEnumerable<TElement>, IEnumerable<TElement>
+        [Fact]
+        public async Task RegressionTest_DisposingEnumeratorOnPendingMoveNextAsyncOperation()
+        {
+            // Regression test for https://github.com/dotnet/runtime/issues/57360
+            using var stream = new Utf8MemoryStream();
+            using var cts = new CancellationTokenSource(millisecondsDelay: 1000);
+            await Assert.ThrowsAsync<TaskCanceledException>(async () => await JsonSerializer.SerializeAsync(stream, GetNumbersAsync(), cancellationToken: cts.Token));
+
+            static async IAsyncEnumerable<int> GetNumbersAsync()
+            {
+                int i = 0;
+                while (true)
+                {
+                    await Task.Delay(100);
+                    yield return i++;
+                }
+            }
+        }
+
+        public class MockedAsyncEnumerable<TElement> : IAsyncEnumerable<TElement>, IEnumerable<TElement>
         {
             private readonly IEnumerable<TElement> _source;
             private readonly TimeSpan _delay;
             private readonly int _delayInterval;
+            private readonly bool _passCancellationTokenToDelayTask;
 
-            internal int TotalCreatedEnumerators { get; private set; }
-            internal int TotalDisposedEnumerators { get; private set; }
-            internal int TotalEnumeratedElements { get; private set; }
+            public int TotalCreatedEnumerators { get; private set; }
+            public int TotalDisposedEnumerators { get; private set; }
+            public int TotalEnumeratedElements { get; private set; }
 
-            public MockedAsyncEnumerable(IEnumerable<TElement> source, int delayInterval = 0, TimeSpan? delay = null)
+            public MockedAsyncEnumerable(IEnumerable<TElement> source, int delayInterval = 0, TimeSpan? delay = null, bool passCancellationTokenToDelayTask = true)
             {
                 _source = source;
                 _delay = delay ?? TimeSpan.FromMilliseconds(20);
                 _delayInterval = delayInterval;
+                _passCancellationTokenToDelayTask = passCancellationTokenToDelayTask;
             }
 
             public IAsyncEnumerator<TElement> GetAsyncEnumerator(CancellationToken cancellationToken = default)
@@ -277,7 +306,7 @@ namespace System.Text.Json.Serialization.Tests
                 {
                     if (i > 0 && _delayInterval > 0 && i % _delayInterval == 0)
                     {
-                        await Task.Delay(_delay, cancellationToken);
+                        await Task.Delay(_delay, _passCancellationTokenToDelayTask ? cancellationToken : default);
                     }
 
                     if (cancellationToken.IsCancellationRequested)