fix async logic in MultipleConnectAsync to avoid lock reentrancy issues (#42919)
authorGeoff Kizer <geoffrek@microsoft.com>
Tue, 6 Oct 2020 04:51:10 +0000 (21:51 -0700)
committerGitHub <noreply@github.com>
Tue, 6 Oct 2020 04:51:10 +0000 (21:51 -0700)
* fix async logic in MultipleConnectAsync to avoid lock reentrancy issues

* Update src/libraries/System.Net.Sockets/src/System/Net/Sockets/MultipleConnectAsync.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>
Co-authored-by: Geoffrey Kizer <geoffrek@windows.microsoft.com>
Co-authored-by: Stephen Toub <stoub@microsoft.com>
src/libraries/System.Net.Sockets/src/System/Net/Sockets/MultipleConnectAsync.cs

index b961d3e..33a54eb 100644 (file)
@@ -32,10 +32,13 @@ namespace System.Net.Sockets
 
         private readonly object _lockObject = new object();
 
-        // Called by Socket to kick off the ConnectAsync process.  We'll complete the user's SAEA
-        // when it's done.  Returns true if the operation will be asynchronous, false if it has failed synchronously
+        // Called by Socket to kick off the ConnectAsync process.  We'll complete the user's SAEA when it's done.
+        // Returns true if the operation is pending, false if it completed synchronously.
         public bool StartConnectAsync(SocketAsyncEventArgs args, DnsEndPoint endPoint)
         {
+            IAsyncResult result;
+
+            Debug.Assert(!Monitor.IsEntered(_lockObject));
             lock (_lockObject)
             {
                 if (endPoint.AddressFamily != AddressFamily.Unspecified &&
@@ -63,15 +66,16 @@ namespace System.Net.Sockets
 
                 _state = State.DnsQuery;
 
-                IAsyncResult result = Dns.BeginGetHostAddresses(endPoint.Host, new AsyncCallback(DnsCallback), null);
-                if (result.CompletedSynchronously)
-                {
-                    return DoDnsCallback(result, true);
-                }
-                else
-                {
-                    return true;
-                }
+                result = Dns.BeginGetHostAddresses(endPoint.Host, new AsyncCallback(DnsCallback), null);
+            }
+
+            if (result.CompletedSynchronously)
+            {
+                return DoDnsCallback(result, true);
+            }
+            else
+            {
+                return true;
             }
         }
 
@@ -85,12 +89,14 @@ namespace System.Net.Sockets
         }
 
         // Called when the DNS query completes (either synchronously or asynchronously).  Checks for failure and
-        // starts the first connection attempt if it succeeded.  Returns true if the operation will be asynchronous,
-        // false if it has failed synchronously.
+        // starts the first connection attempt if it succeeded.
+        // Returns true if the operation is pending, false if it completed synchronously.
         private bool DoDnsCallback(IAsyncResult result, bool sync)
         {
             Exception? exception = null;
+            bool pending = false;
 
+            Debug.Assert(!Monitor.IsEntered(_lockObject));
             lock (_lockObject)
             {
                 // If the connection attempt was canceled during the dns query, the user's callback has already been
@@ -128,7 +134,7 @@ namespace System.Net.Sockets
                     _internalArgs.Completed += InternalConnectCallback;
                     _internalArgs.CopyBufferFrom(_userArgs!);
 
-                    exception = AttemptConnection();
+                    (exception, pending) = AttemptConnection();
 
                     if (exception != null)
                     {
@@ -143,18 +149,29 @@ namespace System.Net.Sockets
             {
                 return Fail(sync, exception);
             }
+            else if (!pending)
+            {
+                return DoConnectCallback(_internalArgs!);
+            }
             else
             {
                 return true;
             }
         }
 
+        private void InternalConnectCallback(object? sender, SocketAsyncEventArgs args)
+        {
+            DoConnectCallback(args);
+        }
+
         // Callback which fires when an internal connection attempt completes.
         // If it failed and there are more addresses to try, do it.
-        private void InternalConnectCallback(object? sender, SocketAsyncEventArgs args)
+        // Returns true if the operation is pending, false if it completed synchronously.
+        private bool DoConnectCallback(SocketAsyncEventArgs args)
         {
             Exception? exception = null;
 
+            Debug.Assert(!Monitor.IsEntered(_lockObject));
             lock (_lockObject)
             {
                 if (_state == State.Canceled)
@@ -166,48 +183,61 @@ namespace System.Net.Sockets
                 }
                 else
                 {
-                    Debug.Assert(_state == State.ConnectAttempt);
-
-                    if (args.SocketError == SocketError.Success)
-                    {
-                        // The connection attempt succeeded; go to the completed state.
-                        // The callback will be called outside the lock.
-                        _state = State.Completed;
-                    }
-                    else if (args.SocketError == SocketError.OperationAborted)
+                    while (true)
                     {
-                        // The socket was closed while the connect was in progress.  This can happen if the user
-                        // closes the socket, and is equivalent to a call to CancelConnectAsync
-                        exception = new SocketException((int)SocketError.OperationAborted);
-                        _state = State.Canceled;
-                    }
-                    else
-                    {
-
-                        // Keep track of this because it will be overwritten by AttemptConnection
-                        SocketError currentFailure = args.SocketError;
-                        Exception? connectException = AttemptConnection();
+                        Debug.Assert(_state == State.ConnectAttempt);
 
-                        if (connectException == null)
+                        if (args.SocketError == SocketError.Success)
+                        {
+                            // The connection attempt succeeded; go to the completed state.
+                            // The callback will be called outside the lock.
+                            _state = State.Completed;
+                            break;
+                        }
+                        else if (args.SocketError == SocketError.OperationAborted)
                         {
-                            // don't call the callback, another connection attempt is successfully started
-                            return;
+                            // The socket was closed while the connect was in progress.  This can happen if the user
+                            // closes the socket, and is equivalent to a call to CancelConnectAsync
+                            exception = new SocketException((int)SocketError.OperationAborted);
+                            _state = State.Canceled;
+                            break;
                         }
                         else
                         {
-                            SocketException? socketException = connectException as SocketException;
-                            if (socketException != null && socketException.SocketErrorCode == SocketError.NoData)
+
+                            // Keep track of this because it will be overwritten by AttemptConnection
+                            SocketError currentFailure = args.SocketError;
+
+                            (Exception? connectException, bool pending) = AttemptConnection();
+
+                            if (connectException == null)
                             {
-                                // If the error is NoData, that means there are no more IPAddresses to attempt
-                                // a connection to.  Return the last error from an actual connection instead.
-                                exception = new SocketException((int)currentFailure);
+                                if (pending)
+                                {
+                                    // don't call the callback, another connection attempt is successfully started
+                                    return true;
+                                }
+
+                                // We have a sync completion from AttemptConnection.
+                                // Loop around and process its results.
                             }
                             else
                             {
-                                exception = connectException;
+                                SocketException? socketException = connectException as SocketException;
+                                if (socketException != null && socketException.SocketErrorCode == SocketError.NoData)
+                                {
+                                    // If the error is NoData, that means there are no more IPAddresses to attempt
+                                    // a connection to.  Return the last error from an actual connection instead.
+                                    exception = new SocketException((int)currentFailure);
+                                }
+                                else
+                                {
+                                    exception = connectException;
+                                }
+
+                                _state = State.Completed;
+                                break;
                             }
-
-                            _state = State.Completed;
                         }
                     }
                 }
@@ -221,38 +251,37 @@ namespace System.Net.Sockets
             {
                 AsyncFail(exception);
             }
+
+            return false;
         }
 
-        // Called to initiate a connection attempt to the next address in the list.  Returns an exception
-        // if the attempt failed synchronously, or null if it was successfully initiated.
-        private Exception? AttemptConnection()
+        // Called to initiate a connection attempt to the next address in the list.
+        // Returns (exception, false) if the attempt failed synchronously.
+        // Returns (null, true) if pending, or (null, false) if completed synchronously.
+        private (Exception? exception, bool pending) AttemptConnection()
         {
             try
             {
                 IPAddress? attemptAddress = GetNextAddress(out Socket? attemptSocket);
                 if (attemptAddress == null)
                 {
-                    return new SocketException((int)SocketError.NoData);
+                    return (new SocketException((int)SocketError.NoData), false);
                 }
                 Debug.Assert(attemptSocket != null);
 
                 SocketAsyncEventArgs args = _internalArgs!;
                 args.RemoteEndPoint = new IPEndPoint(attemptAddress, _endPoint!.Port);
-                if (!attemptSocket.ConnectAsync(args))
-                {
-                    InternalConnectCallback(null, args);
-                }
-
-                return null;
+                bool pending = attemptSocket.ConnectAsync(args);
+                return (null, pending);
             }
             catch (ObjectDisposedException)
             {
                 // This can happen if the user closes the socket and is equivalent to a call to CancelConnectAsync.
-                return new SocketException((int)SocketError.OperationAborted);
+                return (new SocketException((int)SocketError.OperationAborted), false);
             }
             catch (Exception e)
             {
-                return e;
+                return (e, false);
             }
         }
 
@@ -317,6 +346,7 @@ namespace System.Net.Sockets
         {
             bool callOnFail = false;
 
+            Debug.Assert(!Monitor.IsEntered(_lockObject));
             lock (_lockObject)
             {
                 switch (_state)