From 6667b84dcfd782007ffc3c59b5382030aa907517 Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Fri, 2 Aug 2019 19:45:03 +0100 Subject: [PATCH] Streamline SslStream state validation (dotnet/corefx#39950) * Streamline SslStream state validation * Feedback Commit migrated from https://github.com/dotnet/corefx/commit/2f0fd25d59dd2bd77f4008039dab5ce9a9c45058 --- .../Net/Security/SslStream.Implementation.cs | 51 ++-------- .../src/System/Net/Security/SslStream.cs | 112 ++++++++++++++++----- 2 files changed, 97 insertions(+), 66 deletions(-) diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs index 43fdf85..0332ca1 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs @@ -63,26 +63,6 @@ namespace System.Net.Security private int _lockReadState; private object _queuedReadStateRequest; - /// Set as the _exception when the instance is disposed. - private static readonly ExceptionDispatchInfo s_disposedSentinel = ExceptionDispatchInfo.Capture(new ObjectDisposedException(nameof(SslStream), (string)null)); - - private void ThrowIfExceptional() - { - ExceptionDispatchInfo e = _exception; - if (e != null) - { - // If the stored exception just indicates disposal, throw a new ODE rather than the stored one, - // so as to not continually build onto the shared exception's stack. - if (ReferenceEquals(e, s_disposedSentinel)) - { - throw new ObjectDisposedException(nameof(SslStream)); - } - - // Throw the stored exception. - e.Throw(); - } - } - private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertValidationCallback remoteCallback, LocalCertSelectionCallback localCallback) { ThrowIfExceptional(); @@ -163,21 +143,6 @@ namespace System.Net.Security _context?.Close(); } - private void CheckThrow(bool authSuccessCheck, bool shutdownCheck = false) - { - ThrowIfExceptional(); - - if (authSuccessCheck && !IsAuthenticated) - { - throw new InvalidOperationException(SR.net_auth_noauth); - } - - if (shutdownCheck && _shutdown) - { - throw new InvalidOperationException(SR.net_ssl_io_already_shutdown); - } - } - // // This is to not depend on GC&SafeHandle class if the context is not needed anymore. // @@ -211,13 +176,13 @@ namespace System.Net.Security private SecurityStatusPal EncryptData(ReadOnlyMemory buffer, ref byte[] outBuffer, out int outSize) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); return _context.Encrypt(buffer, ref outBuffer, out outSize); } private SecurityStatusPal DecryptData() { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); return PrivateDecryptData(_internalBuffer, ref _decryptedBytesOffset, ref _decryptedBytesCount); } @@ -258,7 +223,7 @@ namespace System.Net.Security // private int CheckOldKeyDecryptedData(Memory buffer) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); if (_queuedReadData != null) { // This is inefficient yet simple and should be a REALLY rare case. @@ -292,7 +257,7 @@ namespace System.Net.Security try { - CheckThrow(false); + ThrowIfExceptional(); AsyncProtocolRequest asyncRequest = null; if (lazyResult != null) { @@ -1039,7 +1004,7 @@ namespace System.Net.Security { if (_lockWriteState != LockHandshake) { - CheckThrow(authSuccessCheck: true); + ThrowIfExceptionalOrNotAuthenticated(); return Task.CompletedTask; } @@ -1067,7 +1032,7 @@ namespace System.Net.Security if (_lockWriteState != LockHandshake) { // Handshake has completed before we grabbed the lock. - CheckThrow(authSuccessCheck: true); + ThrowIfExceptionalOrNotAuthenticated(); return; } @@ -1079,7 +1044,7 @@ namespace System.Net.Security // Need to exit from lock before waiting. lazyResult.InternalWaitForCompletion(); - CheckThrow(authSuccessCheck: true); + ThrowIfExceptionalOrNotAuthenticated(); return; } @@ -1500,7 +1465,7 @@ namespace System.Net.Security private async Task WriteAsyncInternal(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) where TWriteAdapter : struct, ISslWriteAdapter { - CheckThrow(authSuccessCheck: true, shutdownCheck: true); + ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); if (buffer.Length == 0 && !SslStreamPal.CanEncryptEmptyMessage) { diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index bc5aee8..f9d2dc1 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.IO; +using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; @@ -40,6 +41,9 @@ namespace System.Net.Security public partial class SslStream : AuthenticatedStream { + /// Set as the _exception when the instance is disposed. + private static readonly ExceptionDispatchInfo s_disposedSentinel = ExceptionDispatchInfo.Capture(new ObjectDisposedException(nameof(SslStream), (string)null)); + private X509Certificate2 _remoteCertificate; private bool _remoteCertificateExposed; @@ -285,7 +289,7 @@ namespace System.Net.Security internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState) { - CheckThrow(authSuccessCheck: true, shutdownCheck: true); + ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); ProtocolToken message = _context.CreateShutdownToken(); return TaskToApm.Begin(InnerStream.WriteAsync(message.Payload, 0, message.Payload.Length), asyncCallback, asyncState); @@ -293,7 +297,7 @@ namespace System.Net.Security internal void EndShutdown(IAsyncResult asyncResult) { - CheckThrow(authSuccessCheck: true, shutdownCheck: true); + ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); TaskToApm.End(asyncResult); _shutdown = true; @@ -471,7 +475,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslConnectionInfo info = _context.ConnectionInfo; if (info == null) { @@ -527,7 +531,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); return _context.IsServer ? _context.LocalServerCertificate : _context.LocalClientCertificate; } } @@ -536,7 +540,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); _remoteCertificateExposed = true; return _remoteCertificate; } @@ -547,7 +551,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); return _context.ConnectionInfo?.TlsCipherSuite ?? default(TlsCipherSuite); } } @@ -556,7 +560,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslConnectionInfo info = _context.ConnectionInfo; if (info == null) { @@ -570,7 +574,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslConnectionInfo info = _context.ConnectionInfo; if (info == null) { @@ -585,7 +589,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslConnectionInfo info = _context.ConnectionInfo; if (info == null) { @@ -599,7 +603,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslConnectionInfo info = _context.ConnectionInfo; if (info == null) { @@ -614,7 +618,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslConnectionInfo info = _context.ConnectionInfo; if (info == null) { @@ -629,7 +633,7 @@ namespace System.Net.Security { get { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslConnectionInfo info = _context.ConnectionInfo; if (info == null) { @@ -711,7 +715,7 @@ namespace System.Net.Security public override int ReadByte() { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); if (Interlocked.Exchange(ref _nestedRead, 1) == 1) { throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "ReadByte", "read")); @@ -747,7 +751,7 @@ namespace System.Net.Security public override int Read(byte[] buffer, int offset, int count) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); ValidateParameters(buffer, offset, count); SslReadSync reader = new SslReadSync(this); return ReadAsyncInternal(reader, new Memory(buffer, offset, count)).GetAwaiter().GetResult(); @@ -757,7 +761,7 @@ namespace System.Net.Security public override void Write(byte[] buffer, int offset, int count) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); ValidateParameters(buffer, offset, count); SslWriteSync writeAdapter = new SslWriteSync(this); @@ -766,45 +770,45 @@ namespace System.Net.Security public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); return TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState); } public override int EndRead(IAsyncResult asyncResult) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); return TaskToApm.End(asyncResult); } public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); return TaskToApm.Begin(WriteAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState); } public override void EndWrite(IAsyncResult asyncResult) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); TaskToApm.End(asyncResult); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); ValidateParameters(buffer, offset, count); return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslWriteAsync writeAdapter = new SslWriteAsync(this, cancellationToken); return new ValueTask(WriteAsyncInternal(writeAdapter, buffer)); } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); ValidateParameters(buffer, offset, count); SslReadAsync read = new SslReadAsync(this, cancellationToken); return ReadAsyncInternal(read, new Memory(buffer, offset, count)).AsTask(); @@ -812,9 +816,71 @@ namespace System.Net.Security public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - CheckThrow(true); + ThrowIfExceptionalOrNotAuthenticated(); SslReadAsync read = new SslReadAsync(this, cancellationToken); return ReadAsyncInternal(read, buffer); } + + private void ThrowIfExceptional() + { + ExceptionDispatchInfo e = _exception; + if (e != null) + { + ThrowExceptional(e); + } + + // Local function to make the check method more inline friendly. + static void ThrowExceptional(ExceptionDispatchInfo e) + { + // If the stored exception just indicates disposal, throw a new ODE rather than the stored one, + // so as to not continually build onto the shared exception's stack. + if (ReferenceEquals(e, s_disposedSentinel)) + { + throw new ObjectDisposedException(nameof(SslStream)); + } + + // Throw the stored exception. + e.Throw(); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ThrowIfExceptionalOrNotAuthenticated() + { + ThrowIfExceptional(); + + if (!IsAuthenticated) + { + ThrowNotAuthenticated(); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ThrowIfExceptionalOrNotAuthenticatedOrShutdown() + { + ThrowIfExceptional(); + + if (!IsAuthenticated) + { + ThrowNotAuthenticated(); + } + + if (_shutdown) + { + ThrowAlreadyShutdown(); + } + + // Local function to make the check method more inline friendly. + static void ThrowAlreadyShutdown() + { + throw new InvalidOperationException(SR.net_ssl_io_already_shutdown); + } + } + + // Static non-returning throw method to make the check methods more inline friendly. + private static void ThrowNotAuthenticated() + { + throw new InvalidOperationException(SR.net_auth_noauth); + } } } -- 2.7.4