From: Eirik Tsarpalis Date: Tue, 17 Aug 2021 10:16:27 +0000 (+0100) Subject: Fix System.Text.Json IAsyncEnumerator disposal on cancellation (#57505) X-Git-Tag: accepted/tizen/unified/20220110.054933~344 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=61075fbe0d25668b4fa98aa80c2d6c004cf70afd;p=platform%2Fupstream%2Fdotnet%2Fruntime.git Fix System.Text.Json IAsyncEnumerator disposal on cancellation (#57505) * Ensure WriteStack.Pending task is awaited on exception. Ensure IAsyncDisposable instances are disposed exactly once. Fixes #57360. * Update src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.AsyncEnumerable.cs --- diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IAsyncEnumerableOfTConverter.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IAsyncEnumerableOfTConverter.cs index ca3def9..20dcedf 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IAsyncEnumerableOfTConverter.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IAsyncEnumerableOfTConverter.cs @@ -48,18 +48,22 @@ namespace System.Text.Json.Serialization.Converters IAsyncEnumerator enumerator; ValueTask moveNextTask; - if (state.Current.AsyncEnumerator is null) + if (state.Current.AsyncDisposable is null) { enumerator = value.GetAsyncEnumerator(state.CancellationToken); + // async enumerators can only be disposed asynchronously; + // store in the WriteStack for future disposal + // by the root async serialization context. + state.Current.AsyncDisposable = enumerator; + // enumerator.MoveNextAsync() calls can throw, + // ensure the enumerator already is stored + // in the WriteStack for proper disposal. moveNextTask = enumerator.MoveNextAsync(); - // we always need to attach the enumerator to the stack - // since it will need to be disposed asynchronously. - state.Current.AsyncEnumerator = enumerator; } else { - Debug.Assert(state.Current.AsyncEnumerator is IAsyncEnumerator); - enumerator = (IAsyncEnumerator)state.Current.AsyncEnumerator; + Debug.Assert(state.Current.AsyncDisposable is IAsyncEnumerator); + enumerator = (IAsyncEnumerator)state.Current.AsyncDisposable; if (state.Current.AsyncEnumeratorIsPendingCompletion) { @@ -84,6 +88,10 @@ namespace System.Text.Json.Serialization.Converters { if (!moveNextTask.Result) { + // we have completed serialization for the enumerator, + // clear from the stack and schedule for async disposal. + state.Current.AsyncDisposable = null; + state.AddCompletedAsyncDisposable(enumerator); return true; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonSerializer.Write.Stream.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonSerializer.Write.Stream.cs index bd04e35..04b7b10 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonSerializer.Write.Stream.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonSerializer.Write.Stream.cs @@ -325,28 +325,30 @@ namespace System.Text.Json try { isFinalBlock = WriteCore(converter, writer, value, options, ref state); + await bufferWriter.WriteToStreamAsync(utf8Json, cancellationToken).ConfigureAwait(false); + bufferWriter.Clear(); } finally { - if (state.PendingAsyncDisposables?.Count > 0) + // Await any pending resumable converter tasks (currently these can only be IAsyncEnumerator.MoveNextAsync() tasks). + // Note that pending tasks are always awaited, even if an exception has been thrown or the cancellation token has fired. + if (state.PendingTask is not null) { - await state.DisposePendingAsyncDisposables().ConfigureAwait(false); + try + { + await state.PendingTask.ConfigureAwait(false); + } + catch + { + // Exceptions should only be propagated by the resuming converter + // TODO https://github.com/dotnet/runtime/issues/22144 + } } - } - - await bufferWriter.WriteToStreamAsync(utf8Json, cancellationToken).ConfigureAwait(false); - bufferWriter.Clear(); - if (state.PendingTask is not null) - { - try - { - await state.PendingTask.ConfigureAwait(false); - } - catch + // Dispose any pending async disposables (currently these can only be completed IAsyncEnumerators). + if (state.CompletedAsyncDisposables?.Count > 0) { - // Exceptions will be propagated elsewhere - // TODO https://github.com/dotnet/runtime/issues/22144 + await state.DisposeCompletedAsyncDisposables().ConfigureAwait(false); } } @@ -354,6 +356,7 @@ namespace System.Text.Json } catch { + // On exception, walk the WriteStack for any orphaned disposables and try to dispose them. await state.DisposePendingDisposablesOnExceptionAsync().ConfigureAwait(false); throw; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs index bc313da..7312f6f 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs @@ -47,9 +47,9 @@ namespace System.Text.Json public Task? PendingTask; /// - /// List of IAsyncDisposables that have been scheduled for disposal by converters. + /// List of completed IAsyncDisposables that have been scheduled for disposal by converters. /// - public List? PendingAsyncDisposables; + public List? CompletedAsyncDisposables; /// /// The amount of bytes to write before the underlying Stream should be flushed and the @@ -196,14 +196,6 @@ namespace System.Text.Json { Debug.Assert(_continuationCount == 0); - if (Current.AsyncEnumerator is not null) - { - // we have completed serialization of an AsyncEnumerator, - // pop from the stack and schedule for async disposal. - PendingAsyncDisposables ??= new List(); - PendingAsyncDisposables.Add(Current.AsyncEnumerator); - } - if (--_count > 0) { Current = _stack[_count - 1]; @@ -211,13 +203,16 @@ namespace System.Text.Json } } + public void AddCompletedAsyncDisposable(IAsyncDisposable asyncDisposable) + => (CompletedAsyncDisposables ??= new List()).Add(asyncDisposable); + // Asynchronously dispose of any AsyncDisposables that have been scheduled for disposal - public async ValueTask DisposePendingAsyncDisposables() + public async ValueTask DisposeCompletedAsyncDisposables() { - Debug.Assert(PendingAsyncDisposables?.Count > 0); + Debug.Assert(CompletedAsyncDisposables?.Count > 0); Exception? exception = null; - foreach (IAsyncDisposable asyncDisposable in PendingAsyncDisposables) + foreach (IAsyncDisposable asyncDisposable in CompletedAsyncDisposables) { try { @@ -234,7 +229,7 @@ namespace System.Text.Json ExceptionDispatchInfo.Capture(exception).Throw(); } - PendingAsyncDisposables.Clear(); + CompletedAsyncDisposables.Clear(); } /// @@ -245,13 +240,13 @@ namespace System.Text.Json { Exception? exception = null; - Debug.Assert(Current.AsyncEnumerator is null); + Debug.Assert(Current.AsyncDisposable is null); DisposeFrame(Current.CollectionEnumerator, ref exception); int stackSize = Math.Max(_count, _continuationCount); for (int i = 0; i < stackSize - 1; i++) { - Debug.Assert(_stack[i].AsyncEnumerator is null); + Debug.Assert(_stack[i].AsyncDisposable is null); DisposeFrame(_stack[i].CollectionEnumerator, ref exception); } @@ -284,12 +279,12 @@ namespace System.Text.Json { Exception? exception = null; - exception = await DisposeFrame(Current.CollectionEnumerator, Current.AsyncEnumerator, exception).ConfigureAwait(false); + exception = await DisposeFrame(Current.CollectionEnumerator, Current.AsyncDisposable, exception).ConfigureAwait(false); int stackSize = Math.Max(_count, _continuationCount); for (int i = 0; i < stackSize - 1; i++) { - exception = await DisposeFrame(_stack[i].CollectionEnumerator, _stack[i].AsyncEnumerator, exception).ConfigureAwait(false); + exception = await DisposeFrame(_stack[i].CollectionEnumerator, _stack[i].AsyncDisposable, exception).ConfigureAwait(false); } if (exception is not null) diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStackFrame.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStackFrame.cs index 97002fb..fdd4972 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStackFrame.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStackFrame.cs @@ -19,7 +19,7 @@ namespace System.Text.Json /// /// The enumerator for resumable async disposables. /// - public IAsyncDisposable? AsyncEnumerator; + public IAsyncDisposable? AsyncDisposable; /// /// The current stackframe has suspended serialization due to a pending task, diff --git a/src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.AsyncEnumerable.cs b/src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.AsyncEnumerable.cs index 9e57524..ba855bd 100644 --- a/src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.AsyncEnumerable.cs +++ b/src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.AsyncEnumerable.cs @@ -74,16 +74,24 @@ namespace System.Text.Json.Serialization.Tests Assert.Equal(1, asyncEnumerable.TotalDisposedEnumerators); } - [Fact, OuterLoop] - public async Task WriteAsyncEnumerable_LongRunningEnumeration_Cancellation() + [Theory, OuterLoop] + [InlineData(5000, 1000, true)] + [InlineData(5000, 1000, false)] + [InlineData(1000, 10_000, true)] + [InlineData(1000, 10_000, false)] + public async Task WriteAsyncEnumerable_LongRunningEnumeration_Cancellation( + int cancellationTokenSourceDelayMilliseconds, + int enumeratorDelayMilliseconds, + bool passCancellationTokenToDelayTask) { var longRunningEnumerable = new MockedAsyncEnumerable( - source: Enumerable.Range(1, 100), + source: Enumerable.Range(1, 1000), delayInterval: 1, - delay: TimeSpan.FromMinutes(1)); + delay: TimeSpan.FromMilliseconds(enumeratorDelayMilliseconds), + passCancellationTokenToDelayTask); using var utf8Stream = new Utf8MemoryStream(); - using var cts = new CancellationTokenSource(delay: TimeSpan.FromSeconds(5)); + using var cts = new CancellationTokenSource(delay: TimeSpan.FromMilliseconds(cancellationTokenSourceDelayMilliseconds)); await Assert.ThrowsAsync(async () => await JsonSerializer.SerializeAsync(utf8Stream, longRunningEnumerable, cancellationToken: cts.Token)); @@ -225,21 +233,42 @@ namespace System.Text.Json.Serialization.Tests static object[] WrapArgs(IEnumerable source, int delayInterval, int bufferSize) => new object[]{ source, delayInterval, bufferSize }; } - private class MockedAsyncEnumerable : IAsyncEnumerable, IEnumerable + [Fact] + public async Task RegressionTest_DisposingEnumeratorOnPendingMoveNextAsyncOperation() + { + // Regression test for https://github.com/dotnet/runtime/issues/57360 + using var stream = new Utf8MemoryStream(); + using var cts = new CancellationTokenSource(millisecondsDelay: 1000); + await Assert.ThrowsAsync(async () => await JsonSerializer.SerializeAsync(stream, GetNumbersAsync(), cancellationToken: cts.Token)); + + static async IAsyncEnumerable GetNumbersAsync() + { + int i = 0; + while (true) + { + await Task.Delay(100); + yield return i++; + } + } + } + + public class MockedAsyncEnumerable : IAsyncEnumerable, IEnumerable { private readonly IEnumerable _source; private readonly TimeSpan _delay; private readonly int _delayInterval; + private readonly bool _passCancellationTokenToDelayTask; - internal int TotalCreatedEnumerators { get; private set; } - internal int TotalDisposedEnumerators { get; private set; } - internal int TotalEnumeratedElements { get; private set; } + public int TotalCreatedEnumerators { get; private set; } + public int TotalDisposedEnumerators { get; private set; } + public int TotalEnumeratedElements { get; private set; } - public MockedAsyncEnumerable(IEnumerable source, int delayInterval = 0, TimeSpan? delay = null) + public MockedAsyncEnumerable(IEnumerable source, int delayInterval = 0, TimeSpan? delay = null, bool passCancellationTokenToDelayTask = true) { _source = source; _delay = delay ?? TimeSpan.FromMilliseconds(20); _delayInterval = delayInterval; + _passCancellationTokenToDelayTask = passCancellationTokenToDelayTask; } public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) @@ -277,7 +306,7 @@ namespace System.Text.Json.Serialization.Tests { if (i > 0 && _delayInterval > 0 && i % _delayInterval == 0) { - await Task.Delay(_delay, cancellationToken); + await Task.Delay(_delay, _passCancellationTokenToDelayTask ? cancellationToken : default); } if (cancellationToken.IsCancellationRequested)