// to poll for the real state until we're done connecting.
private bool _nonBlockingConnectInProgress;
- // Keep track of the kind of endpoint used to do a non-blocking connect, so we can set
- // it to _rightEndPoint when we discover we're connected.
- private EndPoint? _nonBlockingConnectRightEndPoint;
+ // Keep track of the kind of endpoint used to do a connect, so we can set
+ // it to _rightEndPoint when we're connected.
+ private EndPoint? _pendingConnectRightEndPoint;
// These are constants initialized by constructor.
private AddressFamily _addressFamily;
if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
{
- // Update the state if we've become connected after a non-blocking connect.
- _isConnected = true;
- _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
- UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
+ SetToConnected();
}
if (_rightEndPoint == null)
{
if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
{
- // Update the state if we've become connected after a non-blocking connect.
- _isConnected = true;
- _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
- UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
+ // Update the state if we've become connected after a non-blocking connect.
+ SetToConnected();
}
if (_rightEndPoint == null || !_isConnected)
if (_nonBlockingConnectInProgress && Poll(0, SelectMode.SelectWrite))
{
- // Update the state if we've become connected after a non-blocking connect.
- _isConnected = true;
- _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
- UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
+ // Update the state if we've become connected after a non-blocking connect.
+ SetToConnected();
}
return _isConnected;
ValidateForMultiConnect(isMultiEndpoint: false);
Internals.SocketAddress socketAddress = Serialize(ref remoteEP);
-
- if (!Blocking)
- {
- _nonBlockingConnectRightEndPoint = remoteEP;
- _nonBlockingConnectInProgress = true;
- }
+ _pendingConnectRightEndPoint = remoteEP;
+ _nonBlockingConnectInProgress = !Blocking;
DoConnect(remoteEP, socketAddress);
}
}
e._socketAddress = Serialize(ref endPointSnapshot);
+ _pendingConnectRightEndPoint = endPointSnapshot;
+ _nonBlockingConnectInProgress = false;
WildcardBindForConnectIfNecessary(endPointSnapshot.AddressFamily);
- // Save the old RightEndPoint and prep new RightEndPoint.
- EndPoint? oldEndPoint = _rightEndPoint;
- _rightEndPoint ??= endPointSnapshot;
-
if (SocketsTelemetry.Log.IsEnabled())
{
SocketsTelemetry.Log.ConnectStart(e._socketAddress!);
SocketsTelemetry.Log.AfterConnect(SocketError.NotSocket, ex.Message);
}
- _rightEndPoint = oldEndPoint;
_localEndPoint = null;
// Clear in-use flag on event args object.
if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(SocketError.Success);
- // Save a copy of the EndPoint so we can use it for Create().
- _rightEndPoint ??= endPointSnapshot;
-
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}");
// Update state and performance counters.
+ _pendingConnectRightEndPoint = endPointSnapshot;
+ _nonBlockingConnectInProgress = false;
SetToConnected();
if (NetEventSource.Log.IsEnabled()) NetEventSource.Connected(this, LocalEndPoint, RemoteEndPoint);
}
return;
}
+ Debug.Assert(_nonBlockingConnectInProgress == false);
+
// Update the status: this socket was indeed connected at
// some point in time update the perf counter as well.
_isConnected = true;
_isDisconnected = false;
+ _rightEndPoint ??= _pendingConnectRightEndPoint;
+ _pendingConnectRightEndPoint = null;
UpdateLocalEndPointOnConnect();
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "now connected");
}
}
[Fact]
+ public async Task Connect_DualMode_MultiAddressFamilyConnect_RetrievedEndPoints_Success()
+ {
+ if (!SupportsMultiConnect)
+ return;
+
+ int port;
+ using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port))
+ using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp))
+ {
+ Assert.True(client.DualMode);
+
+ Task connectTask = MultiConnectAsync(client, new IPAddress[] { IPAddress.IPv6Loopback, IPAddress.Loopback }, port);
+ await connectTask;
+
+ var localEndPoint = client.LocalEndPoint as IPEndPoint;
+ Assert.NotNull(localEndPoint);
+ Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address);
+
+ var remoteEndPoint = client.RemoteEndPoint as IPEndPoint;
+ Assert.NotNull(remoteEndPoint);
+ Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address);
+ }
+ }
+
+ [Fact]
+ public async Task Connect_DualMode_DnsConnect_RetrievedEndPoints_Success()
+ {
+ var localhostAddresses = Dns.GetHostAddresses("localhost");
+ if (Array.IndexOf(localhostAddresses, IPAddress.Loopback) == -1 ||
+ Array.IndexOf(localhostAddresses, IPAddress.IPv6Loopback) == -1)
+ {
+ return;
+ }
+
+ int port;
+ using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback, out port))
+ using (Socket client = new Socket(SocketType.Stream, ProtocolType.Tcp))
+ {
+ Assert.True(client.DualMode);
+
+ Task connectTask = ConnectAsync(client, new DnsEndPoint("localhost", port));
+ await connectTask;
+
+ var localEndPoint = client.LocalEndPoint as IPEndPoint;
+ Assert.NotNull(localEndPoint);
+ Assert.Equal(IPAddress.Loopback.MapToIPv6(), localEndPoint.Address);
+
+ var remoteEndPoint = client.RemoteEndPoint as IPEndPoint;
+ Assert.NotNull(remoteEndPoint);
+ Assert.Equal(IPAddress.Loopback.MapToIPv6(), remoteEndPoint.Address);
+ }
+ }
+
+ [Fact]
public async Task Connect_OnConnectedSocket_Fails()
{
int port;
await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.AcceptAsync(s));
await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.AcceptAsync(s, null));
+ await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
+ await Assert.ThrowsAsync<InvalidOperationException>(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
+
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint, CancellationToken.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ConnectAsync(s, badEndPoint.Address, badEndPoint.Port));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, buffer.AsMemory(), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveAsync(s, new ArraySegment<byte>[] { new ArraySegment<byte>(buffer) }, SocketFlags.None));
- await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
- await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.ReceiveMessageFromAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None, badEndPoint));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.SendAsync(s, new ArraySegment<byte>(buffer), SocketFlags.None));
await Assert.ThrowsAsync<SocketException>(async () => await SocketTaskExtensions.SendAsync(s, buffer.AsMemory(), SocketFlags.None));