include more details in exception if remote certificate validation fails (#40110)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Sat, 8 Aug 2020 17:36:52 +0000 (10:36 -0700)
committerGitHub <noreply@github.com>
Sat, 8 Aug 2020 17:36:52 +0000 (10:36 -0700)
* include more details in exception if remote certificate validation fails

* fix unit test linking

* feedback from review

* update exception message

src/libraries/System.Net.Security/src/Resources/Strings.resx
src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs
src/libraries/System.Net.Security/tests/FunctionalTests/ServerAsyncAuthenticateTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs
src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs

index 1270f6b..e518415 100644 (file)
   <data name="net_io_eof" xml:space="preserve">
     <value> Received an unexpected EOF or 0 bytes from the transport stream.</value>
   </data>
-  <data name="net_io_async_result" xml:space="preserve">
-    <value>The parameter: {0} is not valid. Use the object returned from corresponding Begin async call.</value>
-  </data>
   <data name="net_ssl_io_frame" xml:space="preserve">
     <value>The handshake failed due to an unexpected packet format.</value>
   </data>
     <value>The remote party requested renegotiation when AllowRenegotiation was set to false.</value>
   </data>
   <data name="net_ssl_io_cert_validation" xml:space="preserve">
-    <value>The remote certificate is invalid according to the validation procedure.</value>
+    <value>The remote certificate is invalid according to the validation procedure: {0}</value>
+  </data>
+  <data name="net_ssl_io_cert_chain_validation" xml:space="preserve">
+    <value>The remote certificate is invalid because of errors in the certificate chain: {0}</value>
+  </data>
+  <data name="net_ssl_io_cert_custom_validation" xml:space="preserve">
+    <value>The remote certificate was rejected by the provided RemoteCertificateValidationCallback.</value>
   </data>
   <data name="net_ssl_io_no_server_cert" xml:space="preserve">
     <value>The server mode SSL must use a certificate with the associated private key.</value>
index b0af716..4d20cea 100644 (file)
@@ -25,7 +25,8 @@ namespace System.Net.Security
 
         private SslConnectionInfo? _connectionInfo;
         private X509Certificate? _selectedClientCertificate;
-        private bool _isRemoteCertificateAvailable;
+        private X509Certificate2? _remoteCertificate;
+        private bool _remoteCertificateExposed;
 
         // These are the MAX encrypt buffer output sizes, not the actual sizes.
         private int _headerSize = 5; //ATTN must be set to at least 5 by default
@@ -39,6 +40,7 @@ namespace System.Net.Security
 
         private static readonly Oid s_serverAuthOid = new Oid("1.3.6.1.5.5.7.3.1", "1.3.6.1.5.5.7.3.1");
         private static readonly Oid s_clientAuthOid = new Oid("1.3.6.1.5.5.7.3.2", "1.3.6.1.5.5.7.3.2");
+        private SslStream? _ssl;
 
         internal SecureChannel(SslAuthenticationOptions sslAuthenticationOptions, SslStream sslStream)
         {
@@ -54,6 +56,7 @@ namespace System.Net.Security
             _securityContext = null;
             _refreshCredentialNeeded = true;
             _sslAuthenticationOptions = sslAuthenticationOptions;
+            _ssl = sslStream;
         }
 
         //
@@ -85,7 +88,16 @@ namespace System.Net.Security
         {
             get
             {
-                return _isRemoteCertificateAvailable;
+                return _remoteCertificate != null;
+            }
+        }
+
+        internal X509Certificate? RemoteCertificate
+        {
+            get
+            {
+                _remoteCertificateExposed = true;
+                return _remoteCertificate;
             }
         }
 
@@ -164,8 +176,15 @@ namespace System.Net.Security
 
         internal void Close()
         {
+            if (!_remoteCertificateExposed)
+            {
+                  _remoteCertificate?.Dispose();
+                  _remoteCertificate = null;
+            }
+
             _securityContext?.Dispose();
             _credentialsHandle?.Dispose();
+            _ssl = null;
             GC.SuppressFinalize(this);
         }
 
@@ -585,7 +604,6 @@ namespace System.Net.Security
                 else
                 {
                     _credentialsHandle = SslStreamPal.AcquireCredentialsHandle(selectedCert!, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.EncryptionPolicy, _sslAuthenticationOptions.IsServer);
-
                     thumbPrint = guessedThumbPrint; // Delay until here in case something above threw.
                     _selectedClientCertificate = clientCertificate;
                 }
@@ -911,22 +929,21 @@ namespace System.Net.Security
         --*/
 
         //This method validates a remote certificate.
-        internal bool VerifyRemoteCertificate(RemoteCertValidationCallback? remoteCertValidationCallback, ref ProtocolToken? alertToken)
+        internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remoteCertValidationCallback, ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
         {
-            SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None;
+            sslPolicyErrors = SslPolicyErrors.None;
+            chainStatus = X509ChainStatusFlags.NoError;
 
             // We don't catch exceptions in this method, so it's safe for "accepted" be initialized with true.
             bool success = false;
             X509Chain? chain = null;
-            X509Certificate2? remoteCertificateEx = null;
             X509Certificate2Collection? remoteCertificateStore = null;
 
             try
             {
-                remoteCertificateEx = CertificateValidationPal.GetRemoteCertificate(_securityContext, out remoteCertificateStore);
-                _isRemoteCertificateAvailable = remoteCertificateEx != null;
+                _remoteCertificate = CertificateValidationPal.GetRemoteCertificate(_securityContext, out remoteCertificateStore);
 
-                if (remoteCertificateEx == null)
+                if (_remoteCertificate == null)
                 {
                     if (NetEventSource.Log.IsEnabled() && RemoteCertRequired) NetEventSource.Error(this, $"Remote certificate required, but no remote certificate received");
                     sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable;
@@ -948,7 +965,7 @@ namespace System.Net.Security
                     sslPolicyErrors |= CertificateValidationPal.VerifyCertificateProperties(
                         _securityContext!,
                         chain,
-                        remoteCertificateEx,
+                        _remoteCertificate,
                         _sslAuthenticationOptions.CheckCertName,
                         _sslAuthenticationOptions.IsServer,
                         _sslAuthenticationOptions.TargetHost);
@@ -956,30 +973,40 @@ namespace System.Net.Security
 
                 if (remoteCertValidationCallback != null)
                 {
-                    success = remoteCertValidationCallback(_sslAuthenticationOptions.TargetHost, remoteCertificateEx, chain, sslPolicyErrors);
+                    object? sender = _ssl;
+                    if (sender == null)
+                    {
+                        throw new ObjectDisposedException(nameof(SslStream));
+                    }
+
+                    success = remoteCertValidationCallback(sender, _remoteCertificate, chain, sslPolicyErrors);
                 }
                 else
                 {
-                    if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateNotAvailable && !_sslAuthenticationOptions.RemoteCertRequired)
-                    {
-                        success = true;
-                    }
-                    else
+                    if (!RemoteCertRequired)
                     {
-                        success = (sslPolicyErrors == SslPolicyErrors.None);
+                        sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateNotAvailable;
                     }
+
+                    success = (sslPolicyErrors == SslPolicyErrors.None);
                 }
 
                 if (NetEventSource.Log.IsEnabled())
                 {
                     LogCertificateValidation(remoteCertValidationCallback, sslPolicyErrors, success, chain!);
-                    if (NetEventSource.Log.IsEnabled())
-                        NetEventSource.Info(this, $"Cert validation, remote cert = {remoteCertificateEx}");
+                    NetEventSource.Info(this, $"Cert validation, remote cert = {_remoteCertificate}");
                 }
 
                 if (!success)
                 {
                     alertToken = CreateFatalHandshakeAlertToken(sslPolicyErrors, chain!);
+                    if (chain != null)
+                    {
+                        foreach (X509ChainStatus status in chain.ChainStatus)
+                        {
+                            chainStatus |= status.Status;
+                        }
+                    }
                 }
             }
             finally
@@ -1006,8 +1033,6 @@ namespace System.Net.Security
                         remoteCertificateStore[i].Dispose();
                     }
                 }
-
-                remoteCertificateEx?.Dispose();
             }
 
             return success;
@@ -1133,7 +1158,7 @@ namespace System.Net.Security
             return TlsAlertMessage.BadCertificate;
         }
 
-        private void LogCertificateValidation(RemoteCertValidationCallback? remoteCertValidationCallback, SslPolicyErrors sslPolicyErrors, bool success, X509Chain chain)
+        private void LogCertificateValidation(RemoteCertificateValidationCallback? remoteCertValidationCallback, SslPolicyErrors sslPolicyErrors, bool success, X509Chain chain)
         {
             if (!NetEventSource.Log.IsEnabled())
                 return;
index e229ddd..d3f99e6 100644 (file)
@@ -10,7 +10,7 @@ namespace System.Net.Security
 {
     internal class SslAuthenticationOptions
     {
-        internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertValidationCallback remoteCallback, LocalCertSelectionCallback? localCallback)
+        internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertificateValidationCallback? remoteCallback, LocalCertSelectionCallback? localCallback)
         {
             Debug.Assert(sslClientAuthenticationOptions.TargetHost != null);
 
@@ -78,15 +78,21 @@ namespace System.Net.Security
                     CertificateContext = SslStreamCertificateContext.Create(certificateWithKey);
                 }
             }
+
+            if (sslServerAuthenticationOptions.RemoteCertificateValidationCallback != null)
+            {
+                CertValidationDelegate = sslServerAuthenticationOptions.RemoteCertificateValidationCallback;
+            }
         }
 
-        internal SslAuthenticationOptions(ServerOptionsSelectionCallback optionCallback, object? state)
+        internal SslAuthenticationOptions(ServerOptionsSelectionCallback optionCallback, object? state, RemoteCertificateValidationCallback? remoteCallback)
         {
             CheckCertName = false;
             TargetHost = string.Empty;
             IsServer = true;
             UserState = state;
             ServerOptionDelegate = optionCallback;
+            CertValidationDelegate = remoteCallback;
         }
 
         internal void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions)
@@ -108,6 +114,11 @@ namespace System.Net.Security
                 // given cert is X509Certificate2 with key. We can use it directly.
                 CertificateContext = SslStreamCertificateContext.Create(certificateWithKey);
             }
+
+            if (sslServerAuthenticationOptions.RemoteCertificateValidationCallback != null)
+            {
+                CertValidationDelegate = sslServerAuthenticationOptions.RemoteCertificateValidationCallback;
+            }
         }
 
         private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols protocols)
@@ -136,7 +147,7 @@ namespace System.Net.Security
         internal EncryptionPolicy EncryptionPolicy { get; set; }
         internal bool RemoteCertRequired { get; set; }
         internal bool CheckCertName { get; set; }
-        internal RemoteCertValidationCallback? CertValidationDelegate { get; set; }
+        internal RemoteCertificateValidationCallback? CertValidationDelegate { get; set; }
         internal LocalCertSelectionCallback? CertSelectionDelegate { get; set; }
         internal ServerCertSelectionCallback? ServerCertSelectionDelegate { get; set; }
         internal CipherSuitesPolicy? CipherSuitesPolicy { get; set; }
index 5f3c5cd..5af70ed 100644 (file)
@@ -4,10 +4,10 @@
 using System.Buffers;
 using System.ComponentModel;
 using System.Diagnostics;
-using System.Globalization;
 using System.IO;
 using System.Runtime.ExceptionServices;
 using System.Security.Authentication;
+using System.Security.Cryptography.X509Certificates;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -41,7 +41,7 @@ namespace System.Net.Security
         private const int InitialHandshakeBufferSize = 4096 + FrameOverhead; // try to fit at least 4K ServerCertificate
         private ArrayBuffer _handshakeBuffer;
 
-        private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertValidationCallback remoteCallback, LocalCertSelectionCallback? localCallback)
+        private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertificateValidationCallback? remoteCallback, LocalCertSelectionCallback? localCallback)
         {
             ThrowIfExceptional();
 
@@ -321,9 +321,23 @@ namespace System.Net.Security
                 }
 
                 ProtocolToken? alertToken = null;
-                if (!CompleteHandshake(ref alertToken))
+                if (!CompleteHandshake(ref alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus))
                 {
-                    SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
+                    if (_sslAuthenticationOptions!.CertValidationDelegate != null)
+                    {
+                        // there may be some chain errors but the decision was made by custom callback. Details should be tracing if enabled.
+                        SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_custom_validation, null)));
+                    }
+                    else if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateChainErrors && chainStatus != X509ChainStatusFlags.NoError)
+                    {
+                        // We failed only because of chain and we have some insight.
+                        SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_chain_validation, chainStatus), null)));
+                    }
+                    else
+                    {
+                        // Simple add sslPolicyErrors as crude info.
+                        SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_validation, sslPolicyErrors), null)));
+                    }
                 }
             }
             finally
@@ -504,11 +518,11 @@ namespace System.Net.Security
         //
         // - Returns false if failed to verify the Remote Cert
         //
-        private bool CompleteHandshake(ref ProtocolToken? alertToken)
+        private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
         {
             _context!.ProcessHandshakeSuccess();
 
-            if (!_context.VerifyRemoteCertificate(_sslAuthenticationOptions!.CertValidationDelegate, ref alertToken))
+            if (!_context.VerifyRemoteCertificate(_sslAuthenticationOptions!.CertValidationDelegate, ref alertToken, out sslPolicyErrors, out chainStatus))
             {
                 _handshakeCompleted = false;
                 return false;
index 523935e..f896c0e 100644 (file)
@@ -37,7 +37,6 @@ namespace System.Net.Security
     public delegate ValueTask<SslServerAuthenticationOptions> ServerOptionsSelectionCallback(SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, CancellationToken cancellationToken);
 
     // Internal versions of the above delegates.
-    internal delegate bool RemoteCertValidationCallback(string? host, X509Certificate2? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors);
     internal delegate X509Certificate LocalCertSelectionCallback(string targetHost, X509CertificateCollection localCertificates, X509Certificate2? remoteCertificate, string[] acceptableIssuers);
     internal delegate X509Certificate ServerCertSelectionCallback(string? hostName);
 
@@ -46,13 +45,9 @@ namespace System.Net.Security
         /// <summary>Set as the _exception when the instance is disposed.</summary>
         private static readonly ExceptionDispatchInfo s_disposedSentinel = ExceptionDispatchInfo.Capture(new ObjectDisposedException(nameof(SslStream), (string?)null));
 
-        private X509Certificate2? _remoteCertificate;
-        private bool _remoteCertificateExposed;
-
         internal RemoteCertificateValidationCallback? _userCertificateValidationCallback;
         internal LocalCertificateSelectionCallback? _userCertificateSelectionCallback;
         internal ServerCertificateSelectionCallback? _userServerCertificateSelectionCallback;
-        internal RemoteCertValidationCallback _certValidationDelegate;
         internal LocalCertSelectionCallback? _certSelectionDelegate;
         internal EncryptionPolicy _encryptionPolicy;
 
@@ -106,7 +101,6 @@ namespace System.Net.Security
             _userCertificateValidationCallback = userCertificateValidationCallback;
             _userCertificateSelectionCallback = userCertificateSelectionCallback;
             _encryptionPolicy = encryptionPolicy;
-            _certValidationDelegate = new RemoteCertValidationCallback(UserCertValidationCallbackWrapper);
             _certSelectionDelegate = userCertificateSelectionCallback == null ? null : new LocalCertSelectionCallback(UserCertSelectionCallbackWrapper);
 
             _innerStream = innerStream;
@@ -130,7 +124,6 @@ namespace System.Net.Security
             if (_userCertificateValidationCallback == null)
             {
                 _userCertificateValidationCallback = callback;
-                _certValidationDelegate = new RemoteCertValidationCallback(UserCertValidationCallbackWrapper);
             }
             else if (callback != null && _userCertificateValidationCallback != callback)
             {
@@ -151,24 +144,6 @@ namespace System.Net.Security
             }
         }
 
-        private bool UserCertValidationCallbackWrapper(string? hostName, X509Certificate2? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
-        {
-            _remoteCertificate = certificate == null ? null : new X509Certificate2(certificate);
-            if (_userCertificateValidationCallback == null)
-            {
-                if (!RemoteCertRequired)
-                {
-                    sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateNotAvailable;
-                }
-
-                return (sslPolicyErrors == SslPolicyErrors.None);
-            }
-            else
-            {
-                return _userCertificateValidationCallback(this, certificate, chain, sslPolicyErrors);
-            }
-        }
-
         private X509Certificate UserCertSelectionCallbackWrapper(string targetHost, X509CertificateCollection localCertificates, X509Certificate? remoteCertificate, string[] acceptableIssuers)
         {
             return _userCertificateSelectionCallback!(this, targetHost, localCertificates, remoteCertificate, acceptableIssuers);
@@ -194,7 +169,7 @@ namespace System.Net.Security
             _userServerCertificateSelectionCallback = sslServerAuthenticationOptions.ServerCertificateSelectionCallback;
             authOptions.ServerCertSelectionDelegate = _userServerCertificateSelectionCallback == null ? null : new ServerCertSelectionCallback(ServerCertSelectionCallbackWrapper);
 
-            authOptions.CertValidationDelegate = _certValidationDelegate;
+            authOptions.CertValidationDelegate = _userCertificateValidationCallback;
             authOptions.CertSelectionDelegate = _certSelectionDelegate;
 
             return authOptions;
@@ -318,7 +293,7 @@ namespace System.Net.Security
             SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
             SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
 
-            ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+            ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate);
             ProcessAuthentication();
         }
 
@@ -389,7 +364,7 @@ namespace System.Net.Security
             SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
             SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
 
-            ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+            ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate);
 
             return ProcessAuthentication(true, false, cancellationToken)!;
         }
@@ -399,7 +374,7 @@ namespace System.Net.Security
             SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
             SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
 
-            ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+            ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate);
 
             return ProcessAuthentication(true, true, cancellationToken)!;
         }
@@ -457,7 +432,7 @@ namespace System.Net.Security
 
         public Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, CancellationToken cancellationToken = default)
         {
-            ValidateCreateContext(new SslAuthenticationOptions(optionsCallback, state));
+            ValidateCreateContext(new SslAuthenticationOptions(optionsCallback, state, _userCertificateValidationCallback));
             return ProcessAuthentication(isAsync: true, isApm: false, cancellationToken)!;
         }
 
@@ -560,8 +535,7 @@ namespace System.Net.Security
             get
             {
                 ThrowIfExceptionalOrNotAuthenticated();
-                _remoteCertificateExposed = true;
-                return _remoteCertificate;
+                return _context?.RemoteCertificate;
             }
         }
 
@@ -714,12 +688,6 @@ namespace System.Net.Security
         {
             try
             {
-                if (!_remoteCertificateExposed)
-                {
-                    _remoteCertificate?.Dispose();
-                    _remoteCertificate = null;
-                    _remoteCertificateExposed = false;
-                }
                 CloseInternal();
             }
             finally
index 6695fc9..4682ba3 100644 (file)
@@ -3,6 +3,7 @@
 
 using System.Collections.Generic;
 using System.ComponentModel;
+using System.IO;
 using System.Net.Sockets;
 using System.Net.Test.Common;
 using System.Security.Authentication;
@@ -193,6 +194,68 @@ namespace System.Net.Security.Tests
             }
         }
 
+        [Fact]
+        public async Task ServerAsyncAuthenticate_VerificationDelegate_Success()
+        {
+            bool validationCallbackCalled = false;
+            var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate, ClientCertificateRequired = true, };
+            var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
+            clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+            serverOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) =>
+            {
+                validationCallbackCalled = true;
+                return true;
+            };
+
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            {
+                Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
+                Task t2 = server.AuthenticateAsServerAsync(
+                    (stream, clientHelloInfo, userState, cancellationToken) =>
+                    {
+                        Assert.Equal(server, stream);
+                        Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName);
+                        return new ValueTask<SslServerAuthenticationOptions>(OptionsTask(serverOptions));
+                    },
+                    null, CancellationToken.None);
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
+                Assert.True(validationCallbackCalled);
+            }
+        }
+
+        [Fact]
+        public async Task ServerAsyncAuthenticate_ConstructorVerificationDelegate_Success()
+        {
+            bool validationCallbackCalled = false;
+            var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate, ClientCertificateRequired = true, };
+            var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
+            clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+
+            (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams();
+            var client = new SslStream(clientStream);
+            var server = new SslStream(serverStream, false, (sender, certificate, chain, sslPolicyErrors) => { validationCallbackCalled = true; return true;});
+
+            using (client)
+            using (server)
+            {
+                Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
+                Task t2 = server.AuthenticateAsServerAsync(
+                    (stream, clientHelloInfo, userState, cancellationToken) =>
+                    {
+                        Assert.Equal(server, stream);
+                        Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName);
+                        return new ValueTask<SslServerAuthenticationOptions>(OptionsTask(serverOptions));
+                    },
+                    null, CancellationToken.None);
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
+                Assert.True(validationCallbackCalled);
+            }
+        }
+
         [Theory]
         [InlineData(true)]
         [InlineData(false)]
index 46345ba..3b77a45 100644 (file)
@@ -277,22 +277,34 @@ namespace System.Net.Security.Tests
             }
         }
 
-        [Fact]
+        [Theory]
         [PlatformSpecific(TestPlatforms.AnyUnix)]
-        public async Task SslStream_UntrustedCaWithCustomCallback_Throws()
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task SslStream_UntrustedCaWithCustomCallback_Throws(bool customCallback)
         {
+            string errorMessage;
             var options = new  SslClientAuthenticationOptions() { TargetHost = "localhost" };
-            options.RemoteCertificateValidationCallback =
-                (sender, certificate, chain, sslPolicyErrors) =>
-                {
-                    chain.ChainPolicy.ExtraStore.AddRange(_serverChain);
-                    chain.ChainPolicy.CustomTrustStore.Add(_serverChain[_serverChain.Count -1]);
-                    chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
-                    // This should work and we should be able to trust the chain.
-                    Assert.True(chain.Build((X509Certificate2)certificate));
-                    // Reject it in custom callback to simulate for example pinning.
-                    return false;
-                };
+            if (customCallback)
+            {
+                options.RemoteCertificateValidationCallback =
+                    (sender, certificate, chain, sslPolicyErrors) =>
+                    {
+                        chain.ChainPolicy.ExtraStore.AddRange(_serverChain);
+                        chain.ChainPolicy.CustomTrustStore.Add(_serverChain[_serverChain.Count -1]);
+                        chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
+                        // This should work and we should be able to trust the chain.
+                        Assert.True(chain.Build((X509Certificate2)certificate));
+                        // Reject it in custom callback to simulate for example pinning.
+                        return false;
+                    };
+
+                errorMessage = "RemoteCertificateValidationCallback";
+            }
+            else
+            {
+                errorMessage = "PartialChain";
+            }
 
             (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams();
             using (clientStream)
@@ -303,7 +315,8 @@ namespace System.Net.Security.Tests
                 Task t1 = client.AuthenticateAsClientAsync(options, default);
                 Task t2 = server.AuthenticateAsServerAsync(_serverCert);
 
-                await Assert.ThrowsAsync<AuthenticationException>(() => t1);
+                var e = await Assert.ThrowsAsync<AuthenticationException>(() => t1);
+                Assert.Contains(errorMessage, e.Message);
                 // Server side should finish since we run custom callback after handshake is done.
                 await t2;
             }
index 66702b6..7139118 100644 (file)
@@ -17,7 +17,7 @@ namespace System.Net.Security
 
         private FakeOptions? _sslAuthenticationOptions;
 
-        private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertValidationCallback remoteCallback, LocalCertSelectionCallback? localCallback)
+        private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertificateValidationCallback? remoteCallback, LocalCertSelectionCallback? localCallback)
         {
             // Without setting (or using) these members you will get a build exception in the unit test project.
             // The code that normally uses these in the main solution is in the implementation of SslStream.
@@ -89,6 +89,7 @@ namespace System.Net.Security
         internal SslConnectionInfo ConnectionInfo => default;
         internal ChannelBinding GetChannelBinding(ChannelBindingKind kind) => default;
         internal X509Certificate LocalServerCertificate => default;
+        internal X509Certificate RemoteCertificate => default;
         internal bool IsRemoteCertificateAvailable => default;
         internal SslApplicationProtocol NegotiatedApplicationProtocol => default;
         internal X509Certificate LocalClientCertificate => default;