add NegotiateClientCertificateAsync support on Windows (#51905)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Wed, 2 Jun 2021 17:10:35 +0000 (19:10 +0200)
committerGitHub <noreply@github.com>
Wed, 2 Jun 2021 17:10:35 +0000 (19:10 +0200)
* add NegotiateClientCertificateAsync support on Windows

* feedback from review and test update

* throw on data during renegotiation

* disable NegotiateClientCertificateAsync on Win7

* feedback from review

* use Interlocked.Exchang instead of CompareExchange

* add trace message

15 files changed:
src/libraries/Common/src/Interop/Windows/SChannel/Interop.SECURITY_STATUS.cs
src/libraries/Common/src/System/Net/SecurityStatusAdapterPal.Windows.cs
src/libraries/Common/src/System/Net/SecurityStatusPal.cs
src/libraries/System.Net.Security/ref/System.Net.Security.cs
src/libraries/System.Net.Security/src/Resources/Strings.resx
src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Android.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs
src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs

index 05b201c..6befa44 100644 (file)
@@ -50,5 +50,6 @@ internal static partial class Interop
         BadBinding = unchecked((int)0x80090346),
         DowngradeDetected = unchecked((int)0x80090350),
         ApplicationProtocolMismatch = unchecked((int)0x80090367),
+        NoRenegotiation = unchecked((int)0x00090360),
     }
 }
index 3b01a15..0967f8e 100644 (file)
@@ -9,7 +9,7 @@ namespace System.Net
 {
     internal static class SecurityStatusAdapterPal
     {
-        private const int StatusDictionarySize = 42;
+        private const int StatusDictionarySize = 43;
 
 #if DEBUG
         static SecurityStatusAdapterPal()
@@ -61,7 +61,8 @@ namespace System.Net
             { Interop.SECURITY_STATUS.UnsupportedPreauth, SecurityStatusPalErrorCode.UnsupportedPreauth },
             { Interop.SECURITY_STATUS.Unsupported, SecurityStatusPalErrorCode.Unsupported },
             { Interop.SECURITY_STATUS.UntrustedRoot, SecurityStatusPalErrorCode.UntrustedRoot },
-            { Interop.SECURITY_STATUS.WrongPrincipal, SecurityStatusPalErrorCode.WrongPrincipal }
+            { Interop.SECURITY_STATUS.WrongPrincipal, SecurityStatusPalErrorCode.WrongPrincipal },
+            { Interop.SECURITY_STATUS.NoRenegotiation, SecurityStatusPalErrorCode.NoRenegotiation }
         };
 
         internal static SecurityStatusPal GetSecurityStatusPalFromNativeInt(int win32SecurityStatus)
index bc3407b..4b99938 100644 (file)
@@ -69,6 +69,7 @@ namespace System.Net
         UnsupportedPreauth,
         BadBinding,
         DowngradeDetected,
-        ApplicationProtocolMismatch
+        ApplicationProtocolMismatch,
+        NoRenegotiation
     }
 }
index 5fba93d..4e277ce 100644 (file)
@@ -215,6 +215,7 @@ namespace System.Net.Security
         public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate) { throw null; }
         public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation) { throw null; }
         public virtual System.Threading.Tasks.Task AuthenticateAsServerAsync(System.Security.Cryptography.X509Certificates.X509Certificate serverCertificate, bool clientCertificateRequired, System.Security.Authentication.SslProtocols enabledSslProtocols, bool checkCertificateRevocation) { throw null; }
+        public System.Threading.Tasks.Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, System.Threading.CancellationToken cancellationToken = default) { throw null; }
         public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.AsyncCallback? asyncCallback, object? asyncState) { throw null; }
         public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection? clientCertificates, bool checkCertificateRevocation, System.AsyncCallback? asyncCallback, object? asyncState) { throw null; }
         public virtual System.IAsyncResult BeginAuthenticateAsClient(string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection? clientCertificates, System.Security.Authentication.SslProtocols enabledSslProtocols, bool checkCertificateRevocation, System.AsyncCallback? asyncCallback, object? asyncState) { throw null; }
@@ -232,6 +233,7 @@ namespace System.Net.Security
         ~SslStream() { }
         public override void Flush() { }
         public override System.Threading.Tasks.Task FlushAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
+        public virtual System.Threading.Tasks.Task NegotiateClientCertificateAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
         public override int Read(byte[] buffer, int offset, int count) { throw null; }
         public override System.Threading.Tasks.Task<int> ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
         public override System.Threading.Tasks.ValueTask<int> ReadAsync(System.Memory<byte> buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
@@ -243,7 +245,6 @@ namespace System.Net.Security
         public override void Write(byte[] buffer, int offset, int count) { }
         public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
         public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
-        public System.Threading.Tasks.Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, System.Threading.CancellationToken cancellationToken = default) { throw null; }
     }
     [System.CLSCompliantAttribute(false)]
     public enum TlsCipherSuite : ushort
index cf50b28..c002340 100644 (file)
   <data name="SystemNetSecurity_PlatformNotSupported" xml:space="preserve">
     <value>System.Net.Security is not supported on this platform.</value>
   </data>
+  <data name="net_ssl_certificate_exist" xml:space="preserve">
+    <value>Remote certificate is already available.</value>
+  </data>
+  <data name="net_ssl_renegotiate_data" xml:space="preserve">
+    <value>Received data during renegotiation.</value>
+  </data>
 </root>
index 0e287e0..5d234b5 100644 (file)
@@ -830,6 +830,15 @@ namespace System.Net.Security
             return status;
         }
 
+        internal SecurityStatusPal Renegotiate(out byte[]? output)
+        {
+            return SslStreamPal.Renegotiate(
+                                      ref _credentialsHandle!,
+                                      ref _securityContext,
+                                      _sslAuthenticationOptions,
+                                      out output);
+        }
+
         /*++
             ProcessHandshakeSuccess -
                Called on successful completion of Handshake -
index 3193e8f..b838db0 100644 (file)
@@ -26,7 +26,7 @@ namespace System.Net.Security
             BeforeSSL3,     // SSlv2
             SinceSSL3,      // SSlv3 & TLS
             Unified,        // Intermediate on first frame until response is processes.
-            Invalid         // Somthing is wrong.
+            Invalid         // Something is wrong.
         }
 
         // This is set on the first packet to figure out the framing style.
@@ -305,6 +305,59 @@ namespace System.Net.Security
             }
         }
 
+        // This will initiate renegotiation or PHA for Tls1.3
+        private async Task RenegotiateAsync(CancellationToken cancellationToken)
+        {
+            if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
+            {
+                throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "NegotiateClientCertificateAsync", "renegotiate"));
+            }
+
+            if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
+            {
+                throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(SslStream.ReadAsync), "read"));
+            }
+
+            if (Interlocked.Exchange(ref _nestedWrite, 1) == 1)
+            {
+                _nestedRead = 0;
+                throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(WriteAsync), "write"));
+            }
+
+            _sslAuthenticationOptions!.RemoteCertRequired = true;
+            IReadWriteAdapter adapter = new AsyncReadWriteAdapter(InnerStream, cancellationToken);
+
+            try
+            {
+                SecurityStatusPal status = _context!.Renegotiate(out byte[]? nextmsg);
+                if (nextmsg?.Length > 0)
+                {
+                    await adapter.WriteAsync(nextmsg, 0, nextmsg.Length).ConfigureAwait(false);
+                    await adapter.FlushAsync().ConfigureAwait(false);
+                }
+
+                if (status.ErrorCode != SecurityStatusPalErrorCode.OK)
+                {
+                    if (status.ErrorCode == SecurityStatusPalErrorCode.NoRenegotiation)
+                    {
+                        // peer does not want to renegotiate. That should keep session usable.
+                        return;
+                    }
+
+                    throw SslStreamPal.GetException(status);
+                }
+
+                // Issue empty read to get renegotiation going.
+                await ReadAsyncInternal(adapter, Memory<byte>.Empty, renegotiation: true).ConfigureAwait(false);
+            }
+            finally
+            {
+                _nestedRead = 0;
+                _nestedWrite = 0;
+                // We will not release _nestedAuth at this point to prevent another renegotiation attempt.
+            }
+        }
+
         // reAuthenticationData is only used on Windows in case of renegotiation.
         private async Task ForceAuthenticationAsync<TIOAdapter>(TIOAdapter adapter, bool receiveFirst, byte[]? reAuthenticationData, bool isApm = false)
              where TIOAdapter : IReadWriteAdapter
@@ -612,6 +665,15 @@ namespace System.Net.Security
         {
             _context!.ProcessHandshakeSuccess();
 
+            if (_nestedAuth != 1)
+            {
+                if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, $"Ignoring unsolicited renegotiated certificate.");
+                // ignore certificates received outside of handshake or requested renegotiation.
+                sslPolicyErrors = SslPolicyErrors.None;
+                chainStatus = X509ChainStatusFlags.NoError;
+                return true;
+            }
+
             if (!_context.VerifyRemoteCertificate(_sslAuthenticationOptions!.CertValidationDelegate, ref alertToken, out sslPolicyErrors, out chainStatus))
             {
                 _handshakeCompleted = false;
@@ -753,12 +815,15 @@ namespace System.Net.Security
             }
         }
 
-        private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(TIOAdapter adapter, Memory<byte> buffer)
+        private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(TIOAdapter adapter, Memory<byte> buffer, bool renegotiation = false)
             where TIOAdapter : IReadWriteAdapter
         {
-            if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
+            if (!renegotiation)
             {
-                throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(SslStream.ReadAsync), "read"));
+                if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
+                {
+                    throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(SslStream.ReadAsync), "read"));
+                }
             }
 
             Debug.Assert(_internalBuffer is null || _internalBufferCount > 0 || _decryptedBytesCount > 0, "_internalBuffer allocated when no data is buffered.");
@@ -769,10 +834,15 @@ namespace System.Net.Security
                 {
                     if (_decryptedBytesCount != 0)
                     {
+                        if (renegotiation)
+                        {
+                            throw new InvalidOperationException(SR.net_ssl_renegotiate_data);
+                        }
+
                         return CopyDecryptedData(buffer);
                     }
 
-                    if (buffer.Length == 0 && _internalBuffer is null)
+                    if (buffer.Length == 0 && _internalBuffer is null && !renegotiation)
                     {
                         // User requested a zero-byte read, and we have no data available in the buffer for processing.
                         // This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read
@@ -844,7 +914,6 @@ namespace System.Net.Security
                             // If that happen before EncryptData() runs, _handshakeWaiter will be set to null
                             // and EncryptData() will work normally e.g. no waiting, just exclusion with DecryptData()
 
-
                             if (_sslAuthenticationOptions!.AllowRenegotiation || SslProtocol == SslProtocols.Tls13)
                             {
                                 // create TCS only if we plan to proceed. If not, we will throw in block bellow outside of the lock.
@@ -880,8 +949,12 @@ namespace System.Net.Security
                             {
                                 throw new IOException(SR.net_ssl_io_renego);
                             }
-
                             await ReplyOnReAuthenticationAsync(adapter, extraBuffer).ConfigureAwait(false);
+                            if (renegotiation)
+                            {
+                                // if we received data frame instead, we would not be here but we would decrypt data and hit check above.
+                                return 0;
+                            }
                             // Loop on read.
                             continue;
                         }
@@ -897,7 +970,7 @@ namespace System.Net.Security
             }
             catch (Exception e)
             {
-                if (e is IOException || (e is OperationCanceledException && adapter.CancellationToken.IsCancellationRequested))
+                if (e is IOException || (e is OperationCanceledException && adapter.CancellationToken.IsCancellationRequested) || renegotiation)
                 {
                     throw;
                 }
index 0acf683..75c2938 100644 (file)
@@ -689,6 +689,17 @@ namespace System.Net.Security
 
         public override Task FlushAsync(CancellationToken cancellationToken) => InnerStream.FlushAsync(cancellationToken);
 
+        public virtual Task NegotiateClientCertificateAsync(CancellationToken cancellationToken = default)
+        {
+            ThrowIfExceptionalOrNotAuthenticated();
+            if (RemoteCertificate != null)
+            {
+                throw new InvalidOperationException(SR.net_ssl_certificate_exist);
+            }
+
+            return RenegotiateAsync(cancellationToken);
+        }
+
         protected override void Dispose(bool disposing)
         {
             try
index ba40b31..24f1e8e 100644 (file)
@@ -45,6 +45,11 @@ namespace System.Net.Security
             return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
         }
 
+        public static SecurityStatusPal Renegotiate(ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? context, SslAuthenticationOptions sslAuthenticationOptions, out byte[]? outputBuffer)
+        {
+            throw new PlatformNotSupportedException();
+        }
+
         public static SafeFreeCredentials AcquireCredentialsHandle(
             SslStreamCertificateContext? certificateContext,
             SslProtocols protocols,
index 9b24025..53c4d12 100644 (file)
@@ -52,6 +52,11 @@ namespace System.Net.Security
             return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
         }
 
+        public static SecurityStatusPal Renegotiate(ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? context, SslAuthenticationOptions sslAuthenticationOptions, out byte[]? outputBuffer)
+        {
+            throw new PlatformNotSupportedException();
+        }
+
         public static SafeFreeCredentials AcquireCredentialsHandle(
             SslStreamCertificateContext? certificateContext,
             SslProtocols protocols,
index bd9d913..1416e39 100644 (file)
@@ -36,6 +36,11 @@ namespace System.Net.Security
             return HandshakeInternal(credential!, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
         }
 
+        public static SecurityStatusPal Renegotiate(ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? context, SslAuthenticationOptions sslAuthenticationOptions, out byte[]? outputBuffer)
+        {
+            throw new PlatformNotSupportedException();
+        }
+
         public static SafeFreeCredentials AcquireCredentialsHandle(SslStreamCertificateContext? certificateContext,
             SslProtocols protocols, EncryptionPolicy policy, bool isServer)
         {
index b93e899..59a781b 100644 (file)
@@ -110,6 +110,14 @@ namespace System.Net.Security
             return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode);
         }
 
+        public static SecurityStatusPal Renegotiate(ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? context, SslAuthenticationOptions sslAuthenticationOptions, out byte[]? outputBuffer )
+        {
+            byte[]? output = Array.Empty<byte>();
+            SecurityStatusPal status =  AcceptSecurityContext(ref credentialsHandle, ref context, Span<byte>.Empty, ref output, sslAuthenticationOptions);
+            outputBuffer = output;
+            return status;
+        }
+
         public static SafeFreeCredentials AcquireCredentialsHandle(SslStreamCertificateContext? certificateContext, SslProtocols protocols, EncryptionPolicy policy, bool isServer)
         {
             // New crypto API supports TLS1.3 but it does not allow to force NULL encryption.
index d4b2bd0..957a40d 100644 (file)
@@ -172,6 +172,291 @@ namespace System.Net.Security.Tests
             }
         }
 
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsNotWindows7))]
+        [InlineData(true)]
+        [InlineData(false)]
+        [PlatformSpecific(TestPlatforms.Windows)]
+        public async Task SslStream_NegotiateClientCertificateAsync_Succeeds(bool sendClientCertificate)
+        {
+            bool negotiateClientCertificateCalled = false;
+            using CancellationTokenSource cts = new CancellationTokenSource();
+            cts.CancelAfter(TestConfiguration.PassingTestTimeout);
+
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate())
+            using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate())
+            {
+                SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
+                {
+                    TargetHost = Guid.NewGuid().ToString("N"),
+                    EnabledSslProtocols = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12,
+                };
+                clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+                clientOptions.LocalCertificateSelectionCallback = (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
+                {
+                    return sendClientCertificate ? clientCertificate : null;
+                };
+
+                SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = serverCertificate };
+                serverOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) =>
+                {
+                    if (negotiateClientCertificateCalled && sendClientCertificate)
+                    {
+                        Assert.Equal(clientCertificate.GetCertHash(), certificate?.GetCertHash());
+                    }
+                    else
+                    {
+                        Assert.Null(certificate);
+                    }
+
+                    return true;
+                };
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                                client.AuthenticateAsClientAsync(clientOptions, cts.Token),
+                                server.AuthenticateAsServerAsync(serverOptions, cts.Token));
+
+                Assert.Null(server.RemoteCertificate);
+
+                // Client needs to be reading for renegotiation to happen.
+                byte[] buffer = new byte[TestHelper.s_ping.Length];
+                ValueTask<int> t = client.ReadAsync(buffer, cts.Token);
+
+                negotiateClientCertificateCalled = true;
+                await server.NegotiateClientCertificateAsync(cts.Token);
+                if (sendClientCertificate)
+                {
+                    Assert.NotNull(server.RemoteCertificate);
+                }
+                else
+                {
+                    Assert.Null(server.RemoteCertificate);
+                }
+                // Finish the client's read
+                await server.WriteAsync(TestHelper.s_ping, cts.Token);
+                await t;
+                // verify that the session is usable with or without client's certificate
+                await TestHelper.PingPong(client, server, cts.Token);
+                await TestHelper.PingPong(server, client, cts.Token);
+            }
+        }
+
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.SupportsTls13))]
+        [InlineData(true)]
+        [InlineData(false)]
+        [PlatformSpecific(TestPlatforms.Windows)]
+        public async Task SslStream_NegotiateClientCertificateAsyncTls13_Succeeds(bool sendClientCertificate)
+        {
+            bool negotiateClientCertificateCalled = false;
+            using CancellationTokenSource cts = new CancellationTokenSource();
+            cts.CancelAfter(TestConfiguration.PassingTestTimeout);
+
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate())
+            using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate())
+            {
+                SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
+                {
+                    TargetHost = Guid.NewGuid().ToString("N"),
+                    EnabledSslProtocols = SslProtocols.Tls13,
+                };
+                clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+                clientOptions.LocalCertificateSelectionCallback = (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
+                {
+                    return sendClientCertificate ? clientCertificate : null;
+                };
+
+                SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = serverCertificate };
+                serverOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) =>
+                {
+                    if (negotiateClientCertificateCalled && sendClientCertificate)
+                    {
+                        Assert.Equal(clientCertificate.GetCertHash(), certificate?.GetCertHash());
+                    }
+                    else
+                    {
+                        Assert.Null(certificate);
+                    }
+
+                    return true;
+                };
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                                client.AuthenticateAsClientAsync(clientOptions, cts.Token),
+                                server.AuthenticateAsServerAsync(serverOptions, cts.Token));
+                // need this to complete TLS 1.3 handshake
+                await TestHelper.PingPong(client, server);
+                Assert.Null(server.RemoteCertificate);
+
+                // Client needs to be reading for renegotiation to happen.
+                byte[] buffer = new byte[TestHelper.s_ping.Length];
+                ValueTask<int> t = client.ReadAsync(buffer, cts.Token);
+
+                negotiateClientCertificateCalled = true;
+                await server.NegotiateClientCertificateAsync(cts.Token);
+                if (sendClientCertificate)
+                {
+                    Assert.NotNull(server.RemoteCertificate);
+                }
+                else
+                {
+                    Assert.Null(server.RemoteCertificate);
+                }
+                // Finish the client's read
+                await server.WriteAsync(TestHelper.s_ping, cts.Token);
+                await t;
+                // verify that the session is usable with or without client's certificate
+                await TestHelper.PingPong(client, server, cts.Token);
+                await TestHelper.PingPong(server, client, cts.Token);
+            }
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        [PlatformSpecific(TestPlatforms.Windows)]
+        public async Task SslStream_SecondNegotiateClientCertificateAsync_Throws(bool sendClientCertificate)
+        {
+            using CancellationTokenSource cts = new CancellationTokenSource();
+            cts.CancelAfter(TestConfiguration.PassingTestTimeout);
+
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate())
+            using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate())
+            {
+                SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
+                {
+                    TargetHost = Guid.NewGuid().ToString("N"),
+                    EnabledSslProtocols = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12,
+                };
+                clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+                clientOptions.LocalCertificateSelectionCallback = (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
+                {
+                    return sendClientCertificate ? clientCertificate : null;
+                };
+
+                SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = serverCertificate };
+                serverOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                                client.AuthenticateAsClientAsync(clientOptions, cts.Token),
+                                server.AuthenticateAsServerAsync(serverOptions, cts.Token));
+
+                await TestHelper.PingPong(client, server, cts.Token);
+                Assert.Null(server.RemoteCertificate);
+
+                // Client needs to be reading for renegotiation to happen.
+                byte[] buffer = new byte[TestHelper.s_ping.Length];
+                ValueTask<int> t = client.ReadAsync(buffer, cts.Token);
+
+                await server.NegotiateClientCertificateAsync(cts.Token);
+                if (sendClientCertificate)
+                {
+                    Assert.NotNull(server.RemoteCertificate);
+                }
+                else
+                {
+                    Assert.Null(server.RemoteCertificate);
+                }
+                // Finish the client's read
+                await server.WriteAsync(TestHelper.s_ping, cts.Token);
+                await t;
+
+                await Assert.ThrowsAsync<InvalidOperationException>(() => server.NegotiateClientCertificateAsync());
+            }
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        [PlatformSpecific(TestPlatforms.Windows)]
+        public async Task SslStream_NegotiateClientCertificateAsyncConcurrentIO_Throws(bool doRead)
+        {
+            using CancellationTokenSource cts = new CancellationTokenSource();
+            cts.CancelAfter(TestConfiguration.PassingTestTimeout);
+
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate())
+            using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate())
+            {
+                SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
+                {
+                    TargetHost = Guid.NewGuid().ToString("N"),
+                    ClientCertificates = new X509CertificateCollection(new X509Certificate2[] { clientCertificate })
+                };
+                clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+
+                SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = serverCertificate };
+                serverOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                                client.AuthenticateAsClientAsync(clientOptions, cts.Token),
+                                server.AuthenticateAsServerAsync(serverOptions, cts.Token));
+
+                await TestHelper.PingPong(client, server, cts.Token);
+                Assert.Null(server.RemoteCertificate);
+
+                Task t = server.NegotiateClientCertificateAsync(cts.Token);
+                if (doRead)
+                {
+                    byte[] buffer = new byte[TestHelper.s_ping.Length];
+                    await Assert.ThrowsAsync<NotSupportedException>(() => server.ReadAsync(buffer).AsTask());
+                }
+                else
+                {
+                    await Assert.ThrowsAsync<NotSupportedException>(() => server.WriteAsync(TestHelper.s_ping).AsTask());
+                }
+            }
+        }
+
+        [Fact]
+        [PlatformSpecific(TestPlatforms.Windows)]
+        public async Task NegotiateClientCertificateAsync_PendingData_Throws()
+        {
+            using CancellationTokenSource cts = new CancellationTokenSource();
+            cts.CancelAfter(TestConfiguration.PassingTestTimeout);
+
+            (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
+            using (client)
+            using (server)
+            using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate())
+            using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate())
+            {
+                SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
+                {
+                    TargetHost = Guid.NewGuid().ToString("N"),
+                    ClientCertificates = new X509CertificateCollection(new X509Certificate2[] { clientCertificate })
+                };
+                clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+
+                SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = serverCertificate };
+                serverOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
+
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                                client.AuthenticateAsClientAsync(clientOptions, cts.Token),
+                                server.AuthenticateAsServerAsync(serverOptions, cts.Token));
+
+                await TestHelper.PingPong(client, server, cts.Token);
+                Assert.Null(server.RemoteCertificate);
+
+                // This should go out in single TLS frame
+                await client.WriteAsync(new byte[200], cts.Token);
+                byte[] readBuffer = new byte[10];
+                // when we read part of the frame, remaining part should left decrypted
+                await server.ReadAsync(readBuffer, cts.Token);
+
+                await Assert.ThrowsAsync<InvalidOperationException>(() => server.NegotiateClientCertificateAsync(cts.Token));
+            }
+        }
+
         [Fact]
         public async Task SslStream_NestedAuth_Throws()
         {
index fabda4d..998e938 100644 (file)
@@ -11,6 +11,7 @@ using System.Security.Cryptography.X509Certificates;
 using System.Security.Cryptography.X509Certificates.Tests.Common;
 using System.Runtime.CompilerServices;
 using System.Text;
+using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
 
@@ -42,8 +43,8 @@ namespace System.Net.Security.Tests
         private static readonly X509BasicConstraintsExtension s_eeConstraints =
             new X509BasicConstraintsExtension(false, false, 0, false);
 
-        private static readonly byte[] s_ping = Encoding.UTF8.GetBytes("PING");
-        private static readonly byte[] s_pong = Encoding.UTF8.GetBytes("PONG");
+        public static readonly byte[] s_ping = Encoding.UTF8.GetBytes("PING");
+        public static readonly byte[] s_pong = Encoding.UTF8.GetBytes("PONG");
 
         public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams()
         {
@@ -195,26 +196,26 @@ namespace System.Net.Security.Tests
             return (endEntity, chain);
         }
 
-        internal static async Task PingPong(SslStream client, SslStream server)
+        internal static async Task PingPong(SslStream client, SslStream server, CancellationToken cancellationToken = default)
         {
             byte[] buffer = new byte[s_ping.Length];
-            ValueTask t = client.WriteAsync(s_ping);
+            ValueTask t = client.WriteAsync(s_ping, cancellationToken);
 
             int remains = s_ping.Length;
             while (remains > 0)
             {
-                int readLength = await server.ReadAsync(buffer, buffer.Length - remains, remains);
+                int readLength = await server.ReadAsync(buffer, buffer.Length - remains, remains, cancellationToken);
                 Assert.True(readLength > 0);
                 remains -= readLength;
             }
             Assert.Equal(s_ping, buffer);
             await t;
 
-            t = server.WriteAsync(s_pong);
+            t = server.WriteAsync(s_pong, cancellationToken);
             remains = s_pong.Length;
             while (remains > 0)
             {
-                int readLength = await client.ReadAsync(buffer, buffer.Length - remains, remains);
+                int readLength = await client.ReadAsync(buffer, buffer.Length - remains, remains, cancellationToken);
                 Assert.True(readLength > 0);
                 remains -= readLength;
             }
index 666ca16..4e14615 100644 (file)
@@ -59,6 +59,8 @@ namespace System.Net.Security
             return Task.Run(() => {});
         }
 
+        private Task RenegotiateAsync(CancellationToken cancellationToken) => throw new PlatformNotSupportedException();
+
         private void ReturnReadBufferIfEmpty()
         {
         }