Improve Http2Connection buffer management (#79484)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Tue, 31 Jan 2023 12:52:05 +0000 (04:52 -0800)
committerGitHub <noreply@github.com>
Tue, 31 Jan 2023 12:52:05 +0000 (04:52 -0800)
* Improve Http2Connection buffer management

* Add a test

* Add a few comments around buffer disposal

src/libraries/Common/src/System/Net/ArrayBuffer.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs
src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs

index 71630c5..7499861 100644 (file)
@@ -3,6 +3,7 @@
 
 using System.Buffers;
 using System.Diagnostics;
+using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
 
 namespace System.Net
@@ -30,8 +31,12 @@ namespace System.Net
 
         public ArrayBuffer(int initialSize, bool usePool = false)
         {
+            Debug.Assert(initialSize > 0 || usePool);
+
             _usePool = usePool;
-            _bytes = usePool ? ArrayPool<byte>.Shared.Rent(initialSize) : new byte[initialSize];
+            _bytes = initialSize == 0
+                ? Array.Empty<byte>()
+                : usePool ? ArrayPool<byte>.Shared.Rent(initialSize) : new byte[initialSize];
             _activeStart = 0;
             _availableStart = 0;
         }
@@ -54,12 +59,26 @@ namespace System.Net
             byte[] array = _bytes;
             _bytes = null!;
 
-            if (_usePool && array != null)
+            if (array is not null)
             {
-                ArrayPool<byte>.Shared.Return(array);
+                ReturnBufferIfPooled(array);
             }
         }
 
+        // This is different from Dispose as the instance remains usable afterwards (_bytes will not be null).
+        public void ClearAndReturnBuffer()
+        {
+            Debug.Assert(_usePool);
+            Debug.Assert(_bytes is not null);
+
+            _activeStart = 0;
+            _availableStart = 0;
+
+            byte[] bufferToReturn = _bytes;
+            _bytes = Array.Empty<byte>();
+            ReturnBufferIfPooled(bufferToReturn);
+        }
+
         public int ActiveLength => _availableStart - _activeStart;
         public Span<byte> ActiveSpan => new Span<byte>(_bytes, _activeStart, _availableStart - _activeStart);
         public ReadOnlySpan<byte> ActiveReadOnlySpan => new ReadOnlySpan<byte>(_bytes, _activeStart, _availableStart - _activeStart);
@@ -94,10 +113,23 @@ namespace System.Net
         }
 
         // Ensure at least [byteCount] bytes to write to.
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
         public void EnsureAvailableSpace(int byteCount)
         {
-            if (byteCount <= AvailableLength)
+            if (byteCount > AvailableLength)
             {
+                EnsureAvailableSpaceCore(byteCount);
+            }
+        }
+
+        private void EnsureAvailableSpaceCore(int byteCount)
+        {
+            Debug.Assert(AvailableLength < byteCount);
+
+            if (_bytes.Length == 0)
+            {
+                Debug.Assert(_usePool && _activeStart == 0 && _availableStart == 0);
+                _bytes = ArrayPool<byte>.Shared.Rent(byteCount);
                 return;
             }
 
@@ -134,72 +166,24 @@ namespace System.Net
             _activeStart = 0;
 
             _bytes = newBytes;
-            if (_usePool)
-            {
-                ArrayPool<byte>.Shared.Return(oldBytes);
-            }
+            ReturnBufferIfPooled(oldBytes);
 
             Debug.Assert(byteCount <= AvailableLength);
         }
 
-        // Ensure at least [byteCount] bytes to write to, up to the specified limit
-        public void TryEnsureAvailableSpaceUpToLimit(int byteCount, int limit)
+        public void Grow()
         {
-            if (byteCount <= AvailableLength)
-            {
-                return;
-            }
-
-            int totalFree = _activeStart + AvailableLength;
-            if (byteCount <= totalFree)
-            {
-                // We can free up enough space by just shifting the bytes down, so do so.
-                Buffer.BlockCopy(_bytes, _activeStart, _bytes, 0, ActiveLength);
-                _availableStart = ActiveLength;
-                _activeStart = 0;
-                Debug.Assert(byteCount <= AvailableLength);
-                return;
-            }
-
-            if (_bytes.Length >= limit)
-            {
-                // Already at limit, can't grow further.
-                return;
-            }
-
-            // Double the size of the buffer until we have enough space, or we hit the limit
-            int desiredSize = Math.Min(ActiveLength + byteCount, limit);
-            int newSize = _bytes.Length;
-            do
-            {
-                newSize = Math.Min(newSize * 2, limit);
-            } while (newSize < desiredSize);
-
-            byte[] newBytes = _usePool ?
-                ArrayPool<byte>.Shared.Rent(newSize) :
-                new byte[newSize];
-            byte[] oldBytes = _bytes;
-
-            if (ActiveLength != 0)
-            {
-                Buffer.BlockCopy(oldBytes, _activeStart, newBytes, 0, ActiveLength);
-            }
-
-            _availableStart = ActiveLength;
-            _activeStart = 0;
-
-            _bytes = newBytes;
-            if (_usePool)
-            {
-                ArrayPool<byte>.Shared.Return(oldBytes);
-            }
-
-            Debug.Assert(byteCount <= AvailableLength || desiredSize == limit);
+            EnsureAvailableSpaceCore(AvailableLength + 1);
         }
 
-        public void Grow()
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private void ReturnBufferIfPooled(byte[] buffer)
         {
-            EnsureAvailableSpace(AvailableLength + 1);
+            // The buffer may be Array.Empty<byte>()
+            if (_usePool && buffer.Length > 0)
+            {
+                ArrayPool<byte>.Shared.Return(buffer);
+            }
         }
     }
 }
index 4bf5a58..cfc13bf 100644 (file)
@@ -29,6 +29,7 @@ namespace System.Net.Http
         private readonly Stream _stream;
 
         // NOTE: These are mutable structs; do not make these readonly.
+        // ProcessIncomingFramesAsync and ProcessOutgoingFramesAsync are responsible for disposing/returning their respective buffers.
         private ArrayBuffer _incomingBuffer;
         private ArrayBuffer _outgoingBuffer;
 
@@ -89,10 +90,12 @@ namespace System.Net.Http
 
 #if DEBUG
         // In debug builds, start with a very small buffer to induce buffer growing logic.
-        private const int InitialConnectionBufferSize = 4;
+        private const int InitialConnectionBufferSize = FrameHeader.Size;
 #else
-        private const int InitialConnectionBufferSize = 4096;
+        // Rent enough space to receive a full data frame in one read call.
+        private const int InitialConnectionBufferSize = FrameHeader.Size + FrameHeader.MaxPayloadLength;
 #endif
+
         // The default initial window size for streams and connections according to the RFC:
         // https://datatracker.ietf.org/doc/html/rfc7540#section-5.2.1
         // Unlike HttpHandlerDefaults.DefaultInitialHttp2StreamWindowSize, this value should never be changed.
@@ -139,8 +142,8 @@ namespace System.Net.Http
             _pool = pool;
             _stream = stream;
 
-            _incomingBuffer = new ArrayBuffer(InitialConnectionBufferSize);
-            _outgoingBuffer = new ArrayBuffer(InitialConnectionBufferSize);
+            _incomingBuffer = new ArrayBuffer(initialSize: 0, usePool: true);
+            _outgoingBuffer = new ArrayBuffer(initialSize: 0, usePool: true);
 
             _hpackDecoder = new HPackDecoder(maxHeadersLength: pool.Settings.MaxResponseHeadersByteLength);
 
@@ -239,11 +242,15 @@ namespace System.Net.Http
                 _ = ProcessIncomingFramesAsync();
                 await _stream.WriteAsync(_outgoingBuffer.ActiveMemory, cancellationToken).ConfigureAwait(false);
                 _rttEstimator.OnInitialSettingsSent();
-                _outgoingBuffer.Discard(_outgoingBuffer.ActiveLength);
-
+                _outgoingBuffer.ClearAndReturnBuffer();
             }
             catch (Exception e)
             {
+                // ProcessIncomingFramesAsync and ProcessOutgoingFramesAsync are responsible for disposing/returning their respective buffers.
+                // SetupAsync is the exception as it's responsible for starting the ProcessOutgoingFramesAsync loop.
+                // As we're about to throw and ProcessOutgoingFramesAsync will never be called, we must return the buffer here.
+                _outgoingBuffer.Dispose();
+
                 Dispose();
 
                 if (e is OperationCanceledException oce && oce.CancellationToken == cancellationToken)
@@ -428,9 +435,13 @@ namespace System.Net.Http
             // Ensure we've read enough data for the frame header.
             if (_incomingBuffer.ActiveLength < FrameHeader.Size)
             {
-                _incomingBuffer.EnsureAvailableSpace(FrameHeader.Size - _incomingBuffer.ActiveLength);
                 do
                 {
+                    // Issue a zero-byte read to avoid potentially pinning the buffer while waiting for more data.
+                    await _stream.ReadAsync(Memory<byte>.Empty).ConfigureAwait(false);
+
+                    _incomingBuffer.EnsureAvailableSpace(FrameHeader.Size);
+
                     int bytesRead = await _stream.ReadAsync(_incomingBuffer.AvailableMemory).ConfigureAwait(false);
                     _incomingBuffer.Commit(bytesRead);
                     if (bytesRead == 0)
@@ -469,6 +480,9 @@ namespace System.Net.Http
                 _incomingBuffer.EnsureAvailableSpace(frameHeader.PayloadLength - _incomingBuffer.ActiveLength);
                 do
                 {
+                    // Issue a zero-byte read to avoid potentially pinning the buffer while waiting for more data.
+                    await _stream.ReadAsync(Memory<byte>.Empty).ConfigureAwait(false);
+
                     int bytesRead = await _stream.ReadAsync(_incomingBuffer.AvailableMemory).ConfigureAwait(false);
                     _incomingBuffer.Commit(bytesRead);
                     if (bytesRead == 0) ThrowPrematureEOF(frameHeader.PayloadLength);
@@ -531,9 +545,21 @@ namespace System.Net.Http
                     // the entire frame's needs (not just the header).
                     if (_incomingBuffer.ActiveLength < FrameHeader.Size)
                     {
-                        _incomingBuffer.EnsureAvailableSpace(FrameHeader.Size - _incomingBuffer.ActiveLength);
                         do
                         {
+                            // Issue a zero-byte read to avoid potentially pinning the buffer while waiting for more data.
+                            ValueTask<int> zeroByteReadTask = _stream.ReadAsync(Memory<byte>.Empty);
+                            if (!zeroByteReadTask.IsCompletedSuccessfully && _incomingBuffer.ActiveLength == 0)
+                            {
+                                // No data is available yet. Return the receive buffer back to the pool while we wait.
+                                _incomingBuffer.ClearAndReturnBuffer();
+                            }
+                            await zeroByteReadTask.ConfigureAwait(false);
+
+                            // While we only need FrameHeader.Size bytes to complete this read, it's better if we rent more
+                            // to avoid multiple ReadAsync calls and resizes once we start copying the content.
+                            _incomingBuffer.EnsureAvailableSpace(InitialConnectionBufferSize);
+
                             int bytesRead = await _stream.ReadAsync(_incomingBuffer.AvailableMemory).ConfigureAwait(false);
                             Debug.Assert(bytesRead >= 0);
                             _incomingBuffer.Commit(bytesRead);
@@ -605,6 +631,10 @@ namespace System.Net.Http
 
                 Abort(e);
             }
+            finally
+            {
+                _incomingBuffer.Dispose();
+            }
         }
 
         // Note, this will return null for a streamId that's no longer in use.
@@ -1252,6 +1282,11 @@ namespace System.Net.Http
                     {
                         await FlushOutgoingBytesAsync().ConfigureAwait(false);
                     }
+
+                    if (_outgoingBuffer.ActiveLength == 0)
+                    {
+                        _outgoingBuffer.ClearAndReturnBuffer();
+                    }
                 }
             }
             catch (Exception e)
@@ -1260,6 +1295,10 @@ namespace System.Net.Http
 
                 Debug.Fail($"Unexpected exception in {nameof(ProcessOutgoingFramesAsync)}: {e}");
             }
+            finally
+            {
+                _outgoingBuffer.Dispose();
+            }
         }
 
         private Task SendSettingsAckAsync() =>
@@ -1330,7 +1369,7 @@ namespace System.Net.Http
             int bytesWritten;
             while (!HPackEncoder.EncodeIndexedHeaderField(index, headerBuffer.AvailableSpan, out bytesWritten))
             {
-                headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
+                headerBuffer.Grow();
             }
 
             headerBuffer.Commit(bytesWritten);
@@ -1343,7 +1382,7 @@ namespace System.Net.Http
             int bytesWritten;
             while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexing(index, value, valueEncoding: null, headerBuffer.AvailableSpan, out bytesWritten))
             {
-                headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
+                headerBuffer.Grow();
             }
 
             headerBuffer.Commit(bytesWritten);
@@ -1356,7 +1395,7 @@ namespace System.Net.Http
             int bytesWritten;
             while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten))
             {
-                headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
+                headerBuffer.Grow();
             }
 
             headerBuffer.Commit(bytesWritten);
@@ -1369,7 +1408,7 @@ namespace System.Net.Http
             int bytesWritten;
             while (!HPackEncoder.EncodeStringLiterals(values, separator, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten))
             {
-                headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
+                headerBuffer.Grow();
             }
 
             headerBuffer.Commit(bytesWritten);
@@ -1382,7 +1421,7 @@ namespace System.Net.Http
             int bytesWritten;
             while (!HPackEncoder.EncodeStringLiteral(value, valueEncoding, headerBuffer.AvailableSpan, out bytesWritten))
             {
-                headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
+                headerBuffer.Grow();
             }
 
             headerBuffer.Commit(bytesWritten);
@@ -1392,11 +1431,7 @@ namespace System.Net.Http
         {
             if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(bytes.Length)}={bytes.Length}");
 
-            if (bytes.Length > headerBuffer.AvailableLength)
-            {
-                headerBuffer.EnsureAvailableSpace(bytes.Length);
-            }
-
+            headerBuffer.EnsureAvailableSpace(bytes.Length);
             bytes.CopyTo(headerBuffer.AvailableSpan);
             headerBuffer.Commit(bytes.Length);
         }
@@ -1855,6 +1890,10 @@ namespace System.Net.Http
             _connectionWindow.Dispose();
             _writeChannel.Writer.Complete();
 
+            // We're not disposing the _incomingBuffer and _outgoingBuffer here as they may still be in use by
+            // ProcessIncomingFramesAsync and ProcessOutgoingFramesAsync respectively, and those methods are
+            // responsible for returning the buffers.
+
             if (HttpTelemetry.Log.IsEnabled())
             {
                 if (Interlocked.Exchange(ref _markedByTelemetryStatus, TelemetryStatus_Closed) == TelemetryStatus_Opened)
index 6aa4fee..a5312b1 100644 (file)
@@ -159,35 +159,6 @@ namespace System.Net.Http.Functional.Tests
                 server.Dispose();
             }
         }
-
-        private sealed class ReadInterceptStream : DelegatingStream
-        {
-            private readonly Action<int> _readCallback;
-
-            public ReadInterceptStream(Stream innerStream, Action<int> readCallback)
-                : base(innerStream)
-            {
-                _readCallback = readCallback;
-            }
-
-            public override int Read(Span<byte> buffer)
-            {
-                _readCallback(buffer.Length);
-                return base.Read(buffer);
-            }
-
-            public override int Read(byte[] buffer, int offset, int count)
-            {
-                _readCallback(count);
-                return base.Read(buffer, offset, count);
-            }
-
-            public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
-            {
-                _readCallback(buffer.Length);
-                return base.ReadAsync(buffer, cancellationToken);
-            }
-        }
     }
 
     public sealed class Http1ResponseStreamZeroByteReadTest : ResponseStreamZeroByteReadTestBase
@@ -299,4 +270,75 @@ namespace System.Net.Http.Functional.Tests
             }
         }
     }
+
+    [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]
+    public sealed class Http2ConnectionZeroByteReadTest : HttpClientHandlerTestBase
+    {
+        public Http2ConnectionZeroByteReadTest(ITestOutputHelper output) : base(output) { }
+
+        protected override Version UseVersion => HttpVersion.Version20;
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ConnectionIssuesZeroByteReadsOnUnderlyingStream(bool useSsl)
+        {
+            await Http2LoopbackServer.CreateClientAndServerAsync(async uri =>
+            {
+                using HttpClientHandler handler = CreateHttpClientHandler();
+
+                int zeroByteReads = 0;
+                GetUnderlyingSocketsHttpHandler(handler).PlaintextStreamFilter = (context, _) =>
+                {
+                    return new ValueTask<Stream>(new ReadInterceptStream(context.PlaintextStream, read =>
+                    {
+                        if (read == 0)
+                        {
+                            zeroByteReads++;
+                        }
+                    }));
+                };
+
+                using HttpClient client = CreateHttpClient(handler);
+                client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+                Assert.Equal("Foo", await client.GetStringAsync(uri));
+
+                Assert.NotEqual(0, zeroByteReads);
+            },
+            async server =>
+            {
+                await server.HandleRequestAsync(content: "Foo");
+            }, http2Options: new Http2Options { UseSsl = useSsl });
+        }
+    }
+
+    file sealed class ReadInterceptStream : DelegatingStream
+    {
+        private readonly Action<int> _readCallback;
+
+        public ReadInterceptStream(Stream innerStream, Action<int> readCallback)
+            : base(innerStream)
+        {
+            _readCallback = readCallback;
+        }
+
+        public override int Read(Span<byte> buffer)
+        {
+            _readCallback(buffer.Length);
+            return base.Read(buffer);
+        }
+
+        public override int Read(byte[] buffer, int offset, int count)
+        {
+            _readCallback(count);
+            return base.Read(buffer, offset, count);
+        }
+
+        public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
+        {
+            _readCallback(buffer.Length);
+            return base.ReadAsync(buffer, cancellationToken);
+        }
+    }
 }
index d2f8130..541fcb7 100644 (file)
@@ -98,11 +98,8 @@ namespace System.Net.Http.Unit.Tests.HPack
 
             void WriteBytes(ReadOnlySpan<byte> bytes)
             {
-                if (bytes.Length > buffer.AvailableLength)
-                {
-                    buffer.EnsureAvailableSpace(bytes.Length);
-                    FillAvailableSpaceWithOnes(buffer);
-                }
+                buffer.EnsureAvailableSpace(bytes.Length);
+                FillAvailableSpaceWithOnes(buffer);
 
                 bytes.CopyTo(buffer.AvailableSpan);
                 buffer.Commit(bytes.Length);
@@ -113,7 +110,7 @@ namespace System.Net.Http.Unit.Tests.HPack
                 int bytesWritten;
                 while (!HPackEncoder.EncodeStringLiterals(values, separator, valueEncoding, buffer.AvailableSpan, out bytesWritten))
                 {
-                    buffer.EnsureAvailableSpace(buffer.AvailableLength + 1);
+                    buffer.Grow();
                     FillAvailableSpaceWithOnes(buffer);
                 }
 
@@ -125,7 +122,7 @@ namespace System.Net.Http.Unit.Tests.HPack
                 int bytesWritten;
                 while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, valueEncoding, buffer.AvailableSpan, out bytesWritten))
                 {
-                    buffer.EnsureAvailableSpace(buffer.AvailableLength + 1);
+                    buffer.Grow();
                     FillAvailableSpaceWithOnes(buffer);
                 }
 
index 334795d..a086fc9 100644 (file)
@@ -52,7 +52,7 @@ namespace System.Net.Security
         private const int InitialHandshakeBufferSize = 4096 + FrameOverhead; // try to fit at least 4K ServerCertificate
         private const int ReadBufferSize = 4096 * 4 + FrameOverhead;         // We read in 16K chunks + headers.
 
-        private SslBuffer _buffer;
+        private SslBuffer _buffer = new();
 
         // internal buffer for storing incoming data. Wrapper around ArrayBuffer which adds
         // separation between decrypted and still encrypted part of the active region.
@@ -66,14 +66,15 @@ namespace System.Net.Security
             // padding between decrypted part of the active memory and following undecrypted TLS frame.
             private int _decryptedPadding;
 
+            // Indicates whether the _buffer currently holds a rented buffer.
             private bool _isValid;
 
-            public SslBuffer(int initialSize)
+            public SslBuffer()
             {
-                _buffer = new ArrayBuffer(initialSize, true);
+                _buffer = new ArrayBuffer(initialSize: 0, usePool: true);
                 _decryptedLength = 0;
                 _decryptedPadding = 0;
-                _isValid = true;
+                _isValid = false;
             }
 
             public bool IsValid => _isValid;
@@ -106,15 +107,8 @@ namespace System.Net.Security
 
             public void EnsureAvailableSpace(int byteCount)
             {
-                if (_isValid)
-                {
-                    _buffer.EnsureAvailableSpace(byteCount);
-                }
-                else
-                {
-                    _isValid = true;
-                    _buffer = new ArrayBuffer(byteCount, true);
-                }
+                _isValid = true;
+                _buffer.EnsureAvailableSpace(byteCount);
             }
 
             public void Discard(int byteCount)
@@ -164,7 +158,7 @@ namespace System.Net.Security
 
             public void ReturnBuffer()
             {
-                _buffer.Dispose();
+                _buffer.ClearAndReturnBuffer();
                 _decryptedLength = 0;
                 _decryptedPadding = 0;
                 _isValid = false;