Fix SocketAsyncEventArgs' handling of ExecutionContext (dotnet/corefx#30712)
authorStephen Toub <stoub@microsoft.com>
Tue, 3 Jul 2018 17:21:18 +0000 (13:21 -0400)
committerGitHub <noreply@github.com>
Tue, 3 Jul 2018 17:21:18 +0000 (13:21 -0400)
SocketAsyncEventArgs has a few issues with ExecutionContext, presumably stemming from the fact that capturing ExecutionContext in .NET Framework is not a cheap operation.  As a result, when this code was written, it was optimized for avoiding calls to ExecutionContext.Capture.  The SAEA tries to hold onto a captured ExecutionContext for as long as possible, only re-capturing when either the SAEA is used with a different socket instance or when an event handler is changed.  That has several problems, though.  First, it largely violates the purpose of ExecutionContext, which is to flow information from the point where the async operation begins to the continuation/callback, but if the context is only being captured when the Socket or handler is changed, then the context isn't actually tied to the location where the async operation begins, and that means that data like that in an AsyncLocal doesn't properly flow across the async point.  Second, it means that the SocketAsyncEventArgs (the whole purpose of which is to cache it) can end up keeping state in an ExecutionContext alive well beyond when it should be kept alive, because the SocketAsyncEventArgs is holding onto the ExecutionContext instance until either the Socket or handler is changed.

This commit fixes this behavior.  Since ExecutionContext.Capture in .NET Core is relatively cheap (no allocation, primarily just a ThreadStatic access), we now just always capture the context when starting an operation, and then clear it out when completing the operation.

Commit migrated from https://github.com/dotnet/corefx/commit/851a53bee21639cd65e109b72fd5f9ef90cc0f39

src/libraries/Common/tests/System/Threading/Tasks/TaskTimeoutExtensions.cs
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/ExecutionContextFlowTest.netcoreapp.cs

index de3463e..cc8c1d6 100644 (file)
@@ -23,7 +23,7 @@ namespace System.Threading.Tasks
             }
             else
             {
-                throw new TimeoutException($"Task timed out after {millisecondsTimeout}");
+                throw new TimeoutException($"Task timed out after {millisecondsTimeout}ms");
             }
         }
 
@@ -38,7 +38,7 @@ namespace System.Threading.Tasks
             }
             else
             {
-                throw new TimeoutException($"Task timed out after {millisecondsTimeout}");
+                throw new TimeoutException($"Task timed out after {millisecondsTimeout}ms");
             }
         }
 
@@ -53,7 +53,7 @@ namespace System.Threading.Tasks
             }
             else
             {
-                throw new TimeoutException($"{nameof(WhenAllOrAnyFailed)} timed out after {millisecondsTimeout}");
+                throw new TimeoutException($"{nameof(WhenAllOrAnyFailed)} timed out after {millisecondsTimeout}ms");
             }
         }
 
index 2651237..03f45dd 100644 (file)
@@ -9,6 +9,7 @@ using System.Linq;
 using System.Net.Security;
 using System.Net.Sockets;
 using System.Net.Test.Common;
+using System.Runtime.CompilerServices;
 using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
 using System.Text;
@@ -22,6 +23,67 @@ namespace System.Net.Http.Functional.Tests
     public sealed class SocketsHttpHandler_HttpClientHandler_Asynchrony_Test : HttpClientHandler_Asynchrony_Test
     {
         protected override bool UseSocketsHttpHandler => true;
+
+        [OuterLoop("Relies on finalization")]
+        [Fact]
+        public async Task ExecutionContext_HttpConnectionLifetimeDoesntKeepContextAlive()
+        {
+            var clientCompleted = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
+            await LoopbackServer.CreateClientAndServerAsync(async uri =>
+            {
+                try
+                {
+                    using (HttpClient client = CreateHttpClient())
+                    {
+                        (Task completedWhenFinalized, Task getRequest) = MakeHttpRequestWithTcsSetOnFinalizationInAsyncLocal(client, uri);
+                        await getRequest;
+
+                        for (int i = 0; i < 3; i++)
+                        {
+                            GC.Collect();
+                            GC.WaitForPendingFinalizers();
+                        }
+
+                        await completedWhenFinalized.TimeoutAfter(TestHelper.PassingTestTimeoutMilliseconds);
+                    }
+                }
+                finally
+                {
+                    clientCompleted.SetResult(true);
+                }
+            }, async server =>
+            {
+                await server.AcceptConnectionAsync(async connection =>
+                {
+                    await connection.ReadRequestHeaderAndSendResponseAsync();
+                    await clientCompleted.Task;
+                });
+            });
+        }
+
+        [MethodImpl(MethodImplOptions.NoInlining)] // avoid JIT extending lifetime of the finalizable object
+        private static (Task completedOnFinalized, Task getRequest) MakeHttpRequestWithTcsSetOnFinalizationInAsyncLocal(HttpClient client, Uri uri)
+        {
+            var tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            // Put something in ExecutionContext, start the HTTP request, then undo the EC change.
+            var al = new AsyncLocal<object>() { Value = new SetOnFinalized() { _completedWhenFinalized = tcs } };
+            Task t = client.GetStringAsync(uri);
+            al.Value = null;
+
+            // Return a task that will complete when the SetOnFinalized is finalized,
+            // as well as a task to wait on for the get request; for the get request,
+            // we return a continuation to avoid any test-altering issues related to
+            // the state machine holding onto stuff.
+            t = t.ContinueWith(p => p.GetAwaiter().GetResult());
+            return (tcs.Task, t);
+        }
+
+        private sealed class SetOnFinalized
+        {
+            internal TaskCompletionSource<bool> _completedWhenFinalized;
+            ~SetOnFinalized() => _completedWhenFinalized.SetResult(true);
+        }
     }
 
     public sealed class SocketsHttpHandler_HttpProtocolTests : HttpProtocolTests
index 6e9016f..84ec3aa 100644 (file)
@@ -28,10 +28,6 @@ namespace System.Net.Sockets
         // BytesTransferred property variables.
         private int _bytesTransferred;
 
-        // Completed event property variables.
-        private event EventHandler<SocketAsyncEventArgs> _completed;
-        private bool _completedChanged;
-
         // DisconnectReuseSocket propery variables.
         private bool _disconnectReuseSocket;
 
@@ -200,23 +196,11 @@ namespace System.Net.Sockets
             get { return _bytesTransferred; }
         }
 
-        public event EventHandler<SocketAsyncEventArgs> Completed
-        {
-            add
-            {
-                _completed += value;
-                _completedChanged = true;
-            }
-            remove
-            {
-                _completed -= value;
-                _completedChanged = true;
-            }
-        }
+        public event EventHandler<SocketAsyncEventArgs> Completed;
 
         protected virtual void OnCompleted(SocketAsyncEventArgs e)
         {
-            _completed?.Invoke(e._currentSocket, e);
+            Completed?.Invoke(e._currentSocket, e);
         }
 
         // DisconnectResuseSocket property.
@@ -445,6 +429,9 @@ namespace System.Net.Sockets
         {
             CompleteCore();
 
+            // Clear any ExecutionContext that may have been captured.
+            _context = null;
+
             // Mark as not in-use.
             _operating = Free;
 
@@ -519,21 +506,12 @@ namespace System.Net.Sockets
                 ThrowForNonFreeStatus(status);
             }
 
-            // Set the operation type.
+            // Set the operation type and store the socket as current.
             _completedOperation = operation;
+            _currentSocket = socket;
 
-            // Prepare execution context for callback.
-            // If event delegates have changed or socket has changed
-            // then discard any existing context.
-            if (_completedChanged || socket != _currentSocket)
-            {
-                _completedChanged = false;
-                _currentSocket = socket;
-                _context = null;
-            }
-
-            // Capture execution context if necessary.
-            if (_flowExecutionContext && _context == null)
+            // Capture execution context if needed (it is unless explicitly disabled).
+            if (_flowExecutionContext)
             {
                 _context = ExecutionContext.Capture();
             }
@@ -635,29 +613,33 @@ namespace System.Net.Sockets
 
         internal void FinishOperationAsyncFailure(SocketError socketError, int bytesTransferred, SocketFlags flags)
         {
+            ExecutionContext context = _context; // store context before it's cleared as part of finishing the operation
+
             FinishOperationSyncFailure(socketError, bytesTransferred, flags);
 
-            if (_context == null)
+            if (context == null)
             {
                 OnCompleted(this);
             }
             else
             {
-                ExecutionContext.Run(_context, s_executionCallback, this);
+                ExecutionContext.Run(context, s_executionCallback, this);
             }
         }
 
         internal void FinishConnectByNameAsyncFailure(Exception exception, int bytesTransferred, SocketFlags flags)
         {
+            ExecutionContext context = _context; // store context before it's cleared as part of finishing the operation
+
             FinishConnectByNameSyncFailure(exception, bytesTransferred, flags);
 
-            if (_context == null)
+            if (context == null)
             {
                 OnCompleted(this);
             }
             else
             {
-                ExecutionContext.Run(_context, s_executionCallback, this);
+                ExecutionContext.Run(context, s_executionCallback, this);
             }
         }
 
@@ -668,14 +650,15 @@ namespace System.Net.Sockets
             _connectSocket = connectSocket;
 
             // Complete the operation and raise the event.
+            ExecutionContext context = _context; // store context before it's cleared as part of completing the operation
             Complete();
-            if (_context == null)
+            if (context == null)
             {
                 OnCompleted(this);
             }
             else
             {
-                ExecutionContext.Run(_context, s_executionCallback, this);
+                ExecutionContext.Run(context, s_executionCallback, this);
             }
         }
 
@@ -777,16 +760,18 @@ namespace System.Net.Sockets
 
         internal void FinishOperationAsyncSuccess(int bytesTransferred, SocketFlags flags)
         {
+            ExecutionContext context = _context; // store context before it's cleared as part of finishing the operation
+
             FinishOperationSyncSuccess(bytesTransferred, flags);
 
             // Raise completion event.
-            if (_context == null)
+            if (context == null)
             {
                 OnCompleted(this);
             }
             else
             {
-                ExecutionContext.Run(_context, s_executionCallback, this);
+                ExecutionContext.Run(context, s_executionCallback, this);
             }
         }
     }
index d978eca..67743d1 100644 (file)
@@ -13,6 +13,53 @@ namespace System.Net.Sockets.Tests
 {
     public partial class ExecutionContextFlowTest : FileCleanupTestBase
     {
+        [OuterLoop("Relies on finalization")]
+        [Fact]
+        public void ExecutionContext_NotCachedInSocketAsyncEventArgs()
+        {
+            using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+                listener.Listen(1);
+
+                client.Connect(listener.LocalEndPoint);
+                using (Socket server = listener.Accept())
+                using (var saea = new SocketAsyncEventArgs())
+                {
+                    var receiveCompleted = new ManualResetEventSlim();
+                    saea.Completed += (_, __) => receiveCompleted.Set();
+                    saea.SetBuffer(new byte[1]);
+
+                    var ecDropped = new ManualResetEventSlim();
+                    var al = CreateAsyncLocalWithSetWhenFinalized(ecDropped);
+                    Assert.True(client.ReceiveAsync(saea));
+                    al.Value = null;
+
+                    server.Send(new byte[1]);
+                    Assert.True(receiveCompleted.Wait(TestSettings.PassingTestTimeout));
+
+                    for (int i = 0; i < 3; i++)
+                    {
+                        GC.Collect();
+                        GC.WaitForPendingFinalizers();
+                    }
+
+                    Assert.True(ecDropped.Wait(TestSettings.PassingTestTimeout));
+                }
+            }
+        }
+
+        [MethodImpl(MethodImplOptions.NoInlining)]
+        private static AsyncLocal<object> CreateAsyncLocalWithSetWhenFinalized(ManualResetEventSlim ecDropped) =>
+            new AsyncLocal<object>() { Value = new SetOnFinalized { _setWhenFinalized = ecDropped } };
+
+        private sealed class SetOnFinalized
+        {
+            internal ManualResetEventSlim _setWhenFinalized;
+            ~SetOnFinalized() => _setWhenFinalized.Set();
+        }
+
         [Fact]
         public Task ExecutionContext_FlowsOnlyOnceAcrossAsyncOperations()
         {