// These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown).
public MsQuicConnection? Connection;
+ public MsQuicListener.State? ListenerState;
public TaskCompletionSource<uint>? ConnectTcs;
// TODO: only allocate these when there is an outstanding shutdown.
internal string TraceId() => _state.TraceId;
// constructor for inbound connections
- public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
+ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, MsQuicListener.State listenerState, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
{
_state.Handle = handle;
_state.StateGCHandle = GCHandle.Alloc(_state);
- _state.Connected = true;
_state.RemoteCertificateRequired = remoteCertificateRequired;
_state.RevocationMode = revocationMode;
_state.RemoteCertificateValidationCallback = remoteCertificateValidationCallback;
throw;
}
+ _state.ListenerState = listenerState;
_state.TraceId = MsQuicTraceHelper.GetTraceId(_state.Handle);
if (NetEventSource.Log.IsEnabled())
{
private static uint HandleEventConnected(State state, ref ConnectionEvent connectionEvent)
{
- if (!state.Connected)
+ if (state.Connected)
+ {
+ return MsQuicStatusCodes.Success;
+ }
+
+ if (state.IsServer)
+ {
+ state.Connected = true;
+ MsQuicListener.State? listenerState = state.ListenerState;
+ state.ListenerState = null;
+
+ if (listenerState != null)
+ {
+ if (listenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection))
+ {
+ // Move connection from pending to Accept queue and hand it out.
+ if (listenerState.AcceptConnectionQueue.Writer.TryWrite(connection))
+ {
+ return MsQuicStatusCodes.Success;
+ }
+ // Listener is closed
+ connection.Dispose();
+ }
+ }
+
+ return MsQuicStatusCodes.UserCanceled;
+ }
+ else
{
// Connected will already be true for connections accepted from a listener.
Debug.Assert(!Monitor.IsEntered(state));
// This is the final event on the connection, so free the GCHandle used by the event callback.
state.StateGCHandle.Free();
+ if (state.ListenerState != null)
+ {
+ // This is inbound connection that never got connected - becasue of TLS validation or some other reason.
+ // Remove connection from pending queue and dispose it.
+ if (state.ListenerState.PendingConnections.TryRemove(state.Handle.DangerousGetHandle(), out MsQuicConnection? connection))
+ {
+ connection.Dispose();
+ }
+
+ state.ListenerState = null;
+ }
+
state.Connection = null;
state.ShutdownTcs.SetResult(MsQuicStatusCodes.Success);
{
bidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
}
+
return MsQuicStatusCodes.Success;
}
if (!success)
{
+ if (state.IsServer)
+ {
+ return MsQuicStatusCodes.UserCanceled;
+ }
+
throw new AuthenticationException(SR.net_quic_cert_custom_validation);
}
if (sslPolicyErrors != SslPolicyErrors.None)
{
+ if (state.IsServer)
+ {
+ return MsQuicStatusCodes.HandshakeFailure;
+ }
+
throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors));
}
using System.Buffers;
using System.Collections.Generic;
+using System.Collections.Concurrent;
using System.Diagnostics;
using System.Net.Quic.Implementations.MsQuic.Internal;
using System.Net.Security;
private readonly IPEndPoint _listenEndPoint;
- private sealed class State
+ internal sealed class State
{
// set immediately in ctor, but we need a GCHandle to State in order to create the handle.
public SafeMsQuicListenerHandle Handle = null!;
public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration;
public readonly Channel<MsQuicConnection> AcceptConnectionQueue;
+ public readonly ConcurrentDictionary<IntPtr, MsQuicConnection> PendingConnections;
public QuicOptions ConnectionOptions = new QuicOptions();
public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions();
ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions);
}
+ PendingConnections = new ConcurrentDictionary<IntPtr, MsQuicConnection>();
AcceptConnectionQueue = Channel.CreateBounded<MsQuicConnection>(new BoundedChannelOptions(options.ListenBacklog)
{
SingleReader = true,
SafeMsQuicConnectionHandle? connectionHandle = null;
MsQuicConnection? msQuicConnection = null;
-
try
{
ref NewConnectionInfo connectionInfo = ref *evt.Data.NewConnection.Info;
uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration);
if (MsQuicStatusHelper.SuccessfulStatusCode(status))
{
- msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
+ msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, state, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);
- if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
+ if (!state.PendingConnections.TryAdd(connectionHandle.DangerousGetHandle(), msQuicConnection))
{
- return MsQuicStatusCodes.Success;
+ msQuicConnection.Dispose();
}
+
+ return MsQuicStatusCodes.Success;
}
// If we fall-through here something wrong happened.
}
[Fact]
+ [PlatformSpecific(TestPlatforms.Windows)]
+ public async Task UntrustedClientCertificateFails()
+ {
+ var listenerOptions = new QuicListenerOptions();
+ listenerOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
+ listenerOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+ listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
+ listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ {
+ return false;
+ };
+
+ using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
+ QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
+ clientOptions.RemoteEndPoint = listener.ListenEndPoint;
+ clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
+ QuicConnection clientConnection = CreateQuicConnection(clientOptions);
+
+ using CancellationTokenSource cts = new CancellationTokenSource();
+ cts.CancelAfter(500); //Some delay to see if we would get failed connection.
+ Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
+
+ ValueTask t = clientConnection.ConnectAsync(cts.Token);
+
+ t.AsTask().Wait(PassingTestTimeout);
+ await Assert.ThrowsAsync<OperationCanceledException>(() => serverTask);
+ // The task will likely succed but we don't really care.
+ // It may fail if the server aborts quickly.
+ try
+ {
+ await t;
+ }
+ catch { };
+ }
+
+ [Fact]
public async Task CertificateCallbackThrowPropagates()
{
using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout);
X509Certificate? receivedCertificate = null;
+ bool validationResult = false;
var listenerOptions = new QuicListenerOptions();
listenerOptions.ListenEndPoint = new IPEndPoint(Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
receivedCertificate = cert;
+ if (validationResult)
+ {
+ return validationResult;
+ }
+
throw new ArithmeticException("foobar");
};
clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1";
QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
- Task<QuicConnection> serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
await Assert.ThrowsAsync<ArithmeticException>(() => clientConnection.ConnectAsync(cts.Token).AsTask());
- QuicConnection serverConnection = await serverTask;
Assert.Equal(listenerOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate);
+ clientConnection.Dispose();
+ // Make sure the listner is still usable and there is no lingering bad conenction
+ validationResult = true;
+ (clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener);
+ await PingPong(clientConnection, serverConnection);
clientConnection.Dispose();
serverConnection.Dispose();
}
using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
ValueTask clientTask = clientConnection.ConnectAsync();
- using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await Assert.ThrowsAsync<AuthenticationException>(async () => await clientTask);
}
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions);
}
- [Fact]
+ [Theory]
[PlatformSpecific(TestPlatforms.Windows)]
- public async Task ConnectWithClientCertificate()
+ [InlineData(true)]
+ // [InlineData(false)] [ActiveIssue("https://github.com/dotnet/runtime/issues/57308")]
+ public async Task ConnectWithClientCertificate(bool sendCerttificate)
{
bool clientCertificateOK = false;
listenerOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
listenerOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
- _output.WriteLine("client certificate {0}", cert);
- Assert.NotNull(cert);
- Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);
+ if (sendCerttificate)
+ {
+ _output.WriteLine("client certificate {0}", cert);
+ Assert.NotNull(cert);
+ Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);
+ }
clientCertificateOK = true;
return true;
using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
- clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
+ if (sendCerttificate)
+ {
+ clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
+ }
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener);
// Verify functionality of the connections.