Support zero-byte reads on HTTP response streams (#61913)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Mon, 29 Nov 2021 23:11:38 +0000 (15:11 -0800)
committerGitHub <noreply@github.com>
Mon, 29 Nov 2021 23:11:38 +0000 (15:11 -0800)
* Allow zero byte reads on raw HTTP/1.1 response streams

* Allow zero byte reads on the rest of HTTP/1.1 response streams

* Allow zero byte reads on HTTP/2 response streams

* Allow zero byte reads on HTTP/3 response streams

* Enable sync zero-byte reads

* Add zero-byte read tests

* Fully enable zero-byte reads on HTTP/2 and 3

* Add zero-byte read tests for HTTP/2 and HTTP/3

* Remove unsafe-ish code from PeekChunkFromConnectionBuffer

* Add comments when we do extra zero-byte reads

* Update MockQuicStreamConformanceTests to allow zero-byte reads

* Update ConnectedStream tests to allow zero-byte reads

* Skip zero-byte read tests on Browser

* Update comment on explicit zero-byte reads in ChunkedEncodingReadStream

14 files changed:
src/libraries/Common/src/System/Net/StreamBuffer.cs
src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs
src/libraries/Common/tests/Tests/System/IO/ConnectedStreamsTests.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs
src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamConformanceTests.cs
src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs [new file with mode: 0644]
src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj
src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs

index 6759fcd..32bc0f3 100644 (file)
@@ -192,8 +192,6 @@ namespace System.IO
 
         private (bool wait, int bytesRead) TryReadFromBuffer(Span<byte> buffer)
         {
-            Debug.Assert(buffer.Length > 0);
-
             Debug.Assert(!Monitor.IsEntered(SyncObject));
             lock (SyncObject)
             {
@@ -225,11 +223,6 @@ namespace System.IO
 
         public int Read(Span<byte> buffer)
         {
-            if (buffer.Length == 0)
-            {
-                return 0;
-            }
-
             (bool wait, int bytesRead) = TryReadFromBuffer(buffer);
             if (wait)
             {
@@ -246,11 +239,6 @@ namespace System.IO
         {
             cancellationToken.ThrowIfCancellationRequested();
 
-            if (buffer.Length == 0)
-            {
-                return 0;
-            }
-
             (bool wait, int bytesRead) = TryReadFromBuffer(buffer.Span);
             if (wait)
             {
index 43451d4..7dc9519 100644 (file)
@@ -114,7 +114,7 @@ namespace System.IO.Tests
             from mode in Enum.GetValues<SeekMode>()
             select new object[] { mode, value };
 
-        protected async Task<int> ReadAsync(ReadWriteMode mode, Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
+        public static async Task<int> ReadAsync(ReadWriteMode mode, Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
         {
             if (mode == ReadWriteMode.SyncByte)
             {
index 0f700bf..0d69a89 100644 (file)
@@ -9,6 +9,7 @@ namespace System.IO.Tests
     {
         protected override int BufferedSize => StreamBuffer.DefaultMaxBufferSize;
         protected override bool FlushRequiredToWriteData => false;
+        protected override bool BlocksOnZeroByteReads => true;
 
         protected override Task<StreamPair> CreateConnectedStreamsAsync() =>
             Task.FromResult<StreamPair>(ConnectedStreams.CreateUnidirectional());
@@ -18,6 +19,7 @@ namespace System.IO.Tests
     {
         protected override int BufferedSize => StreamBuffer.DefaultMaxBufferSize;
         protected override bool FlushRequiredToWriteData => false;
+        protected override bool BlocksOnZeroByteReads => true;
 
         protected override Task<StreamPair> CreateConnectedStreamsAsync() =>
             Task.FromResult<StreamPair>(ConnectedStreams.CreateBidirectional());
index ff87d5a..e4ecc12 100644 (file)
@@ -37,17 +37,27 @@ namespace System.Net.Http
 
             public override int Read(Span<byte> buffer)
             {
-                if (_connection == null || buffer.Length == 0)
+                if (_connection == null)
                 {
-                    // Response body fully consumed or the caller didn't ask for any data.
+                    // Response body fully consumed
                     return 0;
                 }
 
-                // Try to consume from data we already have in the buffer.
-                int bytesRead = ReadChunksFromConnectionBuffer(buffer, cancellationRegistration: default);
-                if (bytesRead > 0)
+                if (buffer.Length == 0)
+                {
+                    if (PeekChunkFromConnectionBuffer())
+                    {
+                        return 0;
+                    }
+                }
+                else
                 {
-                    return bytesRead;
+                    // Try to consume from data we already have in the buffer.
+                    int bytesRead = ReadChunksFromConnectionBuffer(buffer, cancellationRegistration: default);
+                    if (bytesRead > 0)
+                    {
+                        return bytesRead;
+                    }
                 }
 
                 // Nothing available to consume.  Fall back to I/O.
@@ -68,7 +78,8 @@ namespace System.Net.Http
                         // as the connection buffer.  That avoids an unnecessary copy while still reading
                         // the maximum amount we'd otherwise read at a time.
                         Debug.Assert(_connection.RemainingBuffer.Length == 0);
-                        bytesRead = _connection.Read(buffer.Slice(0, (int)Math.Min((ulong)buffer.Length, _chunkBytesRemaining)));
+                        Debug.Assert(buffer.Length != 0);
+                        int bytesRead = _connection.Read(buffer.Slice(0, (int)Math.Min((ulong)buffer.Length, _chunkBytesRemaining)));
                         if (bytesRead == 0)
                         {
                             throw new IOException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, _chunkBytesRemaining));
@@ -81,15 +92,35 @@ namespace System.Net.Http
                         return bytesRead;
                     }
 
+                    if (buffer.Length == 0)
+                    {
+                        // User requested a zero-byte read, and we have no data available in the buffer for processing.
+                        // This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read
+                        // for reduced memory consumption when data is not immediately available.
+                        // So, we will issue our own zero-byte read against the underlying stream to allow it to make use of
+                        // optimizations, such as deferring buffer allocation until data is actually available.
+                        _connection.Read(buffer);
+                    }
+
                     // We're only here if we need more data to make forward progress.
                     _connection.Fill();
 
                     // Now that we have more, see if we can get any response data, and if
                     // we can we're done.
-                    int bytesCopied = ReadChunksFromConnectionBuffer(buffer, cancellationRegistration: default);
-                    if (bytesCopied > 0)
+                    if (buffer.Length == 0)
                     {
-                        return bytesCopied;
+                        if (PeekChunkFromConnectionBuffer())
+                        {
+                            return 0;
+                        }
+                    }
+                    else
+                    {
+                        int bytesCopied = ReadChunksFromConnectionBuffer(buffer, cancellationRegistration: default);
+                        if (bytesCopied > 0)
+                        {
+                            return bytesCopied;
+                        }
                     }
                 }
             }
@@ -102,17 +133,27 @@ namespace System.Net.Http
                     return ValueTask.FromCanceled<int>(cancellationToken);
                 }
 
-                if (_connection == null || buffer.Length == 0)
+                if (_connection == null)
                 {
-                    // Response body fully consumed or the caller didn't ask for any data.
+                    // Response body fully consumed
                     return new ValueTask<int>(0);
                 }
 
-                // Try to consume from data we already have in the buffer.
-                int bytesRead = ReadChunksFromConnectionBuffer(buffer.Span, cancellationRegistration: default);
-                if (bytesRead > 0)
+                if (buffer.Length == 0)
                 {
-                    return new ValueTask<int>(bytesRead);
+                    if (PeekChunkFromConnectionBuffer())
+                    {
+                        return new ValueTask<int>(0);
+                    }
+                }
+                else
+                {
+                    // Try to consume from data we already have in the buffer.
+                    int bytesRead = ReadChunksFromConnectionBuffer(buffer.Span, cancellationRegistration: default);
+                    if (bytesRead > 0)
+                    {
+                        return new ValueTask<int>(bytesRead);
+                    }
                 }
 
                 // We may have just consumed the remainder of the response (with no actual data
@@ -132,7 +173,6 @@ namespace System.Net.Http
                 // Should only be called if ReadChunksFromConnectionBuffer returned 0.
 
                 Debug.Assert(_connection != null);
-                Debug.Assert(buffer.Length > 0);
 
                 CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
                 try
@@ -154,6 +194,7 @@ namespace System.Net.Http
                             // as the connection buffer.  That avoids an unnecessary copy while still reading
                             // the maximum amount we'd otherwise read at a time.
                             Debug.Assert(_connection.RemainingBuffer.Length == 0);
+                            Debug.Assert(buffer.Length != 0);
                             int bytesRead = await _connection.ReadAsync(buffer.Slice(0, (int)Math.Min((ulong)buffer.Length, _chunkBytesRemaining))).ConfigureAwait(false);
                             if (bytesRead == 0)
                             {
@@ -167,15 +208,35 @@ namespace System.Net.Http
                             return bytesRead;
                         }
 
+                        if (buffer.Length == 0)
+                        {
+                            // User requested a zero-byte read, and we have no data available in the buffer for processing.
+                            // This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read
+                            // for reduced memory consumption when data is not immediately available.
+                            // So, we will issue our own zero-byte read against the underlying stream to allow it to make use of
+                            // optimizations, such as deferring buffer allocation until data is actually available.
+                            await _connection.ReadAsync(buffer).ConfigureAwait(false);
+                        }
+
                         // We're only here if we need more data to make forward progress.
                         await _connection.FillAsync(async: true).ConfigureAwait(false);
 
                         // Now that we have more, see if we can get any response data, and if
                         // we can we're done.
-                        int bytesCopied = ReadChunksFromConnectionBuffer(buffer.Span, ctr);
-                        if (bytesCopied > 0)
+                        if (buffer.Length == 0)
                         {
-                            return bytesCopied;
+                            if (PeekChunkFromConnectionBuffer())
+                            {
+                                return 0;
+                            }
+                        }
+                        else
+                        {
+                            int bytesCopied = ReadChunksFromConnectionBuffer(buffer.Span, ctr);
+                            if (bytesCopied > 0)
+                            {
+                                return bytesCopied;
+                            }
                         }
                     }
                 }
@@ -208,8 +269,7 @@ namespace System.Net.Http
                     {
                         while (true)
                         {
-                            ReadOnlyMemory<byte> bytesRead = ReadChunkFromConnectionBuffer(int.MaxValue, ctr);
-                            if (bytesRead.Length == 0)
+                            if (ReadChunkFromConnectionBuffer(int.MaxValue, ctr) is not ReadOnlyMemory<byte> bytesRead || bytesRead.Length == 0)
                             {
                                 break;
                             }
@@ -235,18 +295,23 @@ namespace System.Net.Http
                 }
             }
 
+            private bool PeekChunkFromConnectionBuffer()
+            {
+                return ReadChunkFromConnectionBuffer(maxBytesToRead: 0, cancellationRegistration: default).HasValue;
+            }
+
             private int ReadChunksFromConnectionBuffer(Span<byte> buffer, CancellationTokenRegistration cancellationRegistration)
             {
+                Debug.Assert(buffer.Length > 0);
                 int totalBytesRead = 0;
                 while (buffer.Length > 0)
                 {
-                    ReadOnlyMemory<byte> bytesRead = ReadChunkFromConnectionBuffer(buffer.Length, cancellationRegistration);
-                    Debug.Assert(bytesRead.Length <= buffer.Length);
-                    if (bytesRead.Length == 0)
+                    if (ReadChunkFromConnectionBuffer(buffer.Length, cancellationRegistration) is not ReadOnlyMemory<byte> bytesRead || bytesRead.Length == 0)
                     {
                         break;
                     }
 
+                    Debug.Assert(bytesRead.Length <= buffer.Length);
                     totalBytesRead += bytesRead.Length;
                     bytesRead.Span.CopyTo(buffer);
                     buffer = buffer.Slice(bytesRead.Length);
@@ -254,9 +319,9 @@ namespace System.Net.Http
                 return totalBytesRead;
             }
 
-            private ReadOnlyMemory<byte> ReadChunkFromConnectionBuffer(int maxBytesToRead, CancellationTokenRegistration cancellationRegistration)
+            private ReadOnlyMemory<byte>? ReadChunkFromConnectionBuffer(int maxBytesToRead, CancellationTokenRegistration cancellationRegistration)
             {
-                Debug.Assert(maxBytesToRead > 0 && _connection != null);
+                Debug.Assert(_connection != null);
 
                 try
                 {
@@ -310,7 +375,7 @@ namespace System.Net.Http
                             }
 
                             int bytesToConsume = Math.Min(maxBytesToRead, (int)Math.Min((ulong)connectionBuffer.Length, _chunkBytesRemaining));
-                            Debug.Assert(bytesToConsume > 0);
+                            Debug.Assert(bytesToConsume > 0 || maxBytesToRead == 0);
 
                             _connection.ConsumeFromRemainingBuffer(bytesToConsume);
                             _chunkBytesRemaining -= (ulong)bytesToConsume;
@@ -441,8 +506,7 @@ namespace System.Net.Http
                         drainedBytes += _connection.RemainingBuffer.Length;
                         while (true)
                         {
-                            ReadOnlyMemory<byte> bytesRead = ReadChunkFromConnectionBuffer(int.MaxValue, ctr);
-                            if (bytesRead.Length == 0)
+                            if (ReadChunkFromConnectionBuffer(int.MaxValue, ctr) is not ReadOnlyMemory<byte> bytesRead || bytesRead.Length == 0)
                             {
                                 break;
                             }
index 0fd0110..7bddf39 100644 (file)
@@ -18,14 +18,14 @@ namespace System.Net.Http
             public override int Read(Span<byte> buffer)
             {
                 HttpConnection? connection = _connection;
-                if (connection == null || buffer.Length == 0)
+                if (connection == null)
                 {
-                    // Response body fully consumed or the caller didn't ask for any data
+                    // Response body fully consumed
                     return 0;
                 }
 
                 int bytesRead = connection.Read(buffer);
-                if (bytesRead == 0)
+                if (bytesRead == 0 && buffer.Length != 0)
                 {
                     // We cannot reuse this connection, so close it.
                     _connection = null;
@@ -40,9 +40,9 @@ namespace System.Net.Http
                 CancellationHelper.ThrowIfCancellationRequested(cancellationToken);
 
                 HttpConnection? connection = _connection;
-                if (connection == null || buffer.Length == 0)
+                if (connection == null)
                 {
-                    // Response body fully consumed or the caller didn't ask for any data
+                    // Response body fully consumed
                     return 0;
                 }
 
@@ -69,7 +69,7 @@ namespace System.Net.Http
                     }
                 }
 
-                if (bytesRead == 0)
+                if (bytesRead == 0 && buffer.Length != 0)
                 {
                     // If cancellation is requested and tears down the connection, it could cause the read
                     // to return 0, which would otherwise signal the end of the data, but that would lead
index 786f285..97f9ddf 100644 (file)
@@ -22,9 +22,9 @@ namespace System.Net.Http
 
             public override int Read(Span<byte> buffer)
             {
-                if (_connection == null || buffer.Length == 0)
+                if (_connection == null)
                 {
-                    // Response body fully consumed or the caller didn't ask for any data.
+                    // Response body fully consumed
                     return 0;
                 }
 
@@ -35,7 +35,7 @@ namespace System.Net.Http
                 }
 
                 int bytesRead = _connection.Read(buffer);
-                if (bytesRead <= 0)
+                if (bytesRead <= 0 && buffer.Length != 0)
                 {
                     // Unexpected end of response stream.
                     throw new IOException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, _contentBytesRemaining));
@@ -58,9 +58,9 @@ namespace System.Net.Http
             {
                 CancellationHelper.ThrowIfCancellationRequested(cancellationToken);
 
-                if (_connection == null || buffer.Length == 0)
+                if (_connection == null)
                 {
-                    // Response body fully consumed or the caller didn't ask for any data
+                    // Response body fully consumed
                     return 0;
                 }
 
@@ -94,7 +94,7 @@ namespace System.Net.Http
                     }
                 }
 
-                if (bytesRead <= 0)
+                if (bytesRead == 0 && buffer.Length != 0)
                 {
                     // A cancellation request may have caused the EOF.
                     CancellationHelper.ThrowIfCancellationRequested(cancellationToken);
index 5c3fde8..b5d5ffa 100644 (file)
@@ -1040,8 +1040,6 @@ namespace System.Net.Http
 
             private (bool wait, int bytesRead) TryReadFromBuffer(Span<byte> buffer, bool partOfSyncRead = false)
             {
-                Debug.Assert(buffer.Length > 0);
-
                 Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
@@ -1073,11 +1071,6 @@ namespace System.Net.Http
 
             public int ReadData(Span<byte> buffer, HttpResponseMessage responseMessage)
             {
-                if (buffer.Length == 0)
-                {
-                    return 0;
-                }
-
                 (bool wait, int bytesRead) = TryReadFromBuffer(buffer, partOfSyncRead: true);
                 if (wait)
                 {
@@ -1092,7 +1085,7 @@ namespace System.Net.Http
                 {
                     _windowManager.AdjustWindow(bytesRead, this);
                 }
-                else
+                else if (buffer.Length != 0)
                 {
                     // We've hit EOF.  Pull in from the Http2Stream any trailers that were temporarily stored there.
                     MoveTrailersToResponseMessage(responseMessage);
@@ -1103,11 +1096,6 @@ namespace System.Net.Http
 
             public async ValueTask<int> ReadDataAsync(Memory<byte> buffer, HttpResponseMessage responseMessage, CancellationToken cancellationToken)
             {
-                if (buffer.Length == 0)
-                {
-                    return 0;
-                }
-
                 (bool wait, int bytesRead) = TryReadFromBuffer(buffer.Span);
                 if (wait)
                 {
@@ -1121,7 +1109,7 @@ namespace System.Net.Http
                 {
                     _windowManager.AdjustWindow(bytesRead, this);
                 }
-                else
+                else if (buffer.Length != 0)
                 {
                     // We've hit EOF.  Pull in from the Http2Stream any trailers that were temporarily stored there.
                     MoveTrailersToResponseMessage(responseMessage);
index 5b2c4fd..e14302e 100644 (file)
@@ -1051,7 +1051,7 @@ namespace System.Net.Http
             {
                 int totalBytesRead = 0;
 
-                while (buffer.Length != 0)
+                do
                 {
                     // Sync over async here -- QUIC implementation does it per-I/O already; this is at least more coarse-grained.
                     if (_responseDataPayloadRemaining <= 0 && !ReadNextDataFrameAsync(response, CancellationToken.None).AsTask().GetAwaiter().GetResult())
@@ -1087,7 +1087,7 @@ namespace System.Net.Http
                         int copyLen = (int)Math.Min(buffer.Length, _responseDataPayloadRemaining);
                         int bytesRead = _stream.Read(buffer.Slice(0, copyLen));
 
-                        if (bytesRead == 0)
+                        if (bytesRead == 0 && buffer.Length != 0)
                         {
                             throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, _responseDataPayloadRemaining));
                         }
@@ -1101,6 +1101,7 @@ namespace System.Net.Http
                         break;
                     }
                 }
+                while (buffer.Length != 0);
 
                 return totalBytesRead;
             }
@@ -1121,7 +1122,7 @@ namespace System.Net.Http
             {
                 int totalBytesRead = 0;
 
-                while (buffer.Length != 0)
+                do
                 {
                     if (_responseDataPayloadRemaining <= 0 && !await ReadNextDataFrameAsync(response, cancellationToken).ConfigureAwait(false))
                     {
@@ -1156,7 +1157,7 @@ namespace System.Net.Http
                         int copyLen = (int)Math.Min(buffer.Length, _responseDataPayloadRemaining);
                         int bytesRead = await _stream.ReadAsync(buffer.Slice(0, copyLen), cancellationToken).ConfigureAwait(false);
 
-                        if (bytesRead == 0)
+                        if (bytesRead == 0 && buffer.Length != 0)
                         {
                             throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, _responseDataPayloadRemaining));
                         }
@@ -1170,6 +1171,7 @@ namespace System.Net.Http
                         break;
                     }
                 }
+                while (buffer.Length != 0);
 
                 return totalBytesRead;
             }
index c781446..d75ed27 100644 (file)
@@ -1708,8 +1708,6 @@ namespace System.Net.Http
         private int ReadBuffered(Span<byte> destination)
         {
             // This is called when reading the response body.
-            Debug.Assert(destination.Length != 0);
-
             int remaining = _readLength - _readOffset;
             if (remaining > 0)
             {
@@ -1731,7 +1729,7 @@ namespace System.Net.Http
 
             // Do a buffered read directly against the underlying stream.
             Debug.Assert(_readAheadTask == null, "Read ahead task should have been consumed as part of the headers.");
-            int bytesRead = _stream.Read(_readBuffer, 0, _readBuffer.Length);
+            int bytesRead = _stream.Read(_readBuffer, 0, destination.Length == 0 ? 0 : _readBuffer.Length);
             if (NetEventSource.Log.IsEnabled()) Trace($"Received {bytesRead} bytes.");
             _readLength = bytesRead;
 
@@ -1747,7 +1745,9 @@ namespace System.Net.Http
             // If the caller provided buffer, and thus the amount of data desired to be read,
             // is larger than the internal buffer, there's no point going through the internal
             // buffer, so just do an unbuffered read.
-            return destination.Length >= _readBuffer.Length ?
+            // Also avoid avoid using the internal buffer if the user requested a zero-byte read to allow
+            // underlying streams to efficiently handle such a read (e.g. SslStream defering buffer allocation).
+            return destination.Length >= _readBuffer.Length || destination.Length == 0 ?
                 ReadAsync(destination) :
                 ReadBufferedAsyncCore(destination);
         }
index 44376b0..7a45db5 100644 (file)
@@ -23,14 +23,14 @@ namespace System.Net.Http
             public override int Read(Span<byte> buffer)
             {
                 HttpConnection? connection = _connection;
-                if (connection == null || buffer.Length == 0)
+                if (connection == null)
                 {
                     // Response body fully consumed or the caller didn't ask for any data
                     return 0;
                 }
 
                 int bytesRead = connection.ReadBuffered(buffer);
-                if (bytesRead == 0)
+                if (bytesRead == 0 && buffer.Length != 0)
                 {
                     // We cannot reuse this connection, so close it.
                     _connection = null;
@@ -45,9 +45,9 @@ namespace System.Net.Http
                 CancellationHelper.ThrowIfCancellationRequested(cancellationToken);
 
                 HttpConnection? connection = _connection;
-                if (connection == null || buffer.Length == 0)
+                if (connection == null)
                 {
-                    // Response body fully consumed or the caller didn't ask for any data
+                    // Response body fully consumed
                     return 0;
                 }
 
@@ -74,7 +74,7 @@ namespace System.Net.Http
                     }
                 }
 
-                if (bytesRead == 0)
+                if (bytesRead == 0 && buffer.Length != 0)
                 {
                     // A cancellation request may have caused the EOF.
                     CancellationHelper.ThrowIfCancellationRequested(cancellationToken);
index 37b9384..f7f28da 100644 (file)
@@ -80,6 +80,7 @@ namespace System.Net.Http.Functional.Tests
     {
         protected override Type UnsupportedConcurrentExceptionType => null;
         protected override bool UsableAfterCanceledReads => false;
+        protected override bool BlocksOnZeroByteReads => true;
 
         protected abstract string GetResponseHeaders();
 
diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs
new file mode 100644 (file)
index 0000000..77acc77
--- /dev/null
@@ -0,0 +1,317 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.IO;
+using System.IO.Tests;
+using System.Linq;
+using System.Net.Quic;
+using System.Net.Quic.Implementations;
+using System.Net.Security;
+using System.Net.Test.Common;
+using System.Security.Authentication;
+using System.Security.Cryptography.X509Certificates;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace System.Net.Http.Functional.Tests
+{
+    public sealed class Http1CloseResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase
+    {
+        protected override string GetResponseHeaders() => "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n";
+
+        protected override async Task WriteAsync(Stream stream, byte[] data) => await stream.WriteAsync(data);
+    }
+
+    public sealed class Http1RawResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase
+    {
+        protected override string GetResponseHeaders() => "HTTP/1.1 101 Switching Protocols\r\n\r\n";
+
+        protected override async Task WriteAsync(Stream stream, byte[] data) => await stream.WriteAsync(data);
+    }
+
+    public sealed class Http1ContentLengthResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase
+    {
+        protected override string GetResponseHeaders() => "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n";
+
+        protected override async Task WriteAsync(Stream stream, byte[] data) => await stream.WriteAsync(data);
+    }
+
+    public sealed class Http1SingleChunkResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase
+    {
+        protected override string GetResponseHeaders() => "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n";
+
+        protected override async Task WriteAsync(Stream stream, byte[] data)
+        {
+            await stream.WriteAsync(Encoding.ASCII.GetBytes($"{data.Length:X}\r\n"));
+            await stream.WriteAsync(data);
+            await stream.WriteAsync(Encoding.ASCII.GetBytes("\r\n"));
+        }
+    }
+
+    public sealed class Http1MultiChunkResponseStreamZeroByteReadTest : Http1ResponseStreamZeroByteReadTestBase
+    {
+        protected override string GetResponseHeaders() => "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n";
+
+        protected override async Task WriteAsync(Stream stream, byte[] data)
+        {
+            for (int i = 0; i < data.Length; i++)
+            {
+                await stream.WriteAsync(Encoding.ASCII.GetBytes($"1\r\n"));
+                await stream.WriteAsync(data.AsMemory(i, 1));
+                await stream.WriteAsync(Encoding.ASCII.GetBytes("\r\n"));
+            }
+        }
+    }
+
+    [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))]
+    public abstract class Http1ResponseStreamZeroByteReadTestBase
+    {
+        protected abstract string GetResponseHeaders();
+
+        protected abstract Task WriteAsync(Stream stream, byte[] data);
+
+        public static IEnumerable<object[]> ZeroByteRead_IssuesZeroByteReadOnUnderlyingStream_MemberData() =>
+            from readMode in Enum.GetValues<StreamConformanceTests.ReadWriteMode>()
+                .Where(mode => mode != StreamConformanceTests.ReadWriteMode.SyncByte) // Can't test zero-byte reads with ReadByte
+            from useSsl in new[] { true, false }
+            select new object[] { readMode, useSsl };
+
+        [Theory]
+        [MemberData(nameof(ZeroByteRead_IssuesZeroByteReadOnUnderlyingStream_MemberData))]
+        public async Task ZeroByteRead_IssuesZeroByteReadOnUnderlyingStream(StreamConformanceTests.ReadWriteMode readMode, bool useSsl)
+        {
+            (Stream httpConnection, Stream server) = ConnectedStreams.CreateBidirectional(4096, int.MaxValue);
+            try
+            {
+                var sawZeroByteRead = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+                httpConnection = new ReadInterceptStream(httpConnection, read =>
+                {
+                    if (read == 0)
+                    {
+                        sawZeroByteRead.TrySetResult();
+                    }
+                });
+
+                using var handler = new SocketsHttpHandler
+                {
+                    ConnectCallback = delegate { return ValueTask.FromResult(httpConnection); }
+                };
+                handler.SslOptions.RemoteCertificateValidationCallback = delegate { return true; };
+
+                using var client = new HttpClient(handler);
+
+                Task<HttpResponseMessage> clientTask = client.GetAsync($"http{(useSsl ? "s" : "")}://doesntmatter", HttpCompletionOption.ResponseHeadersRead);
+
+                if (useSsl)
+                {
+                    var sslStream = new SslStream(server, false, delegate { return true; });
+                    server = sslStream;
+
+                    using (X509Certificate2 cert = Test.Common.Configuration.Certificates.GetServerCertificate())
+                    {
+                        await ((SslStream)server).AuthenticateAsServerAsync(
+                            cert,
+                            clientCertificateRequired: true,
+                            enabledSslProtocols: SslProtocols.Tls12,
+                            checkCertificateRevocation: false).WaitAsync(TimeSpan.FromSeconds(10));
+                    }
+                }
+
+                await ResponseConnectedStreamConformanceTests.ReadHeadersAsync(server).WaitAsync(TimeSpan.FromSeconds(10));
+                await server.WriteAsync(Encoding.ASCII.GetBytes(GetResponseHeaders()));
+
+                using HttpResponseMessage response = await clientTask.WaitAsync(TimeSpan.FromSeconds(10));
+                using Stream clientStream = response.Content.ReadAsStream();
+                Assert.False(sawZeroByteRead.Task.IsCompleted);
+
+                Task<int> zeroByteReadTask = Task.Run(() => StreamConformanceTests.ReadAsync(readMode, clientStream, Array.Empty<byte>(), 0, 0, CancellationToken.None) );
+                Assert.False(zeroByteReadTask.IsCompleted);
+
+                // The zero-byte read should block until data is actually available
+                await sawZeroByteRead.Task.WaitAsync(TimeSpan.FromSeconds(10));
+                Assert.False(zeroByteReadTask.IsCompleted);
+
+                byte[] data = Encoding.UTF8.GetBytes("Hello");
+                await WriteAsync(server, data);
+                await server.FlushAsync();
+
+                Assert.Equal(0, await zeroByteReadTask.WaitAsync(TimeSpan.FromSeconds(10)));
+
+                // Now that data is available, a zero-byte read should complete synchronously
+                zeroByteReadTask = StreamConformanceTests.ReadAsync(readMode, clientStream, Array.Empty<byte>(), 0, 0, CancellationToken.None);
+                Assert.True(zeroByteReadTask.IsCompleted);
+                Assert.Equal(0, await zeroByteReadTask);
+
+                var readBuffer = new byte[10];
+                int read = 0;
+                while (read < data.Length)
+                {
+                    read += await StreamConformanceTests.ReadAsync(readMode, clientStream, readBuffer, read, readBuffer.Length - read, CancellationToken.None).WaitAsync(TimeSpan.FromSeconds(10));
+                }
+
+                Assert.Equal(data.Length, read);
+                Assert.Equal(data, readBuffer.AsSpan(0, read).ToArray());
+            }
+            finally
+            {
+                httpConnection.Dispose();
+                server.Dispose();
+            }
+        }
+
+        private sealed class ReadInterceptStream : DelegatingStream
+        {
+            private readonly Action<int> _readCallback;
+
+            public ReadInterceptStream(Stream innerStream, Action<int> readCallback)
+                : base(innerStream)
+            {
+                _readCallback = readCallback;
+            }
+
+            public override int Read(Span<byte> buffer)
+            {
+                _readCallback(buffer.Length);
+                return base.Read(buffer);
+            }
+
+            public override int Read(byte[] buffer, int offset, int count)
+            {
+                _readCallback(count);
+                return base.Read(buffer, offset, count);
+            }
+
+            public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
+            {
+                _readCallback(buffer.Length);
+                return base.ReadAsync(buffer, cancellationToken);
+            }
+        }
+    }
+
+    public sealed class Http1ResponseStreamZeroByteReadTest : ResponseStreamZeroByteReadTestBase
+    {
+        public Http1ResponseStreamZeroByteReadTest(ITestOutputHelper output) : base(output) { }
+
+        protected override Version UseVersion => HttpVersion.Version11;
+    }
+
+    [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]
+    public sealed class Http2ResponseStreamZeroByteReadTest : ResponseStreamZeroByteReadTestBase
+    {
+        public Http2ResponseStreamZeroByteReadTest(ITestOutputHelper output) : base(output) { }
+
+        protected override Version UseVersion => HttpVersion.Version20;
+    }
+
+    [ConditionalClass(typeof(HttpClientHandlerTestBase), nameof(IsMsQuicSupported))]
+    public sealed class Http3ResponseStreamZeroByteReadTest_MsQuic : ResponseStreamZeroByteReadTestBase
+    {
+        public Http3ResponseStreamZeroByteReadTest_MsQuic(ITestOutputHelper output) : base(output) { }
+
+        protected override Version UseVersion => HttpVersion.Version30;
+
+        protected override QuicImplementationProvider UseQuicImplementationProvider => QuicImplementationProviders.MsQuic;
+    }
+
+    [ConditionalClass(typeof(HttpClientHandlerTestBase), nameof(IsMockQuicSupported))]
+    public sealed class Http3ResponseStreamZeroByteReadTest_Mock : ResponseStreamZeroByteReadTestBase
+    {
+        public Http3ResponseStreamZeroByteReadTest_Mock(ITestOutputHelper output) : base(output) { }
+
+        protected override Version UseVersion => HttpVersion.Version30;
+
+        protected override QuicImplementationProvider UseQuicImplementationProvider => QuicImplementationProviders.Mock;
+    }
+
+    [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))]
+    public abstract class ResponseStreamZeroByteReadTestBase : HttpClientHandlerTestBase
+    {
+        public ResponseStreamZeroByteReadTestBase(ITestOutputHelper output) : base(output) { }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ZeroByteRead_BlocksUntilDataIsAvailable(bool async)
+        {
+            var zeroByteReadIssued = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            await LoopbackServerFactory.CreateClientAndServerAsync(async uri =>
+            {
+                HttpRequestMessage request = CreateRequest(HttpMethod.Get, uri, UseVersion, exactVersion: true);
+
+                using HttpClient client = CreateHttpClient();
+                using HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
+                using Stream responseStream = await response.Content.ReadAsStreamAsync();
+
+                var responseBuffer = new byte[1];
+                Assert.Equal(1, await ReadAsync(async, responseStream, responseBuffer));
+                Assert.Equal(42, responseBuffer[0]);
+
+                Task<int> zeroByteReadTask = ReadAsync(async, responseStream, Array.Empty<byte>());
+                Assert.False(zeroByteReadTask.IsCompleted);
+
+                zeroByteReadIssued.SetResult();
+                Assert.Equal(0, await zeroByteReadTask);
+                Assert.Equal(0, await ReadAsync(async, responseStream, Array.Empty<byte>()));
+
+                Assert.Equal(1, await ReadAsync(async, responseStream, responseBuffer));
+                Assert.Equal(1, responseBuffer[0]);
+
+                Assert.Equal(0, await ReadAsync(async, responseStream, Array.Empty<byte>()));
+
+                Assert.Equal(1, await ReadAsync(async, responseStream, responseBuffer));
+                Assert.Equal(2, responseBuffer[0]);
+
+                zeroByteReadTask = ReadAsync(async, responseStream, Array.Empty<byte>());
+                Assert.False(zeroByteReadTask.IsCompleted);
+
+                zeroByteReadIssued.SetResult();
+                Assert.Equal(0, await zeroByteReadTask);
+                Assert.Equal(0, await ReadAsync(async, responseStream, Array.Empty<byte>()));
+
+                Assert.Equal(1, await ReadAsync(async, responseStream, responseBuffer));
+                Assert.Equal(3, responseBuffer[0]);
+
+                Assert.Equal(0, await ReadAsync(async, responseStream, responseBuffer));
+            },
+            async server =>
+            {
+                await server.AcceptConnectionAsync(async connection =>
+                {
+                    await connection.ReadRequestDataAsync();
+
+                    await connection.SendResponseAsync(headers: new[] { new HttpHeaderData("Content-Length", "4") }, isFinal: false);
+
+                    await connection.SendResponseBodyAsync(new byte[] { 42 }, isFinal: false);
+
+                    await zeroByteReadIssued.Task;
+                    zeroByteReadIssued = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+                    await connection.SendResponseBodyAsync(new byte[] { 1, 2 }, isFinal: false);
+
+                    await zeroByteReadIssued.Task;
+
+                    await connection.SendResponseBodyAsync(new byte[] { 3 }, isFinal: true);
+                });
+            });
+
+            static Task<int> ReadAsync(bool async, Stream stream, byte[] buffer)
+            {
+                if (async)
+                {
+                    return stream.ReadAsync(buffer).AsTask();
+                }
+                else
+                {
+                    return Task.Run(() => stream.Read(buffer));
+                }
+            }
+        }
+    }
+}
index 29bbd6f..343267d 100644 (file)
     <Compile Include="StreamContentTest.cs" />
     <Compile Include="StringContentTest.cs" />
     <Compile Include="ResponseStreamConformanceTests.cs" />
+    <Compile Include="ResponseStreamZeroByteReadTests.cs" />
     <Compile Include="$(CommonTestPath)System\Net\Http\SyncBlockingContent.cs"
              Link="Common\System\Net\Http\SyncBlockingContent.cs" />
     <Compile Include="$(CommonTestPath)System\Net\Http\DefaultCredentialsTest.cs"
index cb31cfd..271ca17 100644 (file)
@@ -18,6 +18,7 @@ namespace System.Net.Quic.Tests
     public sealed class MockQuicStreamConformanceTests : QuicStreamConformanceTests
     {
         protected override QuicImplementationProvider Provider => QuicImplementationProviders.Mock;
+        protected override bool BlocksOnZeroByteReads => true;
     }
 
     [ConditionalClass(typeof(QuicTestBase<MsQuicProviderFactory>), nameof(QuicTestBase<MsQuicProviderFactory>.IsSupported))]