Fix handling of new connection in MsQuicListener (#57319)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Tue, 17 Aug 2021 01:35:26 +0000 (18:35 -0700)
committerGitHub <noreply@github.com>
Tue, 17 Aug 2021 01:35:26 +0000 (18:35 -0700)
* update

* feedback from review

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs

index 1eba27e..ddd002d 100644 (file)
@@ -46,6 +46,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             // These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown).
             public MsQuicConnection? Connection;
+            public MsQuicListener.State? ListenerState;
 
             public TaskCompletionSource<uint>? ConnectTcs;
             // TODO: only allocate these when there is an outstanding shutdown.
@@ -135,11 +136,10 @@ namespace System.Net.Quic.Implementations.MsQuic
         internal string TraceId() => _state.TraceId;
 
         // constructor for inbound connections
-        public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
+        public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, MsQuicListener.State listenerState, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
         {
             _state.Handle = handle;
             _state.StateGCHandle = GCHandle.Alloc(_state);
-            _state.Connected = true;
             _state.RemoteCertificateRequired = remoteCertificateRequired;
             _state.RevocationMode = revocationMode;
             _state.RemoteCertificateValidationCallback = remoteCertificateValidationCallback;
@@ -161,6 +161,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 throw;
             }
 
+            _state.ListenerState = listenerState;
             _state.TraceId = MsQuicTraceHelper.GetTraceId(_state.Handle);
             if (NetEventSource.Log.IsEnabled())
             {
@@ -223,7 +224,34 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private static uint HandleEventConnected(State state, ref ConnectionEvent connectionEvent)
         {
-            if (!state.Connected)
+            if (state.Connected)
+            {
+                return MsQuicStatusCodes.Success;
+            }
+
+            if (state.IsServer)
+            {
+                state.Connected = true;
+                MsQuicListener.State? listenerState = state.ListenerState;
+                state.ListenerState = null;
+
+                if (listenerState != null)
+                {
+                    if (listenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection))
+                    {
+                        // Move connection from pending to Accept queue and hand it out.
+                        if (listenerState.AcceptConnectionQueue.Writer.TryWrite(connection))
+                        {
+                            return MsQuicStatusCodes.Success;
+                        }
+                        // Listener is closed
+                        connection.Dispose();
+                    }
+                }
+
+                return MsQuicStatusCodes.UserCanceled;
+            }
+            else
             {
                 // Connected will already be true for connections accepted from a listener.
                 Debug.Assert(!Monitor.IsEntered(state));
@@ -271,6 +299,18 @@ namespace System.Net.Quic.Implementations.MsQuic
             // This is the final event on the connection, so free the GCHandle used by the event callback.
             state.StateGCHandle.Free();
 
+            if (state.ListenerState != null)
+            {
+                // This is inbound connection that never got connected - becasue of TLS validation or some other reason.
+                // Remove connection from pending queue and dispose it.
+                if (state.ListenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection))
+                {
+                    connection.Dispose();
+                }
+
+                state.ListenerState = null;
+            }
+
             state.Connection = null;
 
             state.ShutdownTcs.SetResult(MsQuicStatusCodes.Success);
@@ -297,6 +337,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             {
                 bidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
             }
+
             return MsQuicStatusCodes.Success;
         }
 
@@ -418,6 +459,11 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                     if (!success)
                     {
+                        if (state.IsServer)
+                        {
+                            return MsQuicStatusCodes.UserCanceled;
+                        }
+
                         throw new AuthenticationException(SR.net_quic_cert_custom_validation);
                     }
 
@@ -430,6 +476,11 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                 if (sslPolicyErrors != SslPolicyErrors.None)
                 {
+                    if (state.IsServer)
+                    {
+                        return MsQuicStatusCodes.HandshakeFailure;
+                    }
+
                     throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors));
                 }
 
index 707f822..7da9eb1 100644 (file)
@@ -3,6 +3,7 @@
 
 using System.Buffers;
 using System.Collections.Generic;
+using System.Collections.Concurrent;
 using System.Diagnostics;
 using System.Net.Quic.Implementations.MsQuic.Internal;
 using System.Net.Security;
@@ -25,7 +26,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private readonly IPEndPoint _listenEndPoint;
 
-        private sealed class State
+        internal sealed class State
         {
             // set immediately in ctor, but we need a GCHandle to State in order to create the handle.
             public SafeMsQuicListenerHandle Handle = null!;
@@ -33,6 +34,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration;
             public readonly Channel<MsQuicConnection> AcceptConnectionQueue;
+            public readonly ConcurrentDictionary<IntPtr, MsQuicConnection> PendingConnections;
 
             public QuicOptions ConnectionOptions = new QuicOptions();
             public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions();
@@ -66,6 +68,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                     ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions);
                 }
 
+                PendingConnections = new ConcurrentDictionary<IntPtr, MsQuicConnection>();
                 AcceptConnectionQueue = Channel.CreateBounded<MsQuicConnection>(new BoundedChannelOptions(options.ListenBacklog)
                 {
                     SingleReader = true,
@@ -234,7 +237,6 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             SafeMsQuicConnectionHandle? connectionHandle = null;
             MsQuicConnection? msQuicConnection = null;
-
             try
             {
                 ref NewConnectionInfo connectionInfo = ref *evt.Data.NewConnection.Info;
@@ -278,13 +280,15 @@ namespace System.Net.Quic.Implementations.MsQuic
                 uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration);
                 if (MsQuicStatusHelper.SuccessfulStatusCode(status))
                 {
-                    msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
+                    msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, state, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
                     msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);
 
-                    if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
+                    if (!state.PendingConnections.TryAdd(connectionHandle.DangerousGetHandle(), msQuicConnection))
                     {
-                        return MsQuicStatusCodes.Success;
+                        msQuicConnection.Dispose();
                     }
+
+                    return MsQuicStatusCodes.Success;
                 }
 
                 // If we fall-through here something wrong happened.
index 253707d..59833d8 100644 (file)
@@ -103,10 +103,47 @@ namespace System.Net.Quic.Tests
         }
 
         [Fact]
+        [PlatformSpecific(TestPlatforms.Windows)]
+        public async Task UntrustedClientCertificateFails()
+        {
+            var listenerOptions = new QuicListenerOptions();
+            listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
+            listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+            listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
+            listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+            {
+                return false;
+            };
+
+            using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
+            QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
+            clientOptions.RemoteEndPoint = listener.ListenEndPoint;
+            clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
+            QuicConnection clientConnection = CreateQuicConnection(clientOptions);
+
+            using CancellationTokenSource cts = new CancellationTokenSource();
+            cts.CancelAfter(500); //Some delay to see if we would get failed connection.
+            Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
+
+            ValueTask t = clientConnection.ConnectAsync(cts.Token);
+
+            t.AsTask().Wait(PassingTestTimeout);
+            await Assert.ThrowsAsync<OperationCanceledException>(() => serverTask);
+            // The task will likely succed but we don't really care.
+            // It may fail if the server aborts quickly.
+            try
+            {
+                await t;
+            }
+            catch { };
+        }
+
+        [Fact]
         public async Task CertificateCallbackThrowPropagates()
         {
             using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout);
             X509Certificate? receivedCertificate = null;
+            bool validationResult = false;
 
             var listenerOptions = new QuicListenerOptions();
             listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
@@ -118,18 +155,26 @@ namespace System.Net.Quic.Tests
             clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
             {
                 receivedCertificate = cert;
+                if (validationResult)
+                {
+                    return validationResult;
+                }
+
                 throw new ArithmeticException("foobar");
             };
 
             clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1";
             QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
 
-            Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
             await Assert.ThrowsAsync<ArithmeticException>(() => clientConnection.ConnectAsync(cts.Token).AsTask());
-            QuicConnection serverConnection = await serverTask;
 
             Assert.Equal(listenerOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate);
+            clientConnection.Dispose();
 
+            // Make sure the listner is still usable and there is no lingering bad conenction
+            validationResult = true;
+            (clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener);
+            await PingPong(clientConnection, serverConnection);
             clientConnection.Dispose();
             serverConnection.Dispose();
         }
@@ -253,7 +298,6 @@ namespace System.Net.Quic.Tests
             using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
             ValueTask clientTask = clientConnection.ConnectAsync();
 
-            using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
             await Assert.ThrowsAsync<AuthenticationException>(async () => await clientTask);
         }
 
@@ -284,9 +328,11 @@ namespace System.Net.Quic.Tests
             (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions);
         }
 
-        [Fact]
+        [Theory]
         [PlatformSpecific(TestPlatforms.Windows)]
-        public async Task ConnectWithClientCertificate()
+        [InlineData(true)]
+        // [InlineData(false)] [ActiveIssue("https://github.com/dotnet/runtime/issues/57308")]
+        public async Task ConnectWithClientCertificate(bool sendCerttificate)
         {
             bool clientCertificateOK = false;
 
@@ -296,9 +342,12 @@ namespace System.Net.Quic.Tests
             listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
             listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
             {
-                _output.WriteLine("client certificate {0}", cert);
-                Assert.NotNull(cert);
-                Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);
+                if (sendCerttificate)
+                {
+                    _output.WriteLine("client certificate {0}", cert);
+                    Assert.NotNull(cert);
+                    Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);
+                }
 
                 clientCertificateOK = true;
                 return true;
@@ -306,7 +355,10 @@ namespace System.Net.Quic.Tests
 
             using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
             QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
-            clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
+            if (sendCerttificate)
+            {
+                clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
+            }
             (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener);
 
             // Verify functionality of the connections.