Socket: don't assign right endpoint until the connect is successful. (#53581)
authorTom Deseyn <tom.deseyn@gmail.com>
Sat, 5 Jun 2021 02:09:30 +0000 (04:09 +0200)
committerGitHub <noreply@github.com>
Sat, 5 Jun 2021 02:09:30 +0000 (19:09 -0700)
* Socket: don't assign right endpoint until the connect is successful.

'Right endpoint' must match the address family of the Socket or
we can't serialize the LocalEndPoint and RemoteEndPoint.

When multiple connect attempts are made against a DualMode Socket with
both IPv4 and IPv6 addresses, a failed attempt must not set 'right
endpoint'.

* SocketTaskExtensionsTest.EnsureMethodsAreCallable: update expected exceptions

* PR feedback

* EnsureMethodsAreCallable: move ReceiveFromAsync before ConnectAsync to avoid wildcard bind on Windows that leads to a different exception

src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTaskExtensionsTest.cs

index 30b5c05..247b02a 100644 (file)
@@ -52,9 +52,9 @@ namespace System.Net.Sockets
         // to poll for the real state until we're done connecting.
         private bool _nonBlockingConnectInProgress;
 
-        // Keep track of the kind of endpoint used to do a non-blocking connect, so we can set
-        // it to _rightEndPoint when we discover we're connected.
-        private EndPoint? _nonBlockingConnectRightEndPoint;
+        // Keep track of the kind of endpoint used to do a connect, so we can set
+        // it to _rightEndPoint when we're connected.
+        private EndPoint? _pendingConnectRightEndPoint;
 
         // These are constants initialized by constructor.
         private AddressFamily _addressFamily;
@@ -285,11 +285,8 @@ namespace System.Net.Sockets
 
                 if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
                 {
-                    // Update the state if we've become connected after a non-blocking connect.
-                    _isConnected = true;
-                    _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
-                    UpdateLocalEndPointOnConnect();
                     _nonBlockingConnectInProgress = false;
+                    SetToConnected();
                 }
 
                 if (_rightEndPoint == null)
@@ -332,11 +329,9 @@ namespace System.Net.Sockets
                 {
                     if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
                     {
-                        // Update the state if we've become connected after a non-blocking connect.
-                        _isConnected = true;
-                        _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
-                        UpdateLocalEndPointOnConnect();
                         _nonBlockingConnectInProgress = false;
+                        // Update the state if we've become connected after a non-blocking connect.
+                        SetToConnected();
                     }
 
                     if (_rightEndPoint == null || !_isConnected)
@@ -439,11 +434,9 @@ namespace System.Net.Sockets
 
                 if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
                 {
-                    // Update the state if we've become connected after a non-blocking connect.
-                    _isConnected = true;
-                    _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
-                    UpdateLocalEndPointOnConnect();
                     _nonBlockingConnectInProgress = false;
+                    // Update the state if we've become connected after a non-blocking connect.
+                    SetToConnected();
                 }
 
                 return _isConnected;
@@ -856,12 +849,8 @@ namespace System.Net.Sockets
             ValidateForMultiConnect(isMultiEndpoint: false);
 
             Internals.SocketAddress socketAddress = Serialize(ref remoteEP);
-
-            if (!Blocking)
-            {
-                _nonBlockingConnectRightEndPoint = remoteEP;
-                _nonBlockingConnectInProgress = true;
-            }
+            _pendingConnectRightEndPoint = remoteEP;
+            _nonBlockingConnectInProgress = !Blocking;
 
             DoConnect(remoteEP, socketAddress);
         }
@@ -2768,13 +2757,11 @@ namespace System.Net.Sockets
                 }
 
                 e._socketAddress = Serialize(ref endPointSnapshot);
+                _pendingConnectRightEndPoint = endPointSnapshot;
+                _nonBlockingConnectInProgress = false;
 
                 WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily);
 
-                // Save the old RightEndPoint and prep new RightEndPoint.
-                EndPoint? oldEndPoint = _rightEndPoint;
-                _rightEndPoint ??= endPointSnapshot;
-
                 if (SocketsTelemetry.Log.IsEnabled())
                 {
                     SocketsTelemetry.Log.ConnectStart(e._socketAddress!);
@@ -2801,7 +2788,6 @@ namespace System.Net.Sockets
                         SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket, ex.Message);
                     }
 
-                    _rightEndPoint = oldEndPoint;
                     _localEndPoint = null;
 
                     // Clear in-use flag on event args object.
@@ -3217,12 +3203,11 @@ namespace System.Net.Sockets
 
             if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(SocketError.Success);
 
-            // Save a copy of the EndPoint so we can use it for Create().
-            _rightEndPoint ??= endPointSnapshot;
-
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}");
 
             // Update state and performance counters.
+            _pendingConnectRightEndPoint = endPointSnapshot;
+            _nonBlockingConnectInProgress = false;
             SetToConnected();
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Connected(this, LocalEndPoint, RemoteEndPoint);
         }
@@ -3659,10 +3644,14 @@ namespace System.Net.Sockets
                 return;
             }
 
+            Debug.Assert(_nonBlockingConnectInProgress == false);
+
             // Update the status: this socket was indeed connected at
             // some point in time update the perf counter as well.
             _isConnected = true;
             _isDisconnected = false;
+            _rightEndPoint ??= _pendingConnectRightEndPoint;
+            _pendingConnectRightEndPoint = null;
             UpdateLocalEndPointOnConnect();
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected");
         }
index d31cf2d..5f9730e 100644 (file)
@@ -84,6 +84,60 @@ namespace System.Net.Sockets.Tests
         }
 
         [Fact]
+        public async Task Connect_DualMode_MultiAddressFamilyConnect_RetrievedEndPoints_Success()
+        {
+            if (!SupportsMultiConnect)
+                return;
+
+            int port;
+            using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port))
+            using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp))
+            {
+                Assert.True(client.DualMode);
+
+                Task connectTask = MultiConnectAsync(client, new IPAddress[] { IPAddress.IPv6Loopback, IPAddress.Loopback }, port);
+                await connectTask;
+
+                var localEndPoint = client.LocalEndPoint as IPEndPoint;
+                Assert.NotNull(localEndPoint);
+                Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address);
+
+                var remoteEndPoint = client.RemoteEndPoint as IPEndPoint;
+                Assert.NotNull(remoteEndPoint);
+                Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address);
+            }
+        }
+
+        [Fact]
+        public async Task Connect_DualMode_DnsConnect_RetrievedEndPoints_Success()
+        {
+            var localhostAddresses = Dns.GetHostAddresses("localhost");
+            if (Array.IndexOf(localhostAddresses, IPAddress.Loopback) == -1 ||
+                Array.IndexOf(localhostAddresses, IPAddress.IPv6Loopback) == -1)
+            {
+                return;
+            }
+
+            int port;
+            using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port))
+            using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp))
+            {
+                Assert.True(client.DualMode);
+
+                Task connectTask = ConnectAsync(client, new DnsEndPoint("localhost", port));
+                await connectTask;
+
+                var localEndPoint = client.LocalEndPoint as IPEndPoint;
+                Assert.NotNull(localEndPoint);
+                Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address);
+
+                var remoteEndPoint = client.RemoteEndPoint as IPEndPoint;
+                Assert.NotNull(remoteEndPoint);
+                Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address);
+            }
+        }
+
+        [Fact]
         public async Task Connect_OnConnectedSocket_Fails()
         {
             int port;
index bdc7ec7..c53ab95 100644 (file)
@@ -23,6 +23,9 @@ namespace System.Net.Sockets.Tests
             await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.AcceptAsync(s));
             await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.AcceptAsync(s, null));
 
+            await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
+            await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
+
             await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint));
             await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint, CancellationToken.None));
             await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint.Address, badEndPoint.Port));
@@ -35,8 +38,6 @@ namespace System.Net.Sockets.Tests
             await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None));
             await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, buffer.AsMemory(), SocketFlags.None));
             await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment<byte>[] { new ArraySegment<byte>(buffer) }, SocketFlags.None));
-            await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
-            await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
 
             await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.SendAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None));
             await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.SendAsync(s, buffer.AsMemory(), SocketFlags.None));