From 844618c70cb99fdabbd748860f7fbef465884307 Mon Sep 17 00:00:00 2001 From: Tomas Weinfurt Date: Thu, 13 Feb 2020 15:07:49 -0800 Subject: [PATCH] use proper IO in ssl handshake (#32013) * use proper IO in ssl handshake * fix UnitTests/Fakes * feedback from review * rename adapters to SyncSslIOAdapter and AsyncSslIOAdapter * feedback from review --- .../Security/SslStream.Implementation.Adapters.cs | 64 ++--- .../Net/Security/SslStream.Implementation.cs | 290 +++++---------------- .../src/System/Net/Security/SslStream.cs | 10 +- .../FunctionalTests/SslStreamNetworkStreamTest.cs | 17 +- .../Fakes/FakeSslStream.Implementation.cs | 2 +- 5 files changed, 104 insertions(+), 279 deletions(-) diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs index b5d7ecb..acc6bd6 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs @@ -9,26 +9,21 @@ namespace System.Net.Security // This contains adapters to allow a single code path for sync/async logic public partial class SslStream { - private interface ISslWriteAdapter - { - Task LockAsync(); - ValueTask WriteAsync(byte[] buffer, int offset, int count); - CancellationToken CancellationToken { get; } - } - - private interface ISslReadAdapter + private interface ISslIOAdapter { ValueTask ReadAsync(byte[] buffer, int offset, int count); - ValueTask LockAsync(Memory buffer); + ValueTask ReadLockAsync(Memory buffer); + Task WriteLockAsync(); + ValueTask WriteAsync(byte[] buffer, int offset, int count); CancellationToken CancellationToken { get; } } - private readonly struct SslReadAsync : ISslReadAdapter + private readonly struct AsyncSslIOAdapter : ISslIOAdapter { private readonly SslStream _sslStream; private readonly CancellationToken _cancellationToken; - public SslReadAsync(SslStream sslStream, CancellationToken cancellationToken) + public AsyncSslIOAdapter(SslStream sslStream, CancellationToken cancellationToken) { _cancellationToken = cancellationToken; _sslStream = sslStream; @@ -36,60 +31,37 @@ namespace System.Net.Security public ValueTask ReadAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.ReadAsync(new Memory(buffer, offset, count), _cancellationToken); - public ValueTask LockAsync(Memory buffer) => _sslStream.CheckEnqueueReadAsync(buffer); + public ValueTask ReadLockAsync(Memory buffer) => _sslStream.CheckEnqueueReadAsync(buffer); + + public Task WriteLockAsync() => _sslStream.CheckEnqueueWriteAsync(); + + public ValueTask WriteAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.WriteAsync(new ReadOnlyMemory(buffer, offset, count), _cancellationToken); public CancellationToken CancellationToken => _cancellationToken; } - private readonly struct SslReadSync : ISslReadAdapter + private readonly struct SyncSslIOAdapter : ISslIOAdapter { private readonly SslStream _sslStream; - public SslReadSync(SslStream sslStream) => _sslStream = sslStream; + public SyncSslIOAdapter(SslStream sslStream) => _sslStream = sslStream; public ValueTask ReadAsync(byte[] buffer, int offset, int count) => new ValueTask(_sslStream.InnerStream.Read(buffer, offset, count)); - public ValueTask LockAsync(Memory buffer) => new ValueTask(_sslStream.CheckEnqueueRead(buffer)); + public ValueTask ReadLockAsync(Memory buffer) => new ValueTask(_sslStream.CheckEnqueueRead(buffer)); - public CancellationToken CancellationToken => default; - } - - private readonly struct SslWriteAsync : ISslWriteAdapter - { - private readonly SslStream _sslStream; - private readonly CancellationToken _cancellationToken; - - public SslWriteAsync(SslStream sslStream, CancellationToken cancellationToken) + public ValueTask WriteAsync(byte[] buffer, int offset, int count) { - _sslStream = sslStream; - _cancellationToken = cancellationToken; + _sslStream.InnerStream.Write(buffer, offset, count); + return default; } - public Task LockAsync() => _sslStream.CheckEnqueueWriteAsync(); - - public ValueTask WriteAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.WriteAsync(new ReadOnlyMemory(buffer, offset, count), _cancellationToken); - - public CancellationToken CancellationToken => _cancellationToken; - } - - private readonly struct SslWriteSync : ISslWriteAdapter - { - private readonly SslStream _sslStream; - - public SslWriteSync(SslStream sslStream) => _sslStream = sslStream; - - public Task LockAsync() + public Task WriteLockAsync() { _sslStream.CheckEnqueueWrite(); return Task.CompletedTask; } - public ValueTask WriteAsync(byte[] buffer, int offset, int count) - { - _sslStream.InnerStream.Write(buffer, offset, count); - return default; - } - public CancellationToken CancellationToken => default; } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs index 97ed6dc..34e6542 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs @@ -22,17 +22,6 @@ namespace System.Net.Security private int _nestedAuth; - private SecurityStatusPal _securityStatus; - - private enum CachedSessionStatus : byte - { - Unknown = 0, - IsNotCached = 1, - IsCached = 2, - Renegotiated = 3 - } - private CachedSessionStatus _CachedSession; - private enum Framing { Unknown = 0, @@ -205,41 +194,16 @@ namespace System.Net.Security private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) { Task result = null; - if (Interlocked.Exchange(ref _nestedAuth, 1) == 1) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate")); - } - - try - { - ThrowIfExceptional(); - // A trick to discover and avoid cached sessions. - _CachedSession = CachedSessionStatus.Unknown; - - if (isAsync) - { - result = ForceAuthenticationAsync(_context.IsServer, null, cancellationToken); - } - else - { - ForceAuthentication(_context.IsServer, null); + ThrowIfExceptional(); - if (NetEventSource.IsEnabled) - NetEventSource.Log.SspiSelectedCipherSuite(nameof(ProcessAuthentication), - SslProtocol, - CipherAlgorithm, - CipherStrength, - HashAlgorithm, - HashStrength, - KeyExchangeAlgorithm, - KeyExchangeStrength); - } + if (isAsync) + { + result = ForceAuthenticationAsync(new AsyncSslIOAdapter(this, cancellationToken), _context.IsServer, null, isApm); } - finally + else { - // Operation has completed. - _nestedAuth = 0; + ForceAuthenticationAsync(new SyncSslIOAdapter(this), _context.IsServer, null).GetAwaiter().GetResult(); } return result; @@ -248,7 +212,8 @@ namespace System.Net.Security // // This is used to reply on re-handshake when received SEC_I_RENEGOTIATE on Read(). // - private async Task ReplyOnReAuthenticationAsync(byte[] buffer, CancellationToken cancellationToken) + private async Task ReplyOnReAuthenticationAsync(TIOAdapter adapter, byte[] buffer) + where TIOAdapter : ISslIOAdapter { lock (SyncLock) { @@ -256,93 +221,71 @@ namespace System.Net.Security _lockReadState = LockHandshake; } - await ForceAuthenticationAsync(false, buffer, cancellationToken).ConfigureAwait(false); + await ForceAuthenticationAsync(adapter, receiveFirst: false, buffer).ConfigureAwait(false); FinishHandshakeRead(LockNone); } - // - // This method attempts to start authentication. - // Incoming buffer is either null or is the result of "renegotiate" decrypted message - // If write is in progress the method will either wait or be put on hold - // - private void ForceAuthentication(bool receiveFirst, byte[] buffer) + // reAuthenticationData is only used on Windows in case of renegotiation. + private async Task ForceAuthenticationAsync(TIOAdapter adapter, bool receiveFirst, byte[] reAuthenticationData, bool isApm = false) + where TIOAdapter : ISslIOAdapter { - // This will tell that we don't know the framing yet (what SSL version is) _framing = Framing.Unknown; + ProtocolToken message; - try + if (reAuthenticationData == null) { - if (receiveFirst) - { - // Listen for a client blob. - ReceiveBlob(buffer); - } - else + // prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation tansparently. + if (Interlocked.Exchange(ref _nestedAuth, 1) == 1) { - // We start with the first blob. - SendBlob(buffer, (buffer == null ? 0 : buffer.Length)); + throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate")); } } - catch (Exception e) + + try { - // Failed auth, reset the framing if any. - _framing = Framing.Unknown; - _handshakeCompleted = false; - SetException(e); - if (_exception.SourceException != e) + if (!receiveFirst) { - ThrowIfExceptional(); + message = _context.NextMessage(reAuthenticationData, 0, (reAuthenticationData == null ? 0 : reAuthenticationData.Length)); + if (message.Size > 0) + { + await adapter.WriteAsync(message.Payload, 0, message.Size).ConfigureAwait(false); + } + + if (message.Failed) + { + // tracing done in NextMessage() + throw new AuthenticationException(SR.net_auth_SSPI, message.GetException()); + } } - throw; - } - finally - { - if (_exception != null) + + do { - // This a failed handshake. Release waiting IO if any. - FinishHandshake(null); - } - } - } + message = await ReceiveBlobAsync(adapter).ConfigureAwait(false); + if (message.Size > 0) + { + // If there is message send it out even if call failed. It may contain TLS Alert. + await adapter.WriteAsync(message.Payload, 0, message.Size).ConfigureAwait(false); + } - internal async Task ForceAuthenticationAsync(bool receiveFirst, byte[] buffer, CancellationToken cancellationToken) - { - _framing = Framing.Unknown; - ProtocolToken message; - SslReadAsync adapter = new SslReadAsync(this, cancellationToken); + if (message.Failed) + { + throw new AuthenticationException(SR.net_auth_SSPI, message.GetException()); + } + } while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK); - if (!receiveFirst) - { - message = _context.NextMessage(buffer, 0, (buffer == null ? 0 : buffer.Length)); - if (message.Failed) + ProtocolToken alertToken = null; + if (!CompleteHandshake(ref alertToken)) { - // tracing done in NextMessage() - throw new AuthenticationException(SR.net_auth_SSPI, message.GetException()); + SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null))); } - - await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false); } - - do + finally { - message = await ReceiveBlobAsync(adapter, buffer, cancellationToken).ConfigureAwait(false); - if (message.Size > 0) - { - // If there is message send it out even if call failed. It may contain TLS Alert. - await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false); - } - - if (message.Failed) + if (reAuthenticationData == null) { - throw new AuthenticationException(SR.net_auth_SSPI, message.GetException()); + _nestedAuth = 0; } - } while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK); - - ProtocolToken alertToken = null; - if (!CompleteHandshake(ref alertToken)) - { - SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null))); } if (NetEventSource.IsEnabled) @@ -357,107 +300,8 @@ namespace System.Net.Security } - // - // Client side starts here, but server also loops through this method. - // - private void SendBlob(byte[] incoming, int count) - { - ProtocolToken message = _context.NextMessage(incoming, 0, count); - _securityStatus = message.Status; - - if (message.Size != 0) - { - if (_context.IsServer && _CachedSession == CachedSessionStatus.Unknown) - { - // - //[Schannel] If the first call to ASC returns a token less than 200 bytes, - // then it's a reconnect (a handshake based on a cache entry). - // - _CachedSession = message.Size < 200 ? CachedSessionStatus.IsCached : CachedSessionStatus.IsNotCached; - } - - if (_framing == Framing.Unified) - { - _framing = DetectFraming(message.Payload, message.Payload.Length); - } - - InnerStream.Write(message.Payload, 0, message.Size); - } - - CheckCompletionBeforeNextReceive(message); - } - - // - // This will check and logically complete / fail the auth handshake. - // - private void CheckCompletionBeforeNextReceive(ProtocolToken message) - { - if (message.Failed) - { - SendAuthResetSignal(null, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_auth_SSPI, message.GetException()))); - return; - } - else if (message.Done) - { - ProtocolToken alertToken = null; - - if (!CompleteHandshake(ref alertToken)) - { - SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null))); - return; - } - - // Release waiting IO if any. Presumably it should not throw. - // Otherwise application may get not expected type of the exception. - FinishHandshake(null); - return; - } - - ReceiveBlob(message.Payload); - } - - // - // Server side starts here, but client also loops through this method. - // - private void ReceiveBlob(byte[] buffer) - { - //This is first server read. - buffer = EnsureBufferSize(buffer, 0, SecureChannel.ReadHeaderSize); - - int readBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, 0, SecureChannel.ReadHeaderSize); - - if (readBytes == 0) - { - // EOF received - throw new IOException(SR.net_auth_eof); - } - - if (_framing == Framing.Unknown) - { - _framing = DetectFraming(buffer, readBytes); - } - - int restBytes = GetRemainingFrameSize(buffer, 0, readBytes); - - if (restBytes < 0) - { - throw new IOException(SR.net_ssl_io_frame); - } - - if (restBytes == 0) - { - // EOF received - throw new AuthenticationException(SR.net_auth_eof, null); - } - - buffer = EnsureBufferSize(buffer, readBytes, readBytes + restBytes); - - restBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, readBytes, restBytes); - - SendBlob(buffer, readBytes + restBytes); - } - - private async ValueTask ReceiveBlobAsync(SslReadAsync adapter, byte[] buffer, CancellationToken cancellationToken) + private async ValueTask ReceiveBlobAsync(TIOAdapter adapter) + where TIOAdapter : ISslIOAdapter { ResetReadBuffer(); int readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false); @@ -710,8 +554,8 @@ namespace System.Net.Security } } - private async ValueTask WriteAsyncChunked(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) - where TWriteAdapter : struct, ISslWriteAdapter + private async ValueTask WriteAsyncChunked(TIOAdapter writeAdapter, ReadOnlyMemory buffer) + where TIOAdapter : struct, ISslIOAdapter { do { @@ -721,11 +565,11 @@ namespace System.Net.Security } while (buffer.Length != 0); } - private ValueTask WriteSingleChunk(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) - where TWriteAdapter : struct, ISslWriteAdapter + private ValueTask WriteSingleChunk(TIOAdapter writeAdapter, ReadOnlyMemory buffer) + where TIOAdapter : struct, ISslIOAdapter { // Request a write IO slot. - Task ioSlot = writeAdapter.LockAsync(); + Task ioSlot = writeAdapter.WriteLockAsync(); if (!ioSlot.IsCompletedSuccessfully) { // Operation is async and has been queued, return. @@ -757,7 +601,7 @@ namespace System.Net.Security return CompleteAsync(t, rentedBuffer); } - async ValueTask WaitForWriteIOSlot(TWriteAdapter wAdapter, Task lockTask, ReadOnlyMemory buff) + async ValueTask WaitForWriteIOSlot(TIOAdapter wAdapter, Task lockTask, ReadOnlyMemory buff) { await lockTask.ConfigureAwait(false); await WriteSingleChunk(wAdapter, buff).ConfigureAwait(false); @@ -823,8 +667,8 @@ namespace System.Net.Security } } - private async ValueTask ReadAsyncInternal(TReadAdapter adapter, Memory buffer) - where TReadAdapter : ISslReadAdapter + private async ValueTask ReadAsyncInternal(TIOAdapter adapter, Memory buffer) + where TIOAdapter : ISslIOAdapter { if (Interlocked.Exchange(ref _nestedRead, 1) == 1) { @@ -843,7 +687,7 @@ namespace System.Net.Security return copyBytes; } - copyBytes = await adapter.LockAsync(buffer).ConfigureAwait(false); + copyBytes = await adapter.ReadLockAsync(buffer).ConfigureAwait(false); if (copyBytes > 0) { return copyBytes; @@ -903,7 +747,7 @@ namespace System.Net.Security throw new IOException(SR.net_ssl_io_renego); } - await ReplyOnReAuthenticationAsync(extraBuffer, adapter.CancellationToken).ConfigureAwait(false); + await ReplyOnReAuthenticationAsync(adapter, extraBuffer).ConfigureAwait(false); // Loop on read. continue; } @@ -932,8 +776,8 @@ namespace System.Net.Security } } - private ValueTask FillBufferAsync(TReadAdapter adapter, int minSize) - where TReadAdapter : ISslReadAdapter + private ValueTask FillBufferAsync(TIOAdapter adapter, int minSize) + where TIOAdapter : ISslIOAdapter { if (_internalBufferCount >= minSize) { @@ -965,7 +809,7 @@ namespace System.Net.Security return new ValueTask(minSize); - async ValueTask InternalFillBufferAsync(TReadAdapter adap, ValueTask task, int min, int initial) + async ValueTask InternalFillBufferAsync(TIOAdapter adap, ValueTask task, int min, int initial) { while (true) { @@ -991,8 +835,8 @@ namespace System.Net.Security } } - private async ValueTask WriteAsyncInternal(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) - where TWriteAdapter : struct, ISslWriteAdapter + private async ValueTask WriteAsyncInternal(TIOAdapter writeAdapter, ReadOnlyMemory buffer) + where TIOAdapter : struct, ISslIOAdapter { ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index 8dc2e8b..d848210 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -741,7 +741,7 @@ namespace System.Net.Security { ThrowIfExceptionalOrNotAuthenticated(); ValidateParameters(buffer, offset, count); - SslReadSync reader = new SslReadSync(this); + SyncSslIOAdapter reader = new SyncSslIOAdapter(this); ValueTask vt = ReadAsyncInternal(reader, new Memory(buffer, offset, count)); Debug.Assert(vt.IsCompleted, "Sync operation must have completed synchronously"); return vt.GetAwaiter().GetResult(); @@ -754,7 +754,7 @@ namespace System.Net.Security ThrowIfExceptionalOrNotAuthenticated(); ValidateParameters(buffer, offset, count); - SslWriteSync writeAdapter = new SslWriteSync(this); + SyncSslIOAdapter writeAdapter = new SyncSslIOAdapter(this); ValueTask vt = WriteAsyncInternal(writeAdapter, new ReadOnlyMemory(buffer, offset, count)); Debug.Assert(vt.IsCompleted, "Sync operation must have completed synchronously"); vt.GetAwaiter().GetResult(); @@ -794,7 +794,7 @@ namespace System.Net.Security public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { ThrowIfExceptionalOrNotAuthenticated(); - SslWriteAsync writeAdapter = new SslWriteAsync(this, cancellationToken); + AsyncSslIOAdapter writeAdapter = new AsyncSslIOAdapter(this, cancellationToken); return WriteAsyncInternal(writeAdapter, buffer); } @@ -802,14 +802,14 @@ namespace System.Net.Security { ThrowIfExceptionalOrNotAuthenticated(); ValidateParameters(buffer, offset, count); - SslReadAsync read = new SslReadAsync(this, cancellationToken); + AsyncSslIOAdapter read = new AsyncSslIOAdapter(this, cancellationToken); return ReadAsyncInternal(read, new Memory(buffer, offset, count)).AsTask(); } public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { ThrowIfExceptionalOrNotAuthenticated(); - SslReadAsync read = new SslReadAsync(this, cancellationToken); + AsyncSslIOAdapter read = new AsyncSslIOAdapter(this, cancellationToken); return ReadAsyncInternal(read, buffer); } diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs index cf47a4c..8aaf543 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs @@ -129,9 +129,11 @@ namespace System.Net.Security.Tests listener.Stop(); } - [Fact] + [Theory] + [InlineData(false)] + [InlineData(true)] [OuterLoop] // Test hits external azure server. - public async Task SslStream_NetworkStream_Renegotiation_Succeeds() + public async Task SslStream_NetworkStream_Renegotiation_Succeeds(bool useSync) { int validationCount = 0; @@ -156,10 +158,17 @@ namespace System.Net.Security.Tests // Issue request that triggers regotiation from server. byte[] message = Encoding.UTF8.GetBytes("GET /EchoClientCertificate.ashx HTTP/1.1\r\nHost: corefx-net-tls.azurewebsites.net\r\n\r\n"); - await ssl.WriteAsync(message, 0, message.Length); + if (useSync) + { + ssl.Write(message, 0, message.Length); + } + else + { + await ssl.WriteAsync(message, 0, message.Length); + } // Initiate Read operation, that results in starting renegotiation as per server response to the above request. - int bytesRead = await ssl.ReadAsync(message, 0, message.Length); + int bytesRead = useSync ? ssl.Read(message, 0, message.Length) : await ssl.ReadAsync(message, 0, message.Length); // renegotiation will trigger validation callback again. Assert.InRange(validationCount, 2, int.MaxValue); diff --git a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs index 4cd25bd..85eb337 100644 --- a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs @@ -38,7 +38,7 @@ namespace System.Net.Security } private ValueTask WriteAsyncInternal(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) - where TWriteAdapter : struct, ISslWriteAdapter => default; + where TWriteAdapter : struct, ISslIOAdapter => default; private ValueTask ReadAsyncInternal(TReadAdapter adapter, Memory buffer) => default; -- 2.7.4