Streamline SslStream state validation (dotnet/corefx#39950)
authorBen Adams <thundercat@illyriad.co.uk>
Fri, 2 Aug 2019 18:45:03 +0000 (19:45 +0100)
committerDavid Shulman <david.shulman@microsoft.com>
Fri, 2 Aug 2019 18:45:02 +0000 (11:45 -0700)
* Streamline SslStream state validation

* Feedback

Commit migrated from https://github.com/dotnet/corefx/commit/2f0fd25d59dd2bd77f4008039dab5ce9a9c45058

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs

index 43fdf85..0332ca1 100644 (file)
@@ -63,26 +63,6 @@ namespace System.Net.Security
         private int _lockReadState;
         private object _queuedReadStateRequest;
 
-        /// <summary>Set as the _exception when the instance is disposed.</summary>
-        private static readonly ExceptionDispatchInfo s_disposedSentinel = ExceptionDispatchInfo.Capture(new ObjectDisposedException(nameof(SslStream), (string)null));
-
-        private void ThrowIfExceptional()
-        {
-            ExceptionDispatchInfo e = _exception;
-            if (e != null)
-            {
-                // If the stored exception just indicates disposal, throw a new ODE rather than the stored one,
-                // so as to not continually build onto the shared exception's stack.
-                if (ReferenceEquals(e, s_disposedSentinel))
-                {
-                    throw new ObjectDisposedException(nameof(SslStream));
-                }
-
-                // Throw the stored exception.
-                e.Throw();
-            }
-        }
-
         private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertValidationCallback remoteCallback, LocalCertSelectionCallback localCallback)
         {
             ThrowIfExceptional();
@@ -163,21 +143,6 @@ namespace System.Net.Security
             _context?.Close();
         }
 
-        private void CheckThrow(bool authSuccessCheck, bool shutdownCheck = false)
-        {
-            ThrowIfExceptional();
-
-            if (authSuccessCheck && !IsAuthenticated)
-            {
-                throw new InvalidOperationException(SR.net_auth_noauth);
-            }
-
-            if (shutdownCheck && _shutdown)
-            {
-                throw new InvalidOperationException(SR.net_ssl_io_already_shutdown);
-            }
-        }
-
         //
         // This is to not depend on GC&SafeHandle class if the context is not needed anymore.
         //
@@ -211,13 +176,13 @@ namespace System.Net.Security
 
         private SecurityStatusPal EncryptData(ReadOnlyMemory<byte> buffer, ref byte[] outBuffer, out int outSize)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             return _context.Encrypt(buffer, ref outBuffer, out outSize);
         }
 
         private SecurityStatusPal DecryptData()
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             return PrivateDecryptData(_internalBuffer, ref _decryptedBytesOffset, ref _decryptedBytesCount);
         }
 
@@ -258,7 +223,7 @@ namespace System.Net.Security
         //
         private int CheckOldKeyDecryptedData(Memory<byte> buffer)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             if (_queuedReadData != null)
             {
                 // This is inefficient yet simple and should be a REALLY rare case.
@@ -292,7 +257,7 @@ namespace System.Net.Security
 
             try
             {
-                CheckThrow(false);
+                ThrowIfExceptional();
                 AsyncProtocolRequest asyncRequest = null;
                 if (lazyResult != null)
                 {
@@ -1039,7 +1004,7 @@ namespace System.Net.Security
             {
                 if (_lockWriteState != LockHandshake)
                 {
-                    CheckThrow(authSuccessCheck: true);
+                    ThrowIfExceptionalOrNotAuthenticated();
                     return Task.CompletedTask;
                 }
 
@@ -1067,7 +1032,7 @@ namespace System.Net.Security
                 if (_lockWriteState != LockHandshake)
                 {
                     // Handshake has completed before we grabbed the lock.
-                    CheckThrow(authSuccessCheck: true);
+                    ThrowIfExceptionalOrNotAuthenticated();
                     return;
                 }
 
@@ -1079,7 +1044,7 @@ namespace System.Net.Security
 
             // Need to exit from lock before waiting.
             lazyResult.InternalWaitForCompletion();
-            CheckThrow(authSuccessCheck: true);
+            ThrowIfExceptionalOrNotAuthenticated();
             return;
         }
 
@@ -1500,7 +1465,7 @@ namespace System.Net.Security
         private async Task WriteAsyncInternal<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
             where TWriteAdapter : struct, ISslWriteAdapter
         {
-            CheckThrow(authSuccessCheck: true, shutdownCheck: true);
+            ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
 
             if (buffer.Length == 0 && !SslStreamPal.CanEncryptEmptyMessage)
             {
index bc5aee8..f9d2dc1 100644 (file)
@@ -4,6 +4,7 @@
 
 using System.Diagnostics;
 using System.IO;
+using System.Runtime.CompilerServices;
 using System.Runtime.ExceptionServices;
 using System.Security.Authentication;
 using System.Security.Authentication.ExtendedProtection;
@@ -40,6 +41,9 @@ namespace System.Net.Security
 
     public partial class SslStream : AuthenticatedStream
     {
+        /// <summary>Set as the _exception when the instance is disposed.</summary>
+        private static readonly ExceptionDispatchInfo s_disposedSentinel = ExceptionDispatchInfo.Capture(new ObjectDisposedException(nameof(SslStream), (string)null));
+
         private X509Certificate2 _remoteCertificate;
         private bool _remoteCertificateExposed;
 
@@ -285,7 +289,7 @@ namespace System.Net.Security
 
         internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState)
         {
-            CheckThrow(authSuccessCheck: true, shutdownCheck: true);
+            ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
 
             ProtocolToken message = _context.CreateShutdownToken();
             return TaskToApm.Begin(InnerStream.WriteAsync(message.Payload, 0, message.Payload.Length), asyncCallback, asyncState);
@@ -293,7 +297,7 @@ namespace System.Net.Security
 
         internal void EndShutdown(IAsyncResult asyncResult)
         {
-            CheckThrow(authSuccessCheck: true, shutdownCheck: true);
+            ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
 
             TaskToApm.End(asyncResult);
             _shutdown = true;
@@ -471,7 +475,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 SslConnectionInfo info = _context.ConnectionInfo;
                 if (info == null)
                 {
@@ -527,7 +531,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 return _context.IsServer ? _context.LocalServerCertificate : _context.LocalClientCertificate;
             }
         }
@@ -536,7 +540,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 _remoteCertificateExposed = true;
                 return _remoteCertificate;
             }
@@ -547,7 +551,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 return _context.ConnectionInfo?.TlsCipherSuite ?? default(TlsCipherSuite);
             }
         }
@@ -556,7 +560,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 SslConnectionInfo info = _context.ConnectionInfo;
                 if (info == null)
                 {
@@ -570,7 +574,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 SslConnectionInfo info = _context.ConnectionInfo;
                 if (info == null)
                 {
@@ -585,7 +589,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 SslConnectionInfo info = _context.ConnectionInfo;
                 if (info == null)
                 {
@@ -599,7 +603,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 SslConnectionInfo info = _context.ConnectionInfo;
                 if (info == null)
                 {
@@ -614,7 +618,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 SslConnectionInfo info = _context.ConnectionInfo;
                 if (info == null)
                 {
@@ -629,7 +633,7 @@ namespace System.Net.Security
         {
             get
             {
-                CheckThrow(true);
+                ThrowIfExceptionalOrNotAuthenticated();
                 SslConnectionInfo info = _context.ConnectionInfo;
                 if (info == null)
                 {
@@ -711,7 +715,7 @@ namespace System.Net.Security
 
         public override int ReadByte()
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
             {
                 throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "ReadByte", "read"));
@@ -747,7 +751,7 @@ namespace System.Net.Security
 
         public override int Read(byte[] buffer, int offset, int count)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             ValidateParameters(buffer, offset, count);
             SslReadSync reader = new SslReadSync(this);
             return ReadAsyncInternal(reader, new Memory<byte>(buffer, offset, count)).GetAwaiter().GetResult();
@@ -757,7 +761,7 @@ namespace System.Net.Security
 
         public override void Write(byte[] buffer, int offset, int count)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             ValidateParameters(buffer, offset, count);
 
             SslWriteSync writeAdapter = new SslWriteSync(this);
@@ -766,45 +770,45 @@ namespace System.Net.Security
 
         public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             return TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState);
         }
 
         public override int EndRead(IAsyncResult asyncResult)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             return TaskToApm.End<int>(asyncResult);
         }
 
         public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             return TaskToApm.Begin(WriteAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState);
         }
 
         public override void EndWrite(IAsyncResult asyncResult)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             TaskToApm.End(asyncResult);
         }
 
         public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             ValidateParameters(buffer, offset, count);
             return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
         }
 
         public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             SslWriteAsync writeAdapter = new SslWriteAsync(this, cancellationToken);
             return new ValueTask(WriteAsyncInternal(writeAdapter, buffer));
         }
 
         public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             ValidateParameters(buffer, offset, count);
             SslReadAsync read = new SslReadAsync(this, cancellationToken);
             return ReadAsyncInternal(read, new Memory<byte>(buffer, offset, count)).AsTask();
@@ -812,9 +816,71 @@ namespace System.Net.Security
 
         public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
         {
-            CheckThrow(true);
+            ThrowIfExceptionalOrNotAuthenticated();
             SslReadAsync read = new SslReadAsync(this, cancellationToken);
             return ReadAsyncInternal(read, buffer);
         }
+
+        private void ThrowIfExceptional()
+        {
+            ExceptionDispatchInfo e = _exception;
+            if (e != null)
+            {
+                ThrowExceptional(e);
+            }
+
+            // Local function to make the check method more inline friendly.
+            static void ThrowExceptional(ExceptionDispatchInfo e)
+            {
+                // If the stored exception just indicates disposal, throw a new ODE rather than the stored one,
+                // so as to not continually build onto the shared exception's stack.
+                if (ReferenceEquals(e, s_disposedSentinel))
+                {
+                    throw new ObjectDisposedException(nameof(SslStream));
+                }
+
+                // Throw the stored exception.
+                e.Throw();
+            }
+        }
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private void ThrowIfExceptionalOrNotAuthenticated()
+        {
+            ThrowIfExceptional();
+
+            if (!IsAuthenticated)
+            {
+                ThrowNotAuthenticated();
+            }
+        }
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private void ThrowIfExceptionalOrNotAuthenticatedOrShutdown()
+        {
+            ThrowIfExceptional();
+
+            if (!IsAuthenticated)
+            {
+                ThrowNotAuthenticated();
+            }
+
+            if (_shutdown)
+            {
+                ThrowAlreadyShutdown();
+            }
+
+            // Local function to make the check method more inline friendly.
+            static void ThrowAlreadyShutdown()
+            {
+                throw new InvalidOperationException(SR.net_ssl_io_already_shutdown);
+            }
+        }
+
+        // Static non-returning throw method to make the check methods more inline friendly.
+        private static void ThrowNotAuthenticated()
+        {
+            throw new InvalidOperationException(SR.net_auth_noauth);
+        }
     }
 }