add zero byte read to SslStream (#87563)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Mon, 19 Jun 2023 18:37:24 +0000 (20:37 +0200)
committerGitHub <noreply@github.com>
Mon, 19 Jun 2023 18:37:24 +0000 (20:37 +0200)
* add zero byte read to SslStream

* fix test

* Apply suggestions from code review

Co-authored-by: Stephen Toub <stoub@microsoft.com>
* feedback

* add back missing line

---------

Co-authored-by: Stephen Toub <stoub@microsoft.com>
src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs
src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamConformanceTests.cs

index 5a2c897..bdc4027 100644 (file)
@@ -2739,6 +2739,8 @@ namespace System.IO.Tests
         /// </summary>
         protected virtual bool ZeroByteReadPerformsZeroByteReadOnUnderlyingStream => false;
 
+        protected virtual bool ExtraZeroByteReadsAllowed => false;
+
         [Theory]
         [InlineData(false)]
         [InlineData(true)]
@@ -2938,7 +2940,7 @@ namespace System.IO.Tests
             using StreamPair innerStreams = ConnectedStreams.CreateBidirectional();
             (Stream innerWriteable, Stream innerReadable) = GetReadWritePair(innerStreams);
 
-            var tracker = new ZeroByteReadTrackingStream(innerReadable);
+            var tracker = new ZeroByteReadTrackingStream(innerReadable, ExtraZeroByteReadsAllowed);
             using StreamPair streams = await CreateWrappedConnectedStreamsAsync((innerWriteable, tracker));
 
             (Stream writeable, Stream readable) = GetReadWritePair(streams);
@@ -2993,9 +2995,11 @@ namespace System.IO.Tests
         private sealed class ZeroByteReadTrackingStream : DelegatingStream
         {
             private TaskCompletionSource? _signal;
+            private bool _extraZeroByteReadsAllowed;
 
-            public ZeroByteReadTrackingStream(Stream innerStream) : base(innerStream)
+            public ZeroByteReadTrackingStream(Stream innerStream, bool extraZeroByteReadsAllowed = false) : base(innerStream)
             {
+                _extraZeroByteReadsAllowed = extraZeroByteReadsAllowed;
             }
 
             public Task WaitForZeroByteReadAsync()
@@ -3014,13 +3018,13 @@ namespace System.IO.Tests
                 if (bufferLength == 0)
                 {
                     var signal = _signal;
-                    if (signal is null)
+                    if (signal is null && !_extraZeroByteReadsAllowed)
                     {
                         throw new Exception("Unexpected zero byte read");
                     }
 
                     _signal = null;
-                    signal.SetResult();
+                    signal?.SetResult();
                 }
             }
 
index a5312b1..74ab208 100644 (file)
@@ -123,7 +123,12 @@ namespace System.Net.Http.Functional.Tests
 
                 using HttpResponseMessage response = await clientTask.WaitAsync(TestHelper.PassingTestTimeout);
                 using Stream clientStream = response.Content.ReadAsStream();
-                Assert.False(sawZeroByteRead.Task.IsCompleted);
+
+                if (!useSsl)
+                {
+                    // SslStream does zero byte reads under the covers
+                    Assert.False(sawZeroByteRead.Task.IsCompleted);
+                }
 
                 Task<int> zeroByteReadTask = Task.Run(() => StreamConformanceTests.ReadAsync(readMode, clientStream, Array.Empty<byte>(), 0, 0, CancellationToken.None));
                 Assert.False(zeroByteReadTask.IsCompleted);
index 7ea62aa..d7438c0 100644 (file)
@@ -27,6 +27,8 @@ namespace System.Net.Security
         private const int HandshakeTypeOffsetSsl2 = 2;                       // Offset of HelloType in Sslv2 and Unified frames
         private const int HandshakeTypeOffsetTls = 5;                        // Offset of HelloType in Sslv3 and TLS frames
 
+        private const int UnknownTlsFrameLength = int.MaxValue;              // frame too short to determine length
+
         private bool _receivedEOF;
 
         // Used by Telemetry to ensure we log connection close exactly once
@@ -211,12 +213,10 @@ namespace System.Net.Security
                     throw SslStreamPal.GetException(status);
                 }
 
-                _buffer.EnsureAvailableSpace(InitialHandshakeBufferSize);
-
                 ProtocolToken message;
                 do
                 {
-                    int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+                    int frameSize = await ReceiveHandshakeFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
                     ProcessTlsFrame(frameSize, out message);
 
                     if (message.Size > 0)
@@ -291,7 +291,7 @@ namespace System.Net.Security
 
                 while (!handshakeCompleted)
                 {
-                    int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+                    int frameSize = await ReceiveHandshakeFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
                     ProcessTlsFrame(frameSize, out message);
 
                     ReadOnlyMemory<byte> payload = default;
@@ -359,10 +359,10 @@ namespace System.Net.Security
         }
 
         // This method will make sure we have at least one full TLS frame buffered.
-        private async ValueTask<int> ReceiveTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
+        private async ValueTask<int> ReceiveHandshakeFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
             where TIOAdapter : IReadWriteAdapter
         {
-            int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+            int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken, InitialHandshakeBufferSize).ConfigureAwait(false);
 
             if (frameSize == 0)
             {
@@ -699,38 +699,27 @@ namespace System.Net.Security
 
         private bool HaveFullTlsFrame(out int frameSize)
         {
-            if (_buffer.EncryptedLength < TlsFrameHelper.HeaderSize)
-            {
-                frameSize = int.MaxValue;
-                return false;
-            }
-
             frameSize = GetFrameSize(_buffer.EncryptedReadOnlySpan);
             return _buffer.EncryptedLength >= frameSize;
         }
 
         [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
-        private async ValueTask<int> EnsureFullTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
+        private async ValueTask<int> EnsureFullTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken, int estimatedSize)
             where TIOAdapter : IReadWriteAdapter
         {
-            int frameSize;
-            if (HaveFullTlsFrame(out frameSize))
+            if (HaveFullTlsFrame(out int frameSize))
             {
                 return frameSize;
             }
 
-            if (frameSize != int.MaxValue)
-            {
-                // make sure we have space for the whole frame
-                _buffer.EnsureAvailableSpace(frameSize - _buffer.EncryptedLength);
-            }
-            else
-            {
-                // move existing data to the beginning of the buffer (they will
-                // be couple of bytes only, otherwise we would have entire
-                // header and know exact size)
-                _buffer.EnsureAvailableSpace(_buffer.Capacity - _buffer.EncryptedLength);
-            }
+            await TIOAdapter.ReadAsync(InnerStream, Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);
+
+            // If we don't have enough data to determine the frame size, use the provided estimate
+            // (e.g. a full TLS frame for reads, and a somewhat shorter frame for handshake / renegotiation).
+            // If we do know the frame size, ensure we have space for the whole frame.
+            _buffer.EnsureAvailableSpace(frameSize == UnknownTlsFrameLength ?
+                estimatedSize :
+                frameSize - _buffer.EncryptedLength);
 
             while (_buffer.EncryptedLength < frameSize)
             {
@@ -806,6 +795,7 @@ namespace System.Net.Security
         private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer, CancellationToken cancellationToken)
             where TIOAdapter : IReadWriteAdapter
         {
+
             // Throw first if we already have exception.
             // Check for disposal is not atomic so we will check again below.
             ThrowIfExceptionalOrNotAuthenticated();
@@ -819,11 +809,12 @@ namespace System.Net.Security
             try
             {
                 int processedLength = 0;
+                int nextTlsFrameLength = UnknownTlsFrameLength;
 
                 if (_buffer.DecryptedLength != 0)
                 {
                     processedLength = CopyDecryptedData(buffer);
-                    if (processedLength == buffer.Length || !HaveFullTlsFrame(out _))
+                    if (processedLength == buffer.Length || !HaveFullTlsFrame(out nextTlsFrameLength))
                     {
                         // We either filled whole buffer or used all buffered frames.
                         return processedLength;
@@ -832,32 +823,19 @@ namespace System.Net.Security
                     buffer = buffer.Slice(processedLength);
                 }
 
-                if (_receivedEOF)
+                if (_receivedEOF && nextTlsFrameLength == UnknownTlsFrameLength)
                 {
+                    // there should be no frames waiting for processing
                     Debug.Assert(_buffer.EncryptedLength == 0);
                     // We received EOF during previous read but had buffered data to return.
                     return 0;
                 }
 
-                if (buffer.Length == 0 && _buffer.ActiveLength == 0)
-                {
-                    // User requested a zero-byte read, and we have no data available in the buffer for processing.
-                    // This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read
-                    // for reduced memory consumption when data is not immediately available.
-                    // So, we will issue our own zero-byte read against the underlying stream and defer buffer allocation
-                    // until data is actually available from the underlying stream.
-                    // Note that if the underlying stream does not supporting blocking on zero byte reads, then this will
-                    // complete immediately and won't save any memory, but will still function correctly.
-                    await TIOAdapter.ReadAsync(InnerStream, Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);
-                }
-
                 Debug.Assert(_buffer.DecryptedLength == 0);
 
-                _buffer.EnsureAvailableSpace(ReadBufferSize - _buffer.ActiveLength);
-
                 while (true)
                 {
-                    int payloadBytes = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+                    int payloadBytes = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken, ReadBufferSize).ConfigureAwait(false);
                     if (payloadBytes == 0)
                     {
                         _receivedEOF = true;
@@ -1009,6 +987,11 @@ namespace System.Net.Security
         // Returns TLS Frame size including header size.
         private int GetFrameSize(ReadOnlySpan<byte> buffer)
         {
+            if (buffer.Length < TlsFrameHelper.HeaderSize)
+            {
+                return UnknownTlsFrameLength;
+            }
+
             if (!TlsFrameHelper.TryGetFrameHeader(buffer, ref _lastFrame.Header))
             {
                 throw new IOException(SR.net_ssl_io_frame);
index 6b47328..c89c480 100644 (file)
@@ -16,6 +16,7 @@ namespace System.Net.Security.Tests
         protected override bool BlocksOnZeroByteReads => true;
         protected override bool ZeroByteReadPerformsZeroByteReadOnUnderlyingStream => true;
         protected override Type UnsupportedConcurrentExceptionType => typeof(NotSupportedException);
+        protected override bool ExtraZeroByteReadsAllowed => true;
 
         protected virtual SslProtocols GetSslProtocols() => SslProtocols.None;