Allow default AsyncFlowControls rather than throwing (#82912)
authorStephen Toub <stoub@microsoft.com>
Fri, 3 Mar 2023 11:43:56 +0000 (06:43 -0500)
committerGitHub <noreply@github.com>
Fri, 3 Mar 2023 11:43:56 +0000 (06:43 -0500)
ExecutionContext.SuppressFlow currently throws an exception if flow is already suppressed.  This makes it complicated to use, as you need to check whether IsFlowSuppressed first and take two different paths based on the result.  If we instead just allow SuppressFlow to return a default AsyncFlowControl rather than throwing, and have AsyncFlowControl's Undo nop rather than throw if it doesn't contain a Thread, we can again make it simple to just always use SuppressFlow without any of the other complications.

12 files changed:
src/libraries/System.IO.FileSystem.Watcher/tests/FileSystemWatcher.MultipleWatchers.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPoolManager.cs
src/libraries/System.Net.NetworkInformation/src/System/Net/NetworkInformation/NetworkAddressChange.Unix.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/ExecutionContextFlowTest.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs
src/libraries/System.Private.CoreLib/src/Resources/Strings.resx
src/libraries/System.Private.CoreLib/src/System/Threading/ExecutionContext.cs
src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs
src/libraries/System.Threading.Tasks/tests/Task/ExecutionContextFlowTest.cs
src/libraries/System.Threading/tests/AsyncLocalTests.cs
src/libraries/System.Threading/tests/ExecutionContextTests.cs

index b5ab5a8..9d15f22 100644 (file)
@@ -79,20 +79,15 @@ namespace System.IO.Tests
 
                 local.Value = 42;
 
-                ExecutionContext.SuppressFlow();
-                try
+                using (ExecutionContext.SuppressFlow())
                 {
                     watcher1.EnableRaisingEvents = true;
                 }
-                finally
-                {
-                    ExecutionContext.RestoreFlow();
-                }
 
-                    File.Create(fileName).Dispose();
-                    tcs1.Task.Wait(WaitForExpectedEventTimeout);
+                File.Create(fileName).Dispose();
+                tcs1.Task.Wait(WaitForExpectedEventTimeout);
 
-                    Assert.Equal(0, tcs1.Task.Result);
+                Assert.Equal(0, tcs1.Task.Result);
            }
         }
 
index 1d8c44a..df08584 100644 (file)
@@ -1291,15 +1291,8 @@ namespace System.Net.Http
                     {
                         var thisRef = new WeakReference<HttpConnectionPool>(this);
 
-                        bool restoreFlow = false;
-                        try
+                        using (ExecutionContext.SuppressFlow())
                         {
-                            if (!ExecutionContext.IsFlowSuppressed())
-                            {
-                                ExecutionContext.SuppressFlow();
-                                restoreFlow = true;
-                            }
-
                             _authorityExpireTimer = new Timer(static o =>
                             {
                                 var wr = (WeakReference<HttpConnectionPool>)o!;
@@ -1309,10 +1302,6 @@ namespace System.Net.Http
                                 }
                             }, thisRef, nextAuthorityMaxAge, Timeout.InfiniteTimeSpan);
                         }
-                        finally
-                        {
-                            if (restoreFlow) ExecutionContext.RestoreFlow();
-                        }
                     }
                     else
                     {
index fdfcd5f..47772a2 100644 (file)
@@ -91,16 +91,8 @@ namespace System.Net.Http
                     _cleanPoolTimeout = timerPeriod.TotalSeconds >= MinScavengeSeconds ? timerPeriod : TimeSpan.FromSeconds(MinScavengeSeconds);
                 }
 
-                bool restoreFlow = false;
-                try
+                using (ExecutionContext.SuppressFlow()) // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever
                 {
-                    // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever
-                    if (!ExecutionContext.IsFlowSuppressed())
-                    {
-                        ExecutionContext.SuppressFlow();
-                        restoreFlow = true;
-                    }
-
                     // Create the timer.  Ensure the Timer has a weak reference to this manager; otherwise, it
                     // can introduce a cycle that keeps the HttpConnectionPoolManager rooted by the Timer
                     // implementation until the handler is Disposed (or indefinitely if it's not).
@@ -131,14 +123,6 @@ namespace System.Net.Http
                         }, thisRef, heartBeatInterval, heartBeatInterval);
                     }
                 }
-                finally
-                {
-                    // Restore the current ExecutionContext
-                    if (restoreFlow)
-                    {
-                        ExecutionContext.RestoreFlow();
-                    }
-                }
             }
 
             // Figure out proxy stuff.
@@ -190,14 +174,7 @@ namespace System.Net.Http
                 return;
             }
 
-            if (!ExecutionContext.IsFlowSuppressed())
-            {
-                using (ExecutionContext.SuppressFlow())
-                {
-                    NetworkChange.NetworkAddressChanged += networkChangedDelegate;
-                }
-            }
-            else
+            using (ExecutionContext.SuppressFlow())
             {
                 NetworkChange.NetworkAddressChanged += networkChangedDelegate;
             }
index 7fdc445..b55a6fe 100644 (file)
@@ -104,23 +104,10 @@ namespace System.Net.NetworkInformation
                         if (s_availabilityTimer == null)
                         {
                             // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever
-                            bool restoreFlow = false;
-                            try
+                            using (ExecutionContext.SuppressFlow())
                             {
-                                if (!ExecutionContext.IsFlowSuppressed())
-                                {
-                                    ExecutionContext.SuppressFlow();
-                                    restoreFlow = true;
-                                }
-
                                 s_availabilityTimer = new Timer(s_availabilityTimerFiredCallback, null, Timeout.Infinite, Timeout.Infinite);
                             }
-                            finally
-                            {
-                                // Restore the current ExecutionContext
-                                if (restoreFlow)
-                                    ExecutionContext.RestoreFlow();
-                            }
                         }
 
                         s_availabilityChangedSubscribers.TryAdd(value, ExecutionContext.Capture());
index 7b7addb..41f4738 100644 (file)
@@ -33,15 +33,10 @@ namespace System.Net.Sockets.Tests
                 };
 
                 asyncLocal.Value = 42;
-                if (suppressContext) ExecutionContext.SuppressFlow();
-                try
+                using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                 {
                     Assert.True(listener.AcceptAsync(saea));
                 }
-                finally
-                {
-                    if (suppressContext) ExecutionContext.RestoreFlow();
-                }
                 asyncLocal.Value = 0;
 
                 client.Connect(listener.LocalEndPoint);
@@ -65,8 +60,7 @@ namespace System.Net.Sockets.Tests
                 var tcs = new TaskCompletionSource<int>();
 
                 asyncLocal.Value = 42;
-                if (suppressContext) ExecutionContext.SuppressFlow();
-                try
+                using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                 {
                     listener.BeginAccept(iar =>
                     {
@@ -74,10 +68,6 @@ namespace System.Net.Sockets.Tests
                         tcs.SetResult(asyncLocal.Value);
                     }, null);
                 }
-                finally
-                {
-                    if (suppressContext) ExecutionContext.RestoreFlow();
-                }
                 asyncLocal.Value = 0;
 
                 client.Connect(listener.LocalEndPoint);
@@ -105,15 +95,10 @@ namespace System.Net.Sockets.Tests
 
                 bool pending;
                 asyncLocal.Value = 42;
-                if (suppressContext) ExecutionContext.SuppressFlow();
-                try
+                using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                 {
                     pending = client.ConnectAsync(saea);
                 }
-                finally
-                {
-                    if (suppressContext) ExecutionContext.RestoreFlow();
-                }
                 asyncLocal.Value = 0;
 
                 if (pending)
@@ -139,8 +124,7 @@ namespace System.Net.Sockets.Tests
 
                 bool pending;
                 asyncLocal.Value = 42;
-                if (suppressContext) ExecutionContext.SuppressFlow();
-                try
+                using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                 {
                     pending = !client.BeginConnect(listener.LocalEndPoint, iar =>
                     {
@@ -148,10 +132,6 @@ namespace System.Net.Sockets.Tests
                         tcs.SetResult(asyncLocal.Value);
                     }, null).CompletedSynchronously;
                 }
-                finally
-                {
-                    if (suppressContext) ExecutionContext.RestoreFlow();
-                }
                 asyncLocal.Value = 0;
 
                 if (pending)
@@ -182,15 +162,10 @@ namespace System.Net.Sockets.Tests
 
                     bool pending;
                     asyncLocal.Value = 42;
-                    if (suppressContext) ExecutionContext.SuppressFlow();
-                    try
+                    using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                     {
                         pending = client.DisconnectAsync(saea);
                     }
-                    finally
-                    {
-                        if (suppressContext) ExecutionContext.RestoreFlow();
-                    }
                     asyncLocal.Value = 0;
 
                     if (pending)
@@ -220,8 +195,7 @@ namespace System.Net.Sockets.Tests
 
                     bool pending;
                     asyncLocal.Value = 42;
-                    if (suppressContext) ExecutionContext.SuppressFlow();
-                    try
+                    using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                     {
                         pending = !client.BeginDisconnect(reuseSocket: false, iar =>
                         {
@@ -229,10 +203,6 @@ namespace System.Net.Sockets.Tests
                             tcs.SetResult(asyncLocal.Value);
                         }, null).CompletedSynchronously;
                     }
-                    finally
-                    {
-                        if (suppressContext) ExecutionContext.RestoreFlow();
-                    }
                     asyncLocal.Value = 0;
 
                     if (pending)
@@ -267,17 +237,12 @@ namespace System.Net.Sockets.Tests
                     saea.RemoteEndPoint = server.LocalEndPoint;
 
                     asyncLocal.Value = 42;
-                    if (suppressContext) ExecutionContext.SuppressFlow();
-                    try
+                    using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                     {
                         Assert.True(receiveFrom ?
                             client.ReceiveFromAsync(saea) :
                             client.ReceiveAsync(saea));
                     }
-                    finally
-                    {
-                        if (suppressContext) ExecutionContext.RestoreFlow();
-                    }
                     asyncLocal.Value = 0;
 
                     server.Send(new byte[] { 18 });
@@ -306,8 +271,7 @@ namespace System.Net.Sockets.Tests
                     var tcs = new TaskCompletionSource<int>();
 
                     asyncLocal.Value = 42;
-                    if (suppressContext) ExecutionContext.SuppressFlow();
-                    try
+                    using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                     {
                         EndPoint ep = server.LocalEndPoint;
                         Assert.False(receiveFrom ?
@@ -322,11 +286,6 @@ namespace System.Net.Sockets.Tests
                                 tcs.SetResult(asyncLocal.Value);
                             }, null).CompletedSynchronously);
                     }
-                    finally
-                    {
-                        if (suppressContext)
-                            ExecutionContext.RestoreFlow();
-                    }
                     asyncLocal.Value = 0;
 
                     server.Send(new byte[] { 18 });
@@ -365,18 +324,13 @@ namespace System.Net.Sockets.Tests
 
                     bool pending;
                     asyncLocal.Value = 42;
-                    if (suppressContext) ExecutionContext.SuppressFlow();
-                    try
+                    using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                     {
                         pending =
                             sendMode == 0 ? client.SendAsync(saea) :
                             sendMode == 1 ? client.SendToAsync(saea) :
                             client.SendPacketsAsync(saea);
                     }
-                    finally
-                    {
-                        if (suppressContext) ExecutionContext.RestoreFlow();
-                    }
                     asyncLocal.Value = 0;
 
                     int totalReceived = 0;
@@ -416,8 +370,7 @@ namespace System.Net.Sockets.Tests
 
                     bool pending;
                     asyncLocal.Value = 42;
-                    if (suppressContext) ExecutionContext.SuppressFlow();
-                    try
+                    using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                     {
                         pending = sendTo ?
                             !client.BeginSendTo(buffer, 0, buffer.Length, SocketFlags.None, server.LocalEndPoint, iar =>
@@ -431,10 +384,6 @@ namespace System.Net.Sockets.Tests
                                 tcs.SetResult(asyncLocal.Value);
                             }, null).CompletedSynchronously;
                     }
-                    finally
-                    {
-                        if (suppressContext) ExecutionContext.RestoreFlow();
-                    }
                     asyncLocal.Value = 0;
 
                     int totalReceived = 0;
@@ -477,8 +426,7 @@ namespace System.Net.Sockets.Tests
 
                     bool pending;
                     asyncLocal.Value = 42;
-                    if (suppressContext) ExecutionContext.SuppressFlow();
-                    try
+                    using (suppressContext ? ExecutionContext.SuppressFlow() : default)
                     {
                         pending = !client.BeginSendFile(filePath, iar =>
                         {
@@ -486,10 +434,6 @@ namespace System.Net.Sockets.Tests
                             tcs.SetResult(asyncLocal.Value);
                         }, null).CompletedSynchronously;
                     }
-                    finally
-                    {
-                        if (suppressContext) ExecutionContext.RestoreFlow();
-                    }
                     asyncLocal.Value = 0;
 
                     if (pending)
index 7865f20..bf7e8c5 100644 (file)
@@ -122,27 +122,25 @@ namespace System.Net.Sockets.Tests
                 using (Socket server = await acceptTask)
                 using (var receiveSaea = new SocketAsyncEventArgs())
                 {
-                    if (suppressed)
+                    using (suppressed ? ExecutionContext.SuppressFlow() : default)
                     {
-                        ExecutionContext.SuppressFlow();
-                    }
-
-                    var local = new AsyncLocal<int>();
-                    local.Value = 42;
-                    int threadId = Environment.CurrentManagedThreadId;
+                        var local = new AsyncLocal<int>();
+                        local.Value = 42;
+                        int threadId = Environment.CurrentManagedThreadId;
 
-                    var mres = new ManualResetEventSlim();
-                    receiveSaea.SetBuffer(new byte[1], 0, 1);
-                    receiveSaea.Completed += delegate
-                    {
-                        Assert.NotEqual(threadId, Environment.CurrentManagedThreadId);
-                        Assert.Equal(suppressed ? 0 : 42, local.Value);
-                        mres.Set();
-                    };
-
-                    Assert.True(client.ReceiveAsync(receiveSaea));
-                    server.Send(new byte[1]);
-                    mres.Wait();
+                        var mres = new ManualResetEventSlim();
+                        receiveSaea.SetBuffer(new byte[1], 0, 1);
+                        receiveSaea.Completed += delegate
+                        {
+                            Assert.NotEqual(threadId, Environment.CurrentManagedThreadId);
+                            Assert.Equal(suppressed ? 0 : 42, local.Value);
+                            mres.Set();
+                        };
+
+                        Assert.True(client.ReceiveAsync(receiveSaea));
+                        server.Send(new byte[1]);
+                        mres.Wait();
+                    }
                 }
             }
         }
index 3d9d7ae..0fd2f1e 100644 (file)
   <data name="InvalidOperation_CannotRestoreUnsuppressedFlow" xml:space="preserve">
     <value>Cannot restore context flow when it is not suppressed.</value>
   </data>
-  <data name="InvalidOperation_CannotSuppressFlowMultipleTimes" xml:space="preserve">
-    <value>Context flow is already suppressed.</value>
-  </data>
-  <data name="InvalidOperation_CannotUseAFCMultiple" xml:space="preserve">
-    <value>AsyncFlowControl object can be used only once to call Undo().</value>
-  </data>
   <data name="InvalidOperation_CannotUseAFCOtherThread" xml:space="preserve">
     <value>AsyncFlowControl object must be used on the thread where it was created.</value>
   </data>
index f9c966f..a08bf91 100644 (file)
@@ -102,15 +102,14 @@ namespace System.Threading
         {
             Thread currentThread = Thread.CurrentThread;
             ExecutionContext? executionContext = currentThread._executionContext ?? Default;
-            if (executionContext.m_isFlowSuppressed)
+
+            AsyncFlowControl asyncFlowControl = default;
+            if (!executionContext.m_isFlowSuppressed)
             {
-                throw new InvalidOperationException(SR.InvalidOperation_CannotSuppressFlowMultipleTimes);
+                currentThread._executionContext = executionContext.ShallowClone(isFlowSuppressed: true);
+                asyncFlowControl.Initialize(currentThread);
             }
 
-            executionContext = executionContext.ShallowClone(isFlowSuppressed: true);
-            AsyncFlowControl asyncFlowControl = default;
-            currentThread._executionContext = executionContext;
-            asyncFlowControl.Initialize(currentThread);
             return asyncFlowControl;
         }
 
@@ -563,10 +562,11 @@ namespace System.Threading
 
         public void Undo()
         {
-            if (_thread == null)
+            if (_thread is null)
             {
-                throw new InvalidOperationException(SR.InvalidOperation_CannotUseAFCMultiple);
+                return;
             }
+
             if (Thread.CurrentThread != _thread)
             {
                 throw new InvalidOperationException(SR.InvalidOperation_CannotUseAFCOtherThread);
index 2187590..b68b3d2 100644 (file)
@@ -1007,20 +1007,14 @@ namespace System.Threading.Tasks.Tests
             var al = new AsyncLocal<int>();
             al.Value = 42;
 
-            if (!flowContext)
+            Task t;
+            using (!flowContext ? ExecutionContext.SuppressFlow() : default)
             {
-                ExecutionContext.SuppressFlow();
-            }
-
-            Task t = Parallel.ForEachAsync(Iterate(), async (item, cancellationToken) =>
-            {
-                await Task.Yield();
-                Assert.Equal(flowContext ? 42 : 0, al.Value);
-            });
-
-            if (!flowContext)
-            {
-                ExecutionContext.RestoreFlow();
+                t = Parallel.ForEachAsync(Iterate(), async (item, cancellationToken) =>
+                {
+                    await Task.Yield();
+                    Assert.Equal(flowContext ? 42 : 0, al.Value);
+                });
             }
 
             await t;
index 5ba5ec4..bd151c6 100644 (file)
@@ -16,17 +16,12 @@ namespace System.Threading.Tasks.Tests
         public void SuppressFlow_TaskCapturesContextAccordingly(bool suppressFlow)
         {
             Assert.False(ExecutionContext.IsFlowSuppressed());
-            if (suppressFlow) ExecutionContext.SuppressFlow();
-            try
+            using (suppressFlow ? ExecutionContext.SuppressFlow() : default)
             {
                 var asyncLocal = new AsyncLocal<int>();
                 Task.Factory.StartNew(() => asyncLocal.Value = 42, CancellationToken.None, TaskCreationOptions.None, new InlineTaskScheduler()).Wait();
                 Assert.Equal(suppressFlow ? 42 : 0, asyncLocal.Value);
             }
-            finally
-            {
-                if (suppressFlow) ExecutionContext.RestoreFlow();
-            }
         }
 
         [ActiveIssue("https://github.com/dotnet/runtime/issues/57331")]
index c770890..dd7f3eb 100644 (file)
@@ -106,16 +106,11 @@ namespace System.Threading.Tests
         [Fact]
         public static async Task CaptureAndRunOnFlowSuppressedContext()
         {
-            ExecutionContext.SuppressFlow();
-            try
+            using (ExecutionContext.SuppressFlow())
             {
                 ExecutionContext ec = ExecutionContext.Capture();
                 Assert.Throws<InvalidOperationException>(() => ExecutionContext.Run(ec, _ => { }, null));
             }
-            finally
-            {
-                ExecutionContext.RestoreFlow();
-            }
         }
 
         [Fact]
@@ -598,16 +593,11 @@ namespace System.Threading.Tests
             // Check Running with the contexts captured when setting the locals
             TestCapturedExecutionContexts();
 
-            ExecutionContext.SuppressFlow();
-            try
+            using (ExecutionContext.SuppressFlow())
             {
                 // Re-check restoring, but starting with a suppressed flow
                 TestCapturedExecutionContexts();
             }
-            finally
-            {
-                ExecutionContext.RestoreFlow();
-            }
 
             // -- Local functions --
             void ValidateAsyncLocalsValuesNull()
index 2054fb3..178a6dc 100644 (file)
@@ -96,9 +96,18 @@ namespace System.Threading.Tests
                     () => ExecutionContext.SuppressFlow(),
                     () => ExecutionContext.RestoreFlow());
 
+                Assert.False(ExecutionContext.IsFlowSuppressed());
                 Assert.Throws<InvalidOperationException>(() => ExecutionContext.RestoreFlow());
+
+                Assert.False(ExecutionContext.IsFlowSuppressed());
                 asyncFlowControl = ExecutionContext.SuppressFlow();
-                Assert.Throws<InvalidOperationException>(() => ExecutionContext.SuppressFlow());
+                Assert.True(ExecutionContext.IsFlowSuppressed());
+
+                Assert.Equal(default, ExecutionContext.SuppressFlow());
+                Assert.True(ExecutionContext.IsFlowSuppressed());
+
+                ExecutionContext.SuppressFlow().Dispose();
+                Assert.True(ExecutionContext.IsFlowSuppressed());
 
                 ThreadTestHelpers.RunTestInBackgroundThread(() =>
                 {
@@ -109,8 +118,9 @@ namespace System.Threading.Tests
                 });
 
                 asyncFlowControl.Undo();
-                Assert.Throws<InvalidOperationException>(() => asyncFlowControl.Undo());
-                Assert.Throws<InvalidOperationException>(() => asyncFlowControl.Dispose());
+
+                asyncFlowControl.Undo();
+                asyncFlowControl.Dispose();
 
                 // Changing an async local value does not prevent undoing a flow-suppressed execution context. In .NET Core, the
                 // execution context is immutable, so changing an async local value changes the execution context instance,