Fix async iterators to clear out state upon completion (#43522)
authorStephen Toub <stoub@microsoft.com>
Sun, 18 Oct 2020 23:10:32 +0000 (19:10 -0400)
committerGitHub <noreply@github.com>
Sun, 18 Oct 2020 23:10:32 +0000 (19:10 -0400)
AsyncIteratorMethodBuilder was only doing its clean-up for completion (e.g. zeroing out the state machine and context, removing the object from a debugger-incited tracking table) if the last call to the iterator was part of asynchronous completion; if the last MoveNextAsync completed synchronously, the used code path could miss that cleanup work.

src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncIteratorMethodBuilder.cs
src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncTaskMethodBuilderT.cs
src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncValueTaskMethodBuilderT.cs
src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/IAsyncStateMachineBox.cs
src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/AsyncIteratorMethodBuilderTests.cs [new file with mode: 0644]
src/libraries/System.Threading.Tasks/tests/System.Threading.Tasks.Tests.csproj

index 1f8f3d5..003047b 100644 (file)
@@ -3,6 +3,7 @@
 
 using System.Runtime.InteropServices;
 using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Runtime.CompilerServices
 {
@@ -10,24 +11,17 @@ namespace System.Runtime.CompilerServices
     [StructLayout(LayoutKind.Auto)]
     public struct AsyncIteratorMethodBuilder
     {
-        // AsyncIteratorMethodBuilder is used by the language compiler as part of generating
-        // async iterators. For now, the implementation just wraps AsyncTaskMethodBuilder, as
-        // most of the logic is shared.  However, in the future this could be changed and
-        // optimized.  For example, we do need to allocate an object (once) to flow state like
-        // ExecutionContext, which AsyncTaskMethodBuilder handles, but it handles it by
-        // allocating a Task-derived object.  We could optimize this further by removing
-        // the Task from the hierarchy, but in doing so we'd also lose a variety of optimizations
-        // related to it, so we'd need to replicate all of those optimizations (e.g. storing
-        // that box object directly into a Task's continuation field).
-
-        private AsyncTaskMethodBuilder _methodBuilder; // mutable struct; do not make it readonly
+        /// <summary>The lazily-initialized box/task object, created the first time the iterator awaits something not yet completed.</summary>
+        /// <remarks>
+        /// This will be the async state machine box created for the compiler-generated class (not struct) state machine
+        /// object for the async enumerator.  Even though its not exposed as a Task property as on AsyncTaskMethodBuilder,
+        /// it needs to be stored if for no other reason than <see cref="Complete"/> needs to mark it completed in order to clean up.
+        /// </remarks>
+        private Task<VoidTaskResult>? m_task; // Debugger depends on the exact name of this field.
 
         /// <summary>Creates an instance of the <see cref="AsyncIteratorMethodBuilder"/> struct.</summary>
         /// <returns>The initialized instance.</returns>
-        public static AsyncIteratorMethodBuilder Create() =>
-            // _methodBuilder should be initialized to AsyncTaskMethodBuilder.Create(), but on coreclr
-            // that Create() is a nop, so we can just return the default here.
-            default;
+        public static AsyncIteratorMethodBuilder Create() => default;
 
         /// <summary>Invokes <see cref="IAsyncStateMachine.MoveNext"/> on the state machine while guarding the <see cref="ExecutionContext"/>.</summary>
         /// <typeparam name="TStateMachine">The type of the state machine.</typeparam>
@@ -44,7 +38,7 @@ namespace System.Runtime.CompilerServices
         public void AwaitOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine)
             where TAwaiter : INotifyCompletion
             where TStateMachine : IAsyncStateMachine =>
-            _methodBuilder.AwaitOnCompleted(ref awaiter, ref stateMachine);
+            AsyncTaskMethodBuilder<VoidTaskResult>.AwaitOnCompleted(ref awaiter, ref stateMachine, ref m_task);
 
         /// <summary>Schedules the state machine to proceed to the next action when the specified awaiter completes.</summary>
         /// <typeparam name="TAwaiter">The type of the awaiter.</typeparam>
@@ -54,12 +48,41 @@ namespace System.Runtime.CompilerServices
         public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine)
             where TAwaiter : ICriticalNotifyCompletion
             where TStateMachine : IAsyncStateMachine =>
-            _methodBuilder.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine);
+            AsyncTaskMethodBuilder<VoidTaskResult>.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine, ref m_task);
 
         /// <summary>Marks iteration as being completed, whether successfully or otherwise.</summary>
-        public void Complete() => _methodBuilder.SetResult();
+        public void Complete()
+        {
+            if (m_task is null)
+            {
+                m_task = Task.s_cachedCompleted;
+            }
+            else
+            {
+                AsyncTaskMethodBuilder<VoidTaskResult>.SetExistingTaskResult(m_task, default);
+
+                // Ensure the box's state is cleared so that we don't inadvertently keep things
+                // alive, such as any locals referenced by the async enumerator.  For async tasks,
+                // this is implicitly handled as part of the box/task's MoveNext, with it invoking
+                // the completion logic after invoking the state machine's MoveNext after the last
+                // await (or it won't have been necessary because the box was never created in the
+                // first place).  But with async iterators, the task represents the entire lifetime
+                // of the iterator, across any number of MoveNextAsync/DisposeAsync calls, and the
+                // only hook we have to know when the whole thing is completed is this call to Complete
+                // as inserted by the compiler in the compiler-generated MoveNext on the state machine.
+                // If the last MoveNextAsync/DisposeAsync call to the iterator completes asynchronously,
+                // then that same clearing logic will handle the iterator as well, but if the last
+                // MoveNextAsync/DisposeAsync completes synchronously, that logic will be skipped, and
+                // we'll need to handle it here.  Thus, it's possible we could double clear by always
+                // doing it here, and the logic needs to be idempotent.
+                if (m_task is IAsyncStateMachineBox box)
+                {
+                    box.ClearStateUponCompletion();
+                }
+            }
+        }
 
         /// <summary>Gets an object that may be used to uniquely identify this builder to the debugger.</summary>
-        internal object ObjectIdForDebugger => _methodBuilder.ObjectIdForDebugger;
+        internal object ObjectIdForDebugger => m_task ??= AsyncTaskMethodBuilder<VoidTaskResult>.CreateWeaklyTypedStateMachineBox();
     }
 }
index 70ac433..2b027ce 100644 (file)
@@ -332,35 +332,41 @@ namespace System.Runtime.CompilerServices
                     }
                 }
 
-                if (IsCompleted)
+                if (loggingOn)
                 {
-                    // If async debugging is enabled, remove the task from tracking.
-                    if (System.Threading.Tasks.Task.s_asyncDebuggingEnabled)
-                    {
-                        System.Threading.Tasks.Task.RemoveFromActiveTasks(this);
-                    }
+                    TplEventSource.Log.TraceSynchronousWorkEnd(CausalitySynchronousWork.Execution);
+                }
+            }
 
-                    // Clear out state now that the async method has completed.
-                    // This avoids keeping arbitrary state referenced by lifted locals
-                    // if this Task / state machine box is held onto.
-                    StateMachine = default;
-                    Context = default;
+            /// <summary>Clears out all state associated with a completed box.</summary>
+            [MethodImpl(MethodImplOptions.AggressiveInlining)]
+            public void ClearStateUponCompletion()
+            {
+                Debug.Assert(IsCompleted);
 
-#if !CORERT
-                    // In case this is a state machine box with a finalizer, suppress its finalization
-                    // as it's now complete.  We only need the finalizer to run if the box is collected
-                    // without having been completed.
-                    if (AsyncMethodBuilderCore.TrackAsyncMethodCompletion)
-                    {
-                        GC.SuppressFinalize(this);
-                    }
-#endif
+                // This logic may be invoked multiple times on the same instance and needs to be robust against that.
+
+                // If async debugging is enabled, remove the task from tracking.
+                if (s_asyncDebuggingEnabled)
+                {
+                    RemoveFromActiveTasks(this);
                 }
 
-                if (loggingOn)
+                // Clear out state now that the async method has completed.
+                // This avoids keeping arbitrary state referenced by lifted locals
+                // if this Task / state machine box is held onto.
+                StateMachine = default;
+                Context = default;
+
+#if !CORERT
+                // In case this is a state machine box with a finalizer, suppress its finalization
+                // as it's now complete.  We only need the finalizer to run if the box is collected
+                // without having been completed.
+                if (AsyncMethodBuilderCore.TrackAsyncMethodCompletion)
                 {
-                    TplEventSource.Log.TraceSynchronousWorkEnd(CausalitySynchronousWork.Execution);
+                    GC.SuppressFinalize(this);
                 }
+#endif
             }
 
             /// <summary>Gets the state machine as a boxed object.  This should only be used for debugging purposes.</summary>
index 846c649..99656b7 100644 (file)
@@ -405,8 +405,7 @@ namespace System.Runtime.CompilerServices
                 // Clear out the state machine and associated context to avoid keeping arbitrary state referenced by
                 // lifted locals.  We want to do this regardless of whether we end up caching the box or not, in case
                 // the caller keeps the box alive for an arbitrary period of time.
-                StateMachine = default;
-                Context = default;
+                ClearStateUponCompletion();
 
                 // Reset the MRVTSC.  We can either do this here, in which case we may be paying the (small) overhead
                 // to reset the box even if we're going to drop it, or we could do it while holding the lock, in which
@@ -444,6 +443,16 @@ namespace System.Runtime.CompilerServices
             }
 
             /// <summary>
+            /// Clear out the state machine and associated context to avoid keeping arbitrary state referenced by lifted locals.
+            /// </summary>
+            [MethodImpl(MethodImplOptions.AggressiveInlining)]
+            public void ClearStateUponCompletion()
+            {
+                StateMachine = default;
+                Context = default;
+            }
+
+            /// <summary>
             /// Used to initialize s_callback above. We don't use a lambda for this on purpose: a lambda would
             /// introduce a new generic type behind the scenes that comes with a hefty size penalty in AOT builds.
             /// </summary>
index 67e9247..b424016 100644 (file)
@@ -19,5 +19,8 @@ namespace System.Runtime.CompilerServices
 
         /// <summary>Gets the state machine as a boxed object.  This should only be used for debugging purposes.</summary>
         IAsyncStateMachine GetStateMachineObject();
+
+        /// <summary>Clears the state of the box.</summary>
+        void ClearStateUponCompletion();
     }
 }
diff --git a/src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/AsyncIteratorMethodBuilderTests.cs b/src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/AsyncIteratorMethodBuilderTests.cs
new file mode 100644 (file)
index 0000000..80cac0b
--- /dev/null
@@ -0,0 +1,50 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics.Tracing;
+using System.Linq;
+using System.Reflection;
+using System.Runtime.CompilerServices;
+using System.Text;
+using Microsoft.DotNet.RemoteExecutor;
+using Xunit;
+using Xunit.Sdk;
+
+namespace System.Threading.Tasks.Tests
+{
+    public class AsyncIteratorMethodBuilderTests
+    {
+        [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        public void AsyncIteratorMethodBuilder_TaskCompleted()
+        {
+            RemoteExecutor.Invoke(() =>
+            {
+                static async IAsyncEnumerable<int> Func(TaskCompletionSource tcs)
+                {
+                    await tcs.Task;
+                    yield return 1;
+                }
+
+                // NOTE: This depends on private implementation details generally only used by the debugger.
+                // If those ever change, this test will need to be updated as well.
+
+                typeof(Task).GetField("s_asyncDebuggingEnabled", BindingFlags.NonPublic | BindingFlags.Static).SetValue(null, true);
+
+                for (int i = 0; i < 1000; i++)
+                {
+                    TaskCompletionSource tcs = new();
+                    IAsyncEnumerator<int> e = Func(tcs).GetAsyncEnumerator();
+                    Task t = e.MoveNextAsync().AsTask();
+                    tcs.SetResult();
+                    t.Wait();
+                    e.MoveNextAsync().AsTask().Wait();
+                }
+
+                int activeCount = ((dynamic)typeof(Task).GetField("s_currentActiveTasks", BindingFlags.NonPublic | BindingFlags.Static).GetValue(null)).Count;
+                Assert.InRange(activeCount, 0, 10); // some other tasks may be created by the runtime, so this is just using a reasonably small upper bound
+            }).Dispose();
+        }
+    }
+}
index 9b42159..5df7438 100644 (file)
@@ -46,6 +46,7 @@
     <!-- TaskScheduler -->
     <Compile Include="TaskScheduler\TaskSchedulerTests.cs" />
     <!-- System.Runtime.CompilerServices -->
+    <Compile Include="System.Runtime.CompilerServices\AsyncIteratorMethodBuilderTests.cs" />
     <Compile Include="System.Runtime.CompilerServices\AsyncTaskMethodBuilderTests.cs" />
     <Compile Include="System.Runtime.CompilerServices\ConfiguredAsyncDisposable.cs" />
     <Compile Include="System.Runtime.CompilerServices\ConfiguredCancelableAsyncEnumerableTests.cs" />