fix SNI handling in quic (#55468)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Fri, 23 Jul 2021 08:29:29 +0000 (01:29 -0700)
committerGitHub <noreply@github.com>
Fri, 23 Jul 2021 08:29:29 +0000 (10:29 +0200)
* fix SNI handling in quic'

* cut ServerOptionsSelectionCallback

* feedback from review

* feedback from review

src/libraries/System.Net.Quic/src/Resources/Strings.resx
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs

index a29352a..51c31a9 100644 (file)
   <data name="net_quic_writing_notallowed" xml:space="preserve">
     <value>Writing is not allowed on stream.</value>
   </data>
+  <data name="net_quic_ssl_option" xml:space="preserve">
+    <value>'{0}' is not supported by System.Net.Quic.</value>
+  </data>
+  <data name="net_quic_cert_custom_validation" xml:space="preserve">
+    <value>The remote certificate was rejected by the provided RemoteCertificateValidationCallback.</value>
+  </data>
+  <data name="net_quic_cert_chain_validation" xml:space="preserve">
+    <value>The remote certificate is invalid because of errors in the certificate chain: {0}</value>
+  </data>
 </root>
 
index df48e0d..aa4c589 100644 (file)
@@ -17,7 +17,7 @@ namespace System.Net.Quic.Implementations.MsQuic.Internal
     internal sealed class SafeMsQuicConfigurationHandle : SafeHandle
     {
         private static readonly FieldInfo _contextCertificate = typeof(SslStreamCertificateContext).GetField("Certificate", BindingFlags.NonPublic | BindingFlags.Instance)!;
-        private static readonly FieldInfo _contextChain= typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!;
+        private static readonly FieldInfo _contextChain = typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!;
 
         public override bool IsInvalid => handle == IntPtr.Zero;
 
@@ -33,7 +33,7 @@ namespace System.Net.Quic.Implementations.MsQuic.Internal
         }
 
         // TODO: consider moving the static code from here to keep all the handle classes small and simple.
-        public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options)
+        public static SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options)
         {
             X509Certificate? certificate = null;
             if (options.ClientAuthenticationOptions?.ClientCertificates != null)
@@ -56,15 +56,35 @@ namespace System.Net.Quic.Implementations.MsQuic.Internal
             return Create(options, QUIC_CREDENTIAL_FLAGS.CLIENT, certificate: certificate, certificateContext: null, options.ClientAuthenticationOptions?.ApplicationProtocols);
         }
 
-        public static unsafe SafeMsQuicConfigurationHandle Create(QuicListenerOptions options)
+        public static SafeMsQuicConfigurationHandle Create(QuicOptions options, SslServerAuthenticationOptions? serverAuthenticationOptions, string? targetHost = null)
         {
             QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE;
-            if (options.ServerAuthenticationOptions != null && options.ServerAuthenticationOptions.ClientCertificateRequired)
+            X509Certificate? certificate = serverAuthenticationOptions?.ServerCertificate;
+
+            if (serverAuthenticationOptions != null)
             {
-                flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION;
+                if (serverAuthenticationOptions.CipherSuitesPolicy != null)
+                {
+                    throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(serverAuthenticationOptions.CipherSuitesPolicy)));
+                }
+
+                if (serverAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.NoEncryption)
+                {
+                    throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(serverAuthenticationOptions.EncryptionPolicy)));
+                }
+
+                if (serverAuthenticationOptions.ClientCertificateRequired)
+                {
+                    flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION;
+                }
+
+                if (certificate == null && serverAuthenticationOptions?.ServerCertificateSelectionCallback != null && targetHost != null)
+                {
+                    certificate = serverAuthenticationOptions.ServerCertificateSelectionCallback(options, targetHost);
+                }
             }
 
-            return Create(options, flags, options.ServerAuthenticationOptions?.ServerCertificate, options.ServerAuthenticationOptions?.ServerCertificateContext, options.ServerAuthenticationOptions?.ApplicationProtocols);
+            return Create(options, flags, certificate, serverAuthenticationOptions?.ServerCertificateContext, serverAuthenticationOptions?.ApplicationProtocols);
         }
 
         // TODO: this is called from MsQuicListener and when it fails it wreaks havoc in MsQuicListener finalizer.
index 6fb6588..f3dcbf2 100644 (file)
@@ -7,6 +7,7 @@ using System.Net.Security;
 using System.Net.Sockets;
 using System.Runtime.ExceptionServices;
 using System.Runtime.InteropServices;
+using System.Security.Authentication;
 using System.Security.Cryptography;
 using System.Security.Cryptography.X509Certificates;
 using System.Threading;
@@ -35,10 +36,6 @@ namespace System.Net.Quic.Implementations.MsQuic
         private IPEndPoint? _localEndPoint;
         private readonly EndPoint _remoteEndPoint;
         private SslApplicationProtocol _negotiatedAlpnProtocol;
-        private bool _isServer;
-        private bool _remoteCertificateRequired;
-        private X509RevocationMode _revocationMode = X509RevocationMode.Offline;
-        private RemoteCertificateValidationCallback? _remoteCertificateValidationCallback;
 
         internal sealed class State
         {
@@ -50,8 +47,8 @@ namespace System.Net.Quic.Implementations.MsQuic
             // These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown).
             public MsQuicConnection? Connection;
 
-            // TODO: only allocate these when there is an outstanding connect/shutdown.
-            public readonly TaskCompletionSource<uint> ConnectTcs = new TaskCompletionSource<uint>(TaskCreationOptions.RunContinuationsAsynchronously);
+            public TaskCompletionSource<uint>? ConnectTcs;
+            // TODO: only allocate these when there is an outstanding shutdown.
             public readonly TaskCompletionSource<uint> ShutdownTcs = new TaskCompletionSource<uint>(TaskCreationOptions.RunContinuationsAsynchronously);
 
             // Note that there's no such thing as resetable TCS, so we cannot reuse the same instance after we've set the result.
@@ -65,6 +62,13 @@ namespace System.Net.Quic.Implementations.MsQuic
             public int StreamCount;
             private bool _closing;
 
+            // Certificate validation properties
+            public bool RemoteCertificateRequired;
+            public X509RevocationMode RevocationMode = X509RevocationMode.Offline;
+            public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback;
+            public bool IsServer;
+            public string? TargetHost;
+
             // Queue for accepted streams.
             // Backlog limit is managed by MsQuic so it can be unbounded here.
             public readonly Channel<MsQuicStream> AcceptQueue = Channel.CreateUnbounded<MsQuicStream>(new UnboundedChannelOptions()
@@ -131,26 +135,17 @@ namespace System.Net.Quic.Implementations.MsQuic
         internal string TraceId() => _state.TraceId;
 
         // constructor for inbound connections
-        public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null)
+        public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
         {
             _state.Handle = handle;
             _state.StateGCHandle = GCHandle.Alloc(_state);
             _state.Connected = true;
-            _isServer = true;
+            _state.RemoteCertificateRequired = remoteCertificateRequired;
+            _state.RevocationMode = revocationMode;
+            _state.RemoteCertificateValidationCallback = remoteCertificateValidationCallback;
+            _state.IsServer = true;
             _localEndPoint = localEndPoint;
             _remoteEndPoint = remoteEndPoint;
-            _remoteCertificateRequired = remoteCertificateRequired;
-            _revocationMode = revocationMode;
-            _remoteCertificateValidationCallback = remoteCertificateValidationCallback;
-
-            if (_remoteCertificateRequired)
-            {
-                // We need to link connection for the validation callback.
-                // We need to be able to find the connection in HandleEventPeerCertificateReceived
-                // and dispatch it as sender to validation callback.
-                // After that Connection will be set back to null.
-                _state.Connection = this;
-            }
 
             try
             {
@@ -177,12 +172,12 @@ namespace System.Net.Quic.Implementations.MsQuic
         {
             _remoteEndPoint = options.RemoteEndPoint!;
             _configuration = SafeMsQuicConfigurationHandle.Create(options);
-            _isServer = false;
-            _remoteCertificateRequired = true;
+            _state.RemoteCertificateRequired = true;
             if (options.ClientAuthenticationOptions != null)
             {
-                _revocationMode = options.ClientAuthenticationOptions.CertificateRevocationCheckMode;
-                _remoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback;
+                _state.RevocationMode = options.ClientAuthenticationOptions.CertificateRevocationCheckMode;
+                _state.RemoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback;
+                _state.TargetHost = options.ClientAuthenticationOptions.TargetHost;
             }
 
             _state.StateGCHandle = GCHandle.Alloc(_state);
@@ -231,7 +226,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 state.Connection = null;
 
                 state.Connected = true;
-                state.ConnectTcs.SetResult(MsQuicStatusCodes.Success);
+                state.ConnectTcs!.SetResult(MsQuicStatusCodes.Success);
             }
 
             return MsQuicStatusCodes.Success;
@@ -239,14 +234,15 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private static uint HandleEventShutdownInitiatedByTransport(State state, ref ConnectionEvent connectionEvent)
         {
-            if (!state.Connected)
+            if (!state.Connected && state.ConnectTcs != null)
             {
                 Debug.Assert(state.Connection != null);
                 state.Connection = null;
 
                 uint hresult = connectionEvent.Data.ShutdownInitiatedByTransport.Status;
                 Exception ex = QuicExceptionHelpers.CreateExceptionForHResult(hresult, "Connection has been shutdown by transport.");
-                state.ConnectTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(ex));
+                state.ConnectTcs!.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(ex));
+                state.ConnectTcs = null;
             }
 
             state.AcceptQueue.Writer.TryComplete();
@@ -345,17 +341,6 @@ namespace System.Net.Quic.Implementations.MsQuic
             X509Certificate2? certificate = null;
             X509Certificate2Collection? additionalCertificates = null;
 
-            MsQuicConnection? connection = state.Connection;
-            if (connection == null)
-            {
-                return MsQuicStatusCodes.InvalidState;
-            }
-
-            if (connection._isServer)
-            {
-                state.Connection = null;
-            }
-
             try
             {
                 if (connectionEvent.Data.PeerCertificateReceived.PlatformCertificateHandle != IntPtr.Zero)
@@ -386,15 +371,15 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                 if (certificate == null)
                 {
-                    if (NetEventSource.Log.IsEnabled() && connection._remoteCertificateRequired) NetEventSource.Error(state, $"{state.TraceId} Remote certificate required, but no remote certificate received");
+                    if (NetEventSource.Log.IsEnabled() && state.RemoteCertificateRequired) NetEventSource.Error(state, $"{state.TraceId} Remote certificate required, but no remote certificate received");
                     sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable;
                 }
                 else
                 {
                     chain = new X509Chain();
-                    chain.ChainPolicy.RevocationMode = connection._revocationMode;
+                    chain.ChainPolicy.RevocationMode = state.RevocationMode;
                     chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot;
-                    chain.ChainPolicy.ApplicationPolicy.Add(connection._isServer ? s_clientAuthOid : s_serverAuthOid);
+                    chain.ChainPolicy.ApplicationPolicy.Add(state.IsServer ? s_clientAuthOid : s_serverAuthOid);
 
                     if (additionalCertificates != null && additionalCertificates.Count > 1)
                     {
@@ -407,34 +392,46 @@ namespace System.Net.Quic.Implementations.MsQuic
                     }
                 }
 
-                if (!connection._remoteCertificateRequired)
+                if (!state.RemoteCertificateRequired)
                 {
                     sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateNotAvailable;
                 }
 
-                if (connection._remoteCertificateValidationCallback != null)
+                if (state.RemoteCertificateValidationCallback != null)
                 {
-                    bool success = connection._remoteCertificateValidationCallback(connection, certificate, chain, sslPolicyErrors);
+                    bool success = state.RemoteCertificateValidationCallback(state, certificate, chain, sslPolicyErrors);
                     // Unset the callback to prevent multiple invocations of the callback per a single connection.
                     // Return the same value as the custom callback just did.
-                    connection._remoteCertificateValidationCallback = (_, _, _, _) => success;
+                    state.RemoteCertificateValidationCallback = (_, _, _, _) => success;
 
                     if (!success && NetEventSource.Log.IsEnabled())
                         NetEventSource.Error(state, $"{state.TraceId} Remote certificate rejected by verification callback");
-                    return success ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
+
+                    if (!success)
+                    {
+                        throw new AuthenticationException(SR.net_quic_cert_custom_validation);
+                    }
+
+                    return MsQuicStatusCodes.Success;
                 }
 
                 if (NetEventSource.Log.IsEnabled())
                     NetEventSource.Info(state, $"{state.TraceId} Certificate validation for '${certificate?.Subject}' finished with ${sslPolicyErrors}");
 
-                return (sslPolicyErrors == SslPolicyErrors.None) ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
+//                return (sslPolicyErrors == SslPolicyErrors.None) ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
+
+                if (sslPolicyErrors != SslPolicyErrors.None)
+                {
+                    throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors));
+                }
+
+                return MsQuicStatusCodes.Success;
             }
             catch (Exception ex)
             {
                 if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"{state.TraceId} Certificate validation failed ${ex.Message}");
+                throw;
             }
-
-            return MsQuicStatusCodes.InternalError;
         }
 
         internal override async ValueTask<QuicStreamProvider> AcceptStreamAsync(CancellationToken cancellationToken = default)
@@ -544,13 +541,6 @@ namespace System.Net.Quic.Implementations.MsQuic
                 throw new Exception($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener.");
             }
 
-            (string address, int port) = _remoteEndPoint switch
-            {
-                DnsEndPoint dnsEp => (dnsEp.Host, dnsEp.Port),
-                IPEndPoint ipEp => (ipEp.Address.ToString(), ipEp.Port),
-                _ => throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'.")
-            };
-
             QUIC_ADDRESS_FAMILY af = _remoteEndPoint.AddressFamily switch
             {
                 AddressFamily.Unspecified => QUIC_ADDRESS_FAMILY.UNSPEC,
@@ -562,13 +552,43 @@ namespace System.Net.Quic.Implementations.MsQuic
             Debug.Assert(_state.StateGCHandle.IsAllocated);
 
             _state.Connection = this;
+            uint status;
+            string targetHost;
+            int port;
+
+            if (_remoteEndPoint is IPEndPoint)
+            {
+                SOCKADDR_INET address = MsQuicAddressHelpers.IPEndPointToINet((IPEndPoint)_remoteEndPoint);
+                unsafe
+                {
+                    status = MsQuicApi.Api.SetParamDelegate(_state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.REMOTE_ADDRESS, (uint)sizeof(SOCKADDR_INET), (byte*)&address);
+                    QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer.");
+                }
+
+                targetHost = _state.TargetHost ?? ((IPEndPoint)_remoteEndPoint).Address.ToString();
+                port = ((IPEndPoint)_remoteEndPoint).Port;
+
+            }
+            else if (_remoteEndPoint is DnsEndPoint)
+            {
+                // We don't have way how to set separate SNI and name for connection at this moment.
+                targetHost = ((DnsEndPoint)_remoteEndPoint).Host;
+                port = ((DnsEndPoint)_remoteEndPoint).Port;
+            }
+            else
+            {
+                throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'.");
+            }
+
+            _state.ConnectTcs = new TaskCompletionSource<uint>(TaskCreationOptions.RunContinuationsAsynchronously);
+
             try
             {
-                uint status = MsQuicApi.Api.ConnectionStartDelegate(
+                status = MsQuicApi.Api.ConnectionStartDelegate(
                     _state.Handle,
                     _configuration,
                     af,
-                    address,
+                    targetHost,
                     (ushort)port);
 
                 QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer.");
@@ -665,10 +685,18 @@ namespace System.Net.Quic.Implementations.MsQuic
                     NetEventSource.Error(state, $"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}");
                 }
 
-                Debug.Fail($"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}");
+                if (state.ConnectTcs != null)
+                {
+                    state.ConnectTcs.SetException(ex);
+                    state.ConnectTcs = null;
+                    state.Connection = null;
+                }
+                else
+                {
+                    Debug.Fail($"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}");
+                }
 
                 // TODO: trigger an exception on any outstanding async calls.
-
                 return MsQuicStatusCodes.InternalError;
             }
         }
@@ -709,7 +737,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 return;
             }
 
-            if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Stream disposing {disposing}");
+            if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Connection disposing {disposing}");
 
             // If we haven't already shutdown gracefully (via a successful CloseAsync call), then force an abortive shutdown.
             if (_state.Handle != null)
index 391d26a..b5daea6 100644 (file)
@@ -12,6 +12,7 @@ using System.Threading;
 using System.Threading.Channels;
 using System.Threading.Tasks;
 using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods;
+using System.Security.Authentication;
 
 namespace System.Net.Quic.Implementations.MsQuic
 {
@@ -31,21 +32,39 @@ namespace System.Net.Quic.Implementations.MsQuic
             public SafeMsQuicListenerHandle Handle = null!;
             public string TraceId = null!; // set in ctor.
 
-            public readonly SafeMsQuicConfigurationHandle ConnectionConfiguration;
+            public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration;
             public readonly Channel<MsQuicConnection> AcceptConnectionQueue;
 
-            public bool RemoteCertificateRequired;
-            public X509RevocationMode RevocationMode = X509RevocationMode.Offline;
-            public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback;
+            public QuicOptions ConnectionOptions = new QuicOptions();
+            public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions();
 
             public State(QuicListenerOptions options)
             {
-                ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options);
+                ConnectionOptions.IdleTimeout = options.IdleTimeout;
+                ConnectionOptions.MaxBidirectionalStreams = options.MaxBidirectionalStreams;
+                ConnectionOptions.MaxUnidirectionalStreams = options.MaxUnidirectionalStreams;
+
+                bool delayConfiguration = false;
+
                 if (options.ServerAuthenticationOptions != null)
                 {
-                    RemoteCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired;
-                    RevocationMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode;
-                    RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback;
+                    AuthenticationOptions.ClientCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired;
+                    AuthenticationOptions.CertificateRevocationCheckMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode;
+                    AuthenticationOptions.RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback;
+                    AuthenticationOptions.ServerCertificateSelectionCallback = options.ServerAuthenticationOptions.ServerCertificateSelectionCallback;
+                    AuthenticationOptions.ApplicationProtocols = options.ServerAuthenticationOptions.ApplicationProtocols;
+
+                    if (options.ServerAuthenticationOptions.ServerCertificate == null && options.ServerAuthenticationOptions.ServerCertificateContext == null &&
+                        options.ServerAuthenticationOptions.ServerCertificateSelectionCallback != null)
+                    {
+                        // We don't have any certificate but we have selection callback so we need to wait for SNI.
+                        delayConfiguration = true;
+                    }
+                }
+
+                if (!delayConfiguration)
+                {
+                    ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions);
                 }
 
                 AcceptConnectionQueue = Channel.CreateBounded<MsQuicConnection>(new BoundedChannelOptions(options.ListenBacklog)
@@ -211,6 +230,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             var state = (State)gcHandle.Target;
 
             SafeMsQuicConnectionHandle? connectionHandle = null;
+            MsQuicConnection? msQuicConnection = null;
 
             try
             {
@@ -218,24 +238,53 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                 IPEndPoint localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref *(SOCKADDR_INET*)connectionInfo.LocalAddress);
                 IPEndPoint remoteEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref *(SOCKADDR_INET*)connectionInfo.RemoteAddress);
+                string targetHost = string.Empty;   // compat with SslStream
+                if (connectionInfo.ServerNameLength > 0 && connectionInfo.ServerName != IntPtr.Zero)
+                {
+                    // TBD We should figure out what to do with international names.
+                    targetHost = Marshal.PtrToStringAnsi(connectionInfo.ServerName, connectionInfo.ServerNameLength);
+                }
 
-                connectionHandle = new SafeMsQuicConnectionHandle(evt.Data.NewConnection.Connection);
+                SafeMsQuicConfigurationHandle? connectionConfiguration = state.ConnectionConfiguration;
 
-                uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, state.ConnectionConfiguration);
-                QuicExceptionHelpers.ThrowIfFailed(status, "ConnectionSetConfiguration failed.");
+                if (connectionConfiguration == null)
+                {
+                    Debug.Assert(state.AuthenticationOptions.ServerCertificateSelectionCallback != null);
+                    try
+                    {
+                        // ServerCertificateSelectionCallback is synchronous. We will call it as needed when building configuration
+                        connectionConfiguration = SafeMsQuicConfigurationHandle.Create(state.ConnectionOptions, state.AuthenticationOptions, targetHost);
+                    }
+                    catch (Exception ex)
+                    {
+                        if (NetEventSource.Log.IsEnabled())
+                        {
+                            NetEventSource.Error(state, $"[Listener#{state.GetHashCode()}] Exception occurred during creating configuration in connection callback: {ex}");
+                        }
+                    }
+
+                    if (connectionConfiguration == null)
+                    {
+                        // We don't have safe handle yet so MsQuic will cleanup new connection.
+                        return MsQuicStatusCodes.InternalError;
+                    }
+                }
 
-                var msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.RemoteCertificateRequired, state.RevocationMode, state.RemoteCertificateValidationCallback);
-                msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);
+                connectionHandle = new SafeMsQuicConnectionHandle(evt.Data.NewConnection.Connection);
 
-                if (!state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
+                uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration);
+                if (MsQuicStatusHelper.SuccessfulStatusCode(status))
                 {
-                    // This handle will be cleaned up by MsQuic.
-                    connectionHandle.SetHandleAsInvalid();
-                    msQuicConnection.Dispose();
-                    return MsQuicStatusCodes.InternalError;
+                    msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
+                    msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);
+
+                    if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
+                    {
+                        return MsQuicStatusCodes.Success;
+                    }
                 }
 
-                return MsQuicStatusCodes.Success;
+                // If we fall-through here something wrong happened.
             }
             catch (Exception ex)
             {
@@ -243,14 +292,12 @@ namespace System.Net.Quic.Implementations.MsQuic
                 {
                     NetEventSource.Error(state, $"[Listener#{state.GetHashCode()}] Exception occurred during handling {(QUIC_LISTENER_EVENT)evt.Type} connection callback: {ex}");
                 }
-
-                Debug.Fail($"[Listener#{state.GetHashCode()}] Exception occurred during handling {(QUIC_LISTENER_EVENT)evt.Type} connection callback: {ex}");
-
-                // This handle will be cleaned up by MsQuic by returning InternalError.
-                connectionHandle?.SetHandleAsInvalid();
-                state.AcceptConnectionQueue.Writer.TryComplete(ex);
-                return MsQuicStatusCodes.InternalError;
             }
+
+            // This handle will be cleaned up by MsQuic by returning InternalError.
+            connectionHandle?.SetHandleAsInvalid();
+            msQuicConnection?.Dispose();
+            return MsQuicStatusCodes.InternalError;
         }
 
         private void ThrowIfDisposed()
index 83846ba..e4486c9 100644 (file)
@@ -6,8 +6,11 @@ using System.Collections.Generic;
 using System.Diagnostics;
 using System.Linq;
 using System.Net.Security;
+using System.Net.Sockets;
+using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
 using System.Text;
+using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
 using Xunit.Abstractions;
@@ -116,6 +119,126 @@ namespace System.Net.Quic.Tests
         }
 
         [Fact]
+        public async Task CertificateCallbackThrowPropagates()
+        {
+            using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout);
+            X509Certificate? receivedCertificate = null;
+
+            var quicOptions = new QuicListenerOptions();
+            quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
+            quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+
+            using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
+
+            QuicClientConnectionOptions options = new QuicClientConnectionOptions()
+            {
+                RemoteEndPoint = listener.ListenEndPoint,
+                ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
+            };
+
+            options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+            {
+                receivedCertificate = cert;
+                throw new ArithmeticException("foobar");
+            };
+
+            options.ClientAuthenticationOptions.TargetHost = "foobar1";
+
+            QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+
+            Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
+            await Assert.ThrowsAsync<ArithmeticException>(() => clientConnection.ConnectAsync(cts.Token).AsTask());
+            QuicConnection serverConnection = await serverTask;
+
+            Assert.Equal(quicOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate);
+
+            clientConnection.Dispose();
+            serverConnection.Dispose();
+        }
+
+        [Fact]
+        public async Task ConnectWithCertificateCallback()
+        {
+            X509Certificate2 c1 = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
+            X509Certificate2 c2 = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate(); // This 'wrong' certificate but should be sufficient
+            X509Certificate2 expectedCertificate = c1;
+
+            using CancellationTokenSource cts = new CancellationTokenSource();
+            cts.CancelAfter(PassingTestTimeout);
+            string? receivedHostName = null;
+            X509Certificate? receivedCertificate = null;
+
+            var quicOptions = new QuicListenerOptions();
+            quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
+            quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+            quicOptions.ServerAuthenticationOptions.ServerCertificate = null;
+            quicOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, hostName) =>
+            {
+                receivedHostName = hostName;
+                if (hostName == "foobar1")
+                {
+                    return c1;
+                }
+                else if (hostName == "foobar2")
+                {
+                    return c2;
+                }
+
+                return null;
+            };
+
+            using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
+
+            QuicClientConnectionOptions options = new QuicClientConnectionOptions()
+            {
+                RemoteEndPoint = listener.ListenEndPoint,
+                ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
+            };
+
+            options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+            {
+                receivedCertificate = cert;
+                return true;
+            };
+
+            options.ClientAuthenticationOptions.TargetHost = "foobar1";
+
+            QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+
+            Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
+            await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
+            QuicConnection serverConnection = serverTask.Result;
+
+            Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+            Assert.Equal(c1, receivedCertificate);
+            clientConnection.Dispose();
+            serverConnection.Dispose();
+
+            // This should fail when callback return null.
+            options.ClientAuthenticationOptions.TargetHost = "foobar3";
+            clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+            Task clientTask = clientConnection.ConnectAsync(cts.Token).AsTask();
+
+            await Assert.ThrowsAsync<QuicException>(() => clientTask);
+            Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+            clientConnection.Dispose();
+
+            // Do this last to make sure Listener is still functional.
+            options.ClientAuthenticationOptions.TargetHost = "foobar2";
+            expectedCertificate = c2;
+
+            clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+            serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
+            await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
+            serverConnection = serverTask.Result;
+
+            Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+            Assert.Equal(c2, receivedCertificate);
+            clientConnection.Dispose();
+            serverConnection.Dispose();
+        }
+
+        [Fact]
         [PlatformSpecific(TestPlatforms.Windows)]
         [ActiveIssue("https://github.com/microsoft/msquic/pull/1728")]
         public async Task ConnectWithClientCertificate()
index e76591f..9a6e43a 100644 (file)
@@ -28,6 +28,9 @@ namespace System.Net.Quic.Tests
         public X509Certificate2 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
         public X509Certificate2 ClientCertificate = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate();
 
+        public const int PassingTestTimeoutMilliseconds = 4 * 60 * 1000;
+        public static TimeSpan PassingTestTimeout => TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds);
+
         public bool RemoteCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
         {
             Assert.Equal(ServerCertificate.GetCertHash(), certificate?.GetCertHash());
@@ -48,7 +51,8 @@ namespace System.Net.Quic.Tests
             return new SslClientAuthenticationOptions()
             {
                 ApplicationProtocols = new List<SslApplicationProtocol>() { ApplicationProtocol },
-                RemoteCertificateValidationCallback = RemoteCertificateValidationCallback
+                RemoteCertificateValidationCallback = RemoteCertificateValidationCallback,
+                TargetHost = "localhost"
             };
         }