From 09ba2200f06d5f4ab9963b56dd3d96aae0465779 Mon Sep 17 00:00:00 2001 From: Tomas Weinfurt Date: Mon, 16 Aug 2021 18:35:26 -0700 Subject: [PATCH] Fix handling of new connection in MsQuicListener (#57319) * update * feedback from review --- .../Implementations/MsQuic/MsQuicConnection.cs | 57 +++++++++++++++++- .../Quic/Implementations/MsQuic/MsQuicListener.cs | 14 +++-- .../tests/FunctionalTests/MsQuicTests.cs | 70 +++++++++++++++++++--- 3 files changed, 124 insertions(+), 17 deletions(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index 1eba27e..ddd002d 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -46,6 +46,7 @@ namespace System.Net.Quic.Implementations.MsQuic // These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown). public MsQuicConnection? Connection; + public MsQuicListener.State? ListenerState; public TaskCompletionSource? ConnectTcs; // TODO: only allocate these when there is an outstanding shutdown. @@ -135,11 +136,10 @@ namespace System.Net.Quic.Implementations.MsQuic internal string TraceId() => _state.TraceId; // constructor for inbound connections - public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null) + public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, MsQuicListener.State listenerState, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null) { _state.Handle = handle; _state.StateGCHandle = GCHandle.Alloc(_state); - _state.Connected = true; _state.RemoteCertificateRequired = remoteCertificateRequired; _state.RevocationMode = revocationMode; _state.RemoteCertificateValidationCallback = remoteCertificateValidationCallback; @@ -161,6 +161,7 @@ namespace System.Net.Quic.Implementations.MsQuic throw; } + _state.ListenerState = listenerState; _state.TraceId = MsQuicTraceHelper.GetTraceId(_state.Handle); if (NetEventSource.Log.IsEnabled()) { @@ -223,7 +224,34 @@ namespace System.Net.Quic.Implementations.MsQuic private static uint HandleEventConnected(State state, ref ConnectionEvent connectionEvent) { - if (!state.Connected) + if (state.Connected) + { + return MsQuicStatusCodes.Success; + } + + if (state.IsServer) + { + state.Connected = true; + MsQuicListener.State? listenerState = state.ListenerState; + state.ListenerState = null; + + if (listenerState != null) + { + if (listenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection)) + { + // Move connection from pending to Accept queue and hand it out. + if (listenerState.AcceptConnectionQueue.Writer.TryWrite(connection)) + { + return MsQuicStatusCodes.Success; + } + // Listener is closed + connection.Dispose(); + } + } + + return MsQuicStatusCodes.UserCanceled; + } + else { // Connected will already be true for connections accepted from a listener. Debug.Assert(!Monitor.IsEntered(state)); @@ -271,6 +299,18 @@ namespace System.Net.Quic.Implementations.MsQuic // This is the final event on the connection, so free the GCHandle used by the event callback. state.StateGCHandle.Free(); + if (state.ListenerState != null) + { + // This is inbound connection that never got connected - becasue of TLS validation or some other reason. + // Remove connection from pending queue and dispose it. + if (state.ListenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection)) + { + connection.Dispose(); + } + + state.ListenerState = null; + } + state.Connection = null; state.ShutdownTcs.SetResult(MsQuicStatusCodes.Success); @@ -297,6 +337,7 @@ namespace System.Net.Quic.Implementations.MsQuic { bidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException())); } + return MsQuicStatusCodes.Success; } @@ -418,6 +459,11 @@ namespace System.Net.Quic.Implementations.MsQuic if (!success) { + if (state.IsServer) + { + return MsQuicStatusCodes.UserCanceled; + } + throw new AuthenticationException(SR.net_quic_cert_custom_validation); } @@ -430,6 +476,11 @@ namespace System.Net.Quic.Implementations.MsQuic if (sslPolicyErrors != SslPolicyErrors.None) { + if (state.IsServer) + { + return MsQuicStatusCodes.HandshakeFailure; + } + throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors)); } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs index 707f822..7da9eb1 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs @@ -3,6 +3,7 @@ using System.Buffers; using System.Collections.Generic; +using System.Collections.Concurrent; using System.Diagnostics; using System.Net.Quic.Implementations.MsQuic.Internal; using System.Net.Security; @@ -25,7 +26,7 @@ namespace System.Net.Quic.Implementations.MsQuic private readonly IPEndPoint _listenEndPoint; - private sealed class State + internal sealed class State { // set immediately in ctor, but we need a GCHandle to State in order to create the handle. public SafeMsQuicListenerHandle Handle = null!; @@ -33,6 +34,7 @@ namespace System.Net.Quic.Implementations.MsQuic public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration; public readonly Channel AcceptConnectionQueue; + public readonly ConcurrentDictionary PendingConnections; public QuicOptions ConnectionOptions = new QuicOptions(); public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions(); @@ -66,6 +68,7 @@ namespace System.Net.Quic.Implementations.MsQuic ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions); } + PendingConnections = new ConcurrentDictionary(); AcceptConnectionQueue = Channel.CreateBounded(new BoundedChannelOptions(options.ListenBacklog) { SingleReader = true, @@ -234,7 +237,6 @@ namespace System.Net.Quic.Implementations.MsQuic SafeMsQuicConnectionHandle? connectionHandle = null; MsQuicConnection? msQuicConnection = null; - try { ref NewConnectionInfo connectionInfo = ref *evt.Data.NewConnection.Info; @@ -278,13 +280,15 @@ namespace System.Net.Quic.Implementations.MsQuic uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration); if (MsQuicStatusHelper.SuccessfulStatusCode(status)) { - msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback); + msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, state, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback); msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength); - if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection)) + if (!state.PendingConnections.TryAdd(connectionHandle.DangerousGetHandle(), msQuicConnection)) { - return MsQuicStatusCodes.Success; + msQuicConnection.Dispose(); } + + return MsQuicStatusCodes.Success; } // If we fall-through here something wrong happened. diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 253707d..59833d8 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -103,10 +103,47 @@ namespace System.Net.Quic.Tests } [Fact] + [PlatformSpecific(TestPlatforms.Windows)] + public async Task UntrustedClientCertificateFails() + { + var listenerOptions = new QuicListenerOptions(); + listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); + listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true; + listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + { + return false; + }; + + using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions); + QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(); + clientOptions.RemoteEndPoint = listener.ListenEndPoint; + clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate }; + QuicConnection clientConnection = CreateQuicConnection(clientOptions); + + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.CancelAfter(500); //Some delay to see if we would get failed connection. + Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); + + ValueTask t = clientConnection.ConnectAsync(cts.Token); + + t.AsTask().Wait(PassingTestTimeout); + await Assert.ThrowsAsync(() => serverTask); + // The task will likely succed but we don't really care. + // It may fail if the server aborts quickly. + try + { + await t; + } + catch { }; + } + + [Fact] public async Task CertificateCallbackThrowPropagates() { using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout); X509Certificate? receivedCertificate = null; + bool validationResult = false; var listenerOptions = new QuicListenerOptions(); listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0); @@ -118,18 +155,26 @@ namespace System.Net.Quic.Tests clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { receivedCertificate = cert; + if (validationResult) + { + return validationResult; + } + throw new ArithmeticException("foobar"); }; clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1"; QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions); - Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); await Assert.ThrowsAsync(() => clientConnection.ConnectAsync(cts.Token).AsTask()); - QuicConnection serverConnection = await serverTask; Assert.Equal(listenerOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate); + clientConnection.Dispose(); + // Make sure the listner is still usable and there is no lingering bad conenction + validationResult = true; + (clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener); + await PingPong(clientConnection, serverConnection); clientConnection.Dispose(); serverConnection.Dispose(); } @@ -253,7 +298,6 @@ namespace System.Net.Quic.Tests using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions); ValueTask clientTask = clientConnection.ConnectAsync(); - using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); await Assert.ThrowsAsync(async () => await clientTask); } @@ -284,9 +328,11 @@ namespace System.Net.Quic.Tests (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions); } - [Fact] + [Theory] [PlatformSpecific(TestPlatforms.Windows)] - public async Task ConnectWithClientCertificate() + [InlineData(true)] + // [InlineData(false)] [ActiveIssue("https://github.com/dotnet/runtime/issues/57308")] + public async Task ConnectWithClientCertificate(bool sendCerttificate) { bool clientCertificateOK = false; @@ -296,9 +342,12 @@ namespace System.Net.Quic.Tests listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true; listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { - _output.WriteLine("client certificate {0}", cert); - Assert.NotNull(cert); - Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint); + if (sendCerttificate) + { + _output.WriteLine("client certificate {0}", cert); + Assert.NotNull(cert); + Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint); + } clientCertificateOK = true; return true; @@ -306,7 +355,10 @@ namespace System.Net.Quic.Tests using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions); QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(); - clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate }; + if (sendCerttificate) + { + clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate }; + } (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener); // Verify functionality of the connections. -- 2.7.4