Improve behavior of TcpClient Dispose concurrent with ConnectAsync
authorStephen Toub <stoub@microsoft.com>
Mon, 11 Jul 2016 01:52:53 +0000 (21:52 -0400)
committerStephen Toub <stoub@microsoft.com>
Mon, 11 Jul 2016 11:23:45 +0000 (07:23 -0400)
TcpClient.Dispose is not meant to be used concurrently with other operations on the instance, but some code does do so as a way to provide a cancellation mechanism.  There are two easily hit issues with this:
1. On Unix, the ConnectAsync operation doesn't publish the actual Socket on which a connection was made until after the connection is established, as it needs to use temporary sockets to try each potential target address, and publishing it before connecting could end up publishing a Socket that won't end up being the actual one used.  As such, if a Dispose occurs during the ConnectAsync operation, it won't end up disposing the socket being used to make the connection, such that the connection won't be canceled.
2. On all platforms, Dispose nulls out the client socket field.  When the connection then subsequently completes, it hits a NullReferenceException while trying to dereference that field.

This commit addresses both issues:
a. When the client is disposed, on Unix we cancel a CancellationTokenSource, and each Socket we create is registered with that source to dispose the socket.  That way, we dispose of each socket even if it hasn't been published onto the instance yet.
b. We grab the Socket from the field and check for null prior to dereferencing it.

Commit migrated from https://github.com/dotnet/corefx/commit/e3e7ca7630c486f35bf7eac330dedcf729c74e71

src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPClient.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPClient.Windows.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPClient.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpClientTest.cs

index fc57485..838f4de 100644 (file)
@@ -43,6 +43,9 @@ namespace System.Net.Sockets
         // We use a separate bool field to store whether the value has been set.
         // We don't use nullables, due to one of the properties being a reference type.
 
+        
+        private static readonly CancellationTokenSource s_canceledSource = CreateCanceledSource();
+        private CancellationTokenSource _disposing;
         private ShadowOptions _shadowOptions; // shadow state used in public properties before the socket is created
         private int _connectRunning; // tracks whether a connect operation that could set _clientSocket is currently running
 
@@ -51,6 +54,23 @@ namespace System.Net.Sockets
             // Nop.  We want to lazily-allocate the socket.
         }
 
+        private void DisposeCore()
+        {
+            // In case there's a concurrent ConnectAsync operation, we need to signal to that
+            // operation that we're being disposed of, so that it can dispose of the current
+            // temporary socket that hasn't yet been published as the official one.  If there's
+            // already a cancellation source, just cancel it.  If there isn't, try to swap in
+            // an already-canceled source so that we don't have to artificially create a new one
+            // (since not all async connect operations require temporary sockets), but we may
+            // lose that race condition, in which case we still need to dispose of whatever is
+            // published.  It's fine to Cancel an already canceled cancellation source.
+            if (Volatile.Read(ref _disposing) == null)
+            {
+                Interlocked.CompareExchange(ref _disposing, s_canceledSource, null);
+            }
+            _disposing.Cancel();
+        }
+
         private Socket ClientCore
         {
             get
@@ -59,8 +79,9 @@ namespace System.Net.Sockets
                 try
                 {
                     // The Client socket is being explicitly accessed, so we're forced
-                    // to create it if it doesn't exist.
-                    if (_clientSocket == null)
+                    // to create it if it doesn't exist.  Only do so if we haven't been disposed of,
+                    // which nulls out the field.
+                    if (_clientSocket == null && (_disposing == null || !_disposing.IsCancellationRequested))
                     {
                         // Create the socket, and transfer to it any of our shadow properties.
                         _clientSocket = CreateSocket();
@@ -241,6 +262,14 @@ namespace System.Net.Sockets
         {
             try
             {
+                // Make sure we've created a disposing cancellation source so that we get alerted
+                // to a potentially concurrent disposal happening.
+                if (Volatile.Read(ref _disposing) != null && _disposing.IsCancellationRequested)
+                {
+                    throw new ObjectDisposedException(GetType().Name);
+                }
+                Interlocked.CompareExchange(ref _disposing, new CancellationTokenSource(), null);
+
                 // For each address, create a new socket (configured appropriately) and try to connect
                 // to the endpoint.  If we're successful, set the newly connected socket as the client
                 // socket, and we're done.  If we're unsuccessful, try the next address.
@@ -250,15 +279,30 @@ namespace System.Net.Sockets
                     Socket s = CreateSocket();
                     try
                     {
+                        // Configure the socket
                         ApplyInitializedOptionsToSocket(s);
-                        await s.ConnectAsync(address, port).ConfigureAwait(false);
 
+                        // Register to dispose of the socket when the TcpClient is Dispose'd of.
+                        // Some consumers use Dispose as a way to cancel a connect operation, as
+                        // TcpClient.Dispose calls Socket.Dispose on the stored socket... but we've
+                        // not stored the socket into the field yet, as doing so will publish it
+                        // to be seen via the Client property.  Instead, we register to be notified
+                        // when Dispose is called or has happened, and Dispose of the socket
+                        using (_disposing.Token.Register(o => ((Socket)o).Dispose(), s))
+                        {
+                            await s.ConnectAsync(address, port).ConfigureAwait(false);
+                        }
                         _clientSocket = s;
                         _active = true;
 
+                        if (_disposing.IsCancellationRequested)
+                        {
+                            s.Dispose();
+                            _clientSocket = null;
+                        }
                         return;
                     }
-                    catch (Exception exc)
+                    catch (Exception exc) when (!(exc is ObjectDisposedException))
                     {
                         s.Dispose();
                         lastException = ExceptionDispatchInfo.Capture(exc);
@@ -473,6 +517,13 @@ namespace System.Net.Sockets
             Volatile.Write(ref _connectRunning, 0);
         }
 
+        private static CancellationTokenSource CreateCanceledSource()
+        {
+            var cts = new CancellationTokenSource();
+            cts.Cancel();
+            return cts;
+        }
+
         private sealed class ShadowOptions
         {
             internal int _exclusiveAddressUse;
index 0f24c35..aa0b670 100644 (file)
@@ -13,6 +13,11 @@ namespace System.Net.Sockets
             Client = CreateSocket();
         }
 
+        private void DisposeCore()
+        {
+            // Nop.  No additional state that needs to be disposed of.
+        }
+
         // Used by the class to provide the underlying network socket.
         private Socket ClientCore
         {
index 7847fdd..e56a75b 100644 (file)
@@ -76,16 +76,8 @@ namespace System.Net.Sockets
         [DebuggerBrowsable(DebuggerBrowsableState.Never)] // TODO: Remove once https://github.com/dotnet/corefx/issues/5868 is addressed.
         public Socket Client
         {
-            get
-            {
-                Socket s = ClientCore;
-                Debug.Assert(s != null);
-                return s;
-            }
-            set
-            {
-                ClientCore = value;
-            }
+            get { return ClientCore; }
+            set { ClientCore = value; }
         }
 
         public bool Connected { get { return ConnectedCore; } }
@@ -139,7 +131,14 @@ namespace System.Net.Sockets
                 NetEventSource.Enter(NetEventSource.ComponentType.Socket, this, "EndConnect", asyncResult);
             }
 
-            Client.EndConnect(asyncResult);
+            Socket s = Client;
+            if (s == null)
+            {
+                // Dispose nulls out the client socket field.
+                throw new ObjectDisposedException(GetType().Name);
+            }
+            s.EndConnect(asyncResult);
+
             _active = true;
             if (NetEventSource.Log.IsEnabled())
             {
@@ -223,6 +222,8 @@ namespace System.Net.Sockets
                     }
                 }
 
+                DisposeCore(); // platform-specific disposal work
+
                 GC.SuppressFinalize(this);
             }
 
index dc49f66..8dd7394 100644 (file)
@@ -8,6 +8,8 @@ using Xunit.Abstractions;
 using System.Threading.Tasks;
 using System.Net.Test.Common;
 using System.Text;
+using System.Collections.Generic;
+using System.Diagnostics;
 
 namespace System.Net.Sockets.Tests
 {
@@ -217,5 +219,41 @@ namespace System.Net.Sockets.Tests
                 // minimums and maximums, silently capping to those amounts.
             }
         }
+
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task Dispose_CancelsConnectAsync(bool connectByName)
+        {
+            using (var server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                // Set up a server socket to which to connect
+                server.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+                server.Listen(1);
+                var endpoint = (IPEndPoint)server.LocalEndPoint;
+
+                // Connect asynchronously...
+                var client = new TcpClient();
+                Task connectTask = connectByName ?
+                    client.ConnectAsync("localhost", endpoint.Port) :
+                    client.ConnectAsync(endpoint.Address, endpoint.Port);
+
+                // ...and hopefully before it's completed connecting, dispose.
+                var sw = Stopwatch.StartNew();
+                client.Dispose();
+
+                // There is a race condition here.  If the connection succeeds before the
+                // disposal, then the task will complete successfully.  Otherwise, it should
+                // fail with an ObjectDisposedException.
+                try
+                {
+                    await connectTask;
+                }
+                catch (ObjectDisposedException) { }
+                sw.Stop();
+
+                Assert.Null(client.Client); // should be nulled out after Dispose
+            }
+        }
     }
 }