avoid ProtocolToken allocations in TLS handshake (#86163)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Thu, 18 May 2023 05:42:35 +0000 (22:42 -0700)
committerGitHub <noreply@github.com>
Thu, 18 May 2023 05:42:35 +0000 (22:42 -0700)
* avoild ProtocolToken allocations in TLS handshake

* cleanup

* UnitTests

* feedback from review

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Android.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs
src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs

index f35aafe..172e585 100644 (file)
@@ -14,7 +14,7 @@ namespace System.Net.Security
     {
         private JavaProxy.RemoteCertificateValidationResult VerifyRemoteCertificate()
         {
-            ProtocolToken? alertToken = null;
+            ProtocolToken alertToken = default;
             var isValid = VerifyRemoteCertificate(
                 _sslAuthenticationOptions.CertValidationDelegate,
                 _sslAuthenticationOptions.CertificateContext?.Trust,
@@ -31,13 +31,13 @@ namespace System.Net.Security
             };
         }
 
-        private bool TryGetRemoteCertificateValidationResult(out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus, out ProtocolToken? alertToken, out bool isValid)
+        private bool TryGetRemoteCertificateValidationResult(out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus, ref ProtocolToken alertToken, out bool isValid)
         {
             JavaProxy.RemoteCertificateValidationResult? validationResult = _securityContext?.SslStreamProxy.ValidationResult;
             sslPolicyErrors = validationResult?.SslPolicyErrors ?? default;
             chainStatus = validationResult?.ChainStatus ?? default;
             isValid = validationResult?.IsValid ?? default;
-            alertToken = validationResult?.AlertToken;
+            alertToken = validationResult?.AlertToken ?? default;
             return validationResult is not null;
         }
 
index 0536517..7ea62aa 100644 (file)
@@ -216,7 +216,9 @@ namespace System.Net.Security
                 ProtocolToken message;
                 do
                 {
-                    message = await ReceiveBlobAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+                    int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+                    ProcessTlsFrame(frameSize, out message);
+
                     if (message.Size > 0)
                     {
                         await TIOAdapter.WriteAsync(InnerStream, new ReadOnlyMemory<byte>(message.Payload!, 0, message.Size), cancellationToken).ConfigureAwait(false);
@@ -245,7 +247,7 @@ namespace System.Net.Security
         private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[]? reAuthenticationData, CancellationToken cancellationToken)
             where TIOAdapter : IReadWriteAdapter
         {
-            ProtocolToken message;
+            ProtocolToken message = default;
             bool handshakeCompleted = false;
 
             if (reAuthenticationData == null)
@@ -256,12 +258,12 @@ namespace System.Net.Security
                     throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
                 }
             }
-
             try
             {
                 if (!receiveFirst)
                 {
-                    message = NextMessage(reAuthenticationData);
+                    NextMessage(reAuthenticationData, out message);
+
                     if (message.Size > 0)
                     {
                         await TIOAdapter.WriteAsync(InnerStream, new ReadOnlyMemory<byte>(message.Payload!, 0, message.Size), cancellationToken).ConfigureAwait(false);
@@ -289,7 +291,8 @@ namespace System.Net.Security
 
                 while (!handshakeCompleted)
                 {
-                    message = await ReceiveBlobAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+                    int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+                    ProcessTlsFrame(frameSize, out message);
 
                     ReadOnlyMemory<byte> payload = default;
                     if (message.Size > 0)
@@ -355,7 +358,8 @@ namespace System.Net.Security
 
         }
 
-        private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(CancellationToken cancellationToken)
+        // This method will make sure we have at least one full TLS frame buffered.
+        private async ValueTask<int> ReceiveTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
             where TIOAdapter : IReadWriteAdapter
         {
             int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
@@ -430,11 +434,11 @@ namespace System.Net.Security
 
             }
 
-            return ProcessBlob(frameSize);
+            return frameSize;
         }
 
         // Calls crypto on received data. No IO inside.
-        private ProtocolToken ProcessBlob(int frameSize)
+        private void ProcessTlsFrame(int frameSize, out ProtocolToken message)
         {
             int chunkSize = frameSize;
 
@@ -467,18 +471,18 @@ namespace System.Net.Security
                 _buffer.DiscardEncrypted(frameSize);
             }
 
-            return NextMessage(availableData.Slice(0, chunkSize));
+            NextMessage(availableData.Slice(0, chunkSize), out message);
         }
 
         //
         //  This is to reset auth state on remote side.
         //  If this write succeeds we will allow auth retrying.
         //
-        private void SendAuthResetSignal(ProtocolToken? message, ExceptionDispatchInfo exception)
+        private void SendAuthResetSignal(ReadOnlySpan<byte> alert, ExceptionDispatchInfo exception)
         {
             SetException(exception.SourceException);
 
-            if (message == null || message.Size == 0)
+            if (alert.Length == 0)
             {
                 //
                 // We don't have an alert to send so cannot retry and fail prematurely.
@@ -486,7 +490,7 @@ namespace System.Net.Security
                 exception.Throw();
             }
 
-            InnerStream.Write(message.Payload!, 0, message.Size);
+            InnerStream.Write(alert);
 
             exception.Throw();
         }
@@ -499,7 +503,7 @@ namespace System.Net.Security
         //
         // - Returns false if failed to verify the Remote Cert
         //
-        private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
+        private bool CompleteHandshake(ref ProtocolToken alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
         {
             ProcessHandshakeSuccess();
 
@@ -527,7 +531,7 @@ namespace System.Net.Security
             // The Java TrustManager callback is called only when the peer has a certificate. It's possible that
             // the peer didn't provide any certificate (for example when the peer is the client) and the validation
             // result hasn't been set. In that case we still need to run the verification at this point.
-            if (TryGetRemoteCertificateValidationResult(out sslPolicyErrors, out chainStatus, out alertToken, out bool isValid))
+            if (TryGetRemoteCertificateValidationResult(out sslPolicyErrors, out chainStatus, ref alertToken, out bool isValid))
             {
                 _handshakeCompleted = isValid;
                 return isValid;
@@ -546,23 +550,23 @@ namespace System.Net.Security
 
         private void CompleteHandshake(SslAuthenticationOptions sslAuthenticationOptions)
         {
-            ProtocolToken? alertToken = null;
+            ProtocolToken alertToken = default;
             if (!CompleteHandshake(ref alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus))
             {
                 if (sslAuthenticationOptions!.CertValidationDelegate != null)
                 {
                     // there may be some chain errors but the decision was made by custom callback. Details should be tracing if enabled.
-                    SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_custom_validation, null)));
+                    SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_custom_validation, null)));
                 }
                 else if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateChainErrors && chainStatus != X509ChainStatusFlags.NoError)
                 {
                     // We failed only because of chain and we have some insight.
-                    SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_chain_validation, chainStatus), null)));
+                    SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_chain_validation, chainStatus), null)));
                 }
                 else
                 {
                     // Simple add sslPolicyErrors as crude info.
-                    SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_validation, sslPolicyErrors), null)));
+                    SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_validation, sslPolicyErrors), null)));
                 }
             }
         }
index 79e655c..1bd2d22 100644 (file)
@@ -751,20 +751,20 @@ namespace System.Net.Security
         }
 
         //
-        internal ProtocolToken NextMessage(ReadOnlySpan<byte> incomingBuffer)
+        internal void NextMessage(ReadOnlySpan<byte> incomingBuffer, out ProtocolToken token)
         {
             byte[]? nextmsg = null;
-            SecurityStatusPal status = GenerateToken(incomingBuffer, ref nextmsg);
-            ProtocolToken token = new ProtocolToken(nextmsg, status);
+            token.Status = GenerateToken(incomingBuffer, ref nextmsg);
+            token.Size = nextmsg?.Length ?? 0;
+            token.Payload = nextmsg;
 
             if (NetEventSource.Log.IsEnabled())
             {
                 if (token.Failed)
                 {
-                    NetEventSource.Error(this, $"Authentication failed. Status: {status}, Exception message: {token.GetException()!.Message}");
+                    NetEventSource.Error(this, $"Authentication failed. Status: {token.Status}, Exception message: {token.GetException()!.Message}");
                 }
             }
-            return token;
         }
 
         /*++
@@ -992,7 +992,7 @@ namespace System.Net.Security
         --*/
 
         //This method validates a remote certificate.
-        internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remoteCertValidationCallback, SslCertificateTrust? trust, ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
+        internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remoteCertValidationCallback, SslCertificateTrust? trust, ref ProtocolToken alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
         {
             sslPolicyErrors = SslPolicyErrors.None;
             chainStatus = X509ChainStatusFlags.NoError;
@@ -1085,7 +1085,7 @@ namespace System.Net.Security
 
                 if (!success)
                 {
-                    alertToken = CreateFatalHandshakeAlertToken(sslPolicyErrors, chain!);
+                    CreateFatalHandshakeAlertToken(sslPolicyErrors, chain!, ref alertToken);
                     if (chain != null)
                     {
                         foreach (X509ChainStatus status in chain.ChainStatus)
@@ -1115,7 +1115,7 @@ namespace System.Net.Security
             return success;
         }
 
-        private ProtocolToken? CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain)
+        private void CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain, ref ProtocolToken alertToken)
         {
             TlsAlertMessage alertMessage;
 
@@ -1148,15 +1148,14 @@ namespace System.Net.Security
                 {
                     ExceptionDispatchInfo.Throw(status.Exception);
                 }
-
-                return null;
             }
 
-            return GenerateAlertToken();
+            GenerateAlertToken(ref alertToken);
         }
 
-        private ProtocolToken? CreateShutdownToken()
+        private byte[]? CreateShutdownToken()
         {
+            byte[]? nextmsg = null;
             SecurityStatusPal status;
             status = SslStreamPal.ApplyShutdownToken(_securityContext!);
 
@@ -1173,17 +1172,21 @@ namespace System.Net.Security
                 return null;
             }
 
-            return GenerateAlertToken();
+            GenerateToken(default, ref nextmsg);
+
+            return nextmsg;
         }
 
-        private ProtocolToken GenerateAlertToken()
+        private void GenerateAlertToken(ref ProtocolToken alertToken)
         {
             byte[]? nextmsg = null;
 
             SecurityStatusPal status;
             status = GenerateToken(default, ref nextmsg);
 
-            return new ProtocolToken(nextmsg, status);
+            alertToken.Payload = nextmsg;
+            alertToken.Size = nextmsg?.Length ?? 0;
+            alertToken.Status = status;
         }
 
         private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain)
@@ -1286,7 +1289,7 @@ namespace System.Net.Security
     }
 
     // ProtocolToken - used to process and handle the return codes from the SSPI wrapper
-    internal sealed class ProtocolToken
+    internal struct ProtocolToken
     {
         internal SecurityStatusPal Status;
         internal byte[]? Payload;
index 9f4ac13..8af4ab0 100644 (file)
@@ -441,9 +441,14 @@ namespace System.Net.Security
         {
             ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
 
-            ProtocolToken message = CreateShutdownToken()!;
+            byte[]? message = CreateShutdownToken();
             _shutdown = true;
-            return InnerStream.WriteAsync(message.Payload, default).AsTask();
+            if (message != null)
+            {
+                return InnerStream.WriteAsync(message, default).AsTask();
+            }
+
+            return Task.CompletedTask;
         }
         #endregion
 
index a34fcab..f293cf6 100644 (file)
@@ -92,7 +92,7 @@ namespace System.Net.Security
         {
         }
 
-        private ProtocolToken? CreateShutdownToken()
+        private byte[]? CreateShutdownToken()
         {
             return null;
         }