Reduce simple HTTP/2 post app allocation by ~40% (#32557)
authorStephen Toub <stoub@microsoft.com>
Thu, 20 Feb 2020 14:28:01 +0000 (09:28 -0500)
committerGitHub <noreply@github.com>
Thu, 20 Feb 2020 14:28:01 +0000 (09:28 -0500)
* Remove cancellation-related allocations in Http2Stream

We don't need to allocate a linked token source in SendRequestBodyAsync if the caller's token is the default

* Reduce size of SendDataAsync state machine

We're carrying around an extra 24-bytes for a `ReadOnlyMemory<byte>`, when we could instead just use the argument.

* Tweak HeaderField's ctor to use ROS.ToArray

If `value` happens to be empty, this will avoid an allocation.  But what actually led me to do this was just tightening up the code.

* Make HPackDecoder.State enum 1 instead of 4 bytes

* Remove spilled CancellationTokenSource field from SendDataAsync

* Add known-header values for access-control-* headers

* Reduce allocation in SslStream.ReadAsync

The current structure is that ReadAsync makes two calls to FillBufferAsync, one to ensure the frame header is read and another to ensure any additional payload is read.  This has two issues:
1. It ensures that in addition to allocating a state machine for FillBufferAsync (or, rather, a helper it uses) when it needs to yield, it'll also end up allocating for ReadAsync.
2. It complicates error handling, which needs to differentiate whether the first read can't get any bytes or whether a subsequent read can't, which necessitates storing state like how many bytes we initially had buffered so we can compare to that to see if we need to throw.

We can instead:
- Make FillBufferAsync into a simple "read until we get the requested number of bytes" loop and throw if it fails to do so.
- Do the initial read in ReadAsync, thereby allowing us to special-case the first read for both error handling and to minimize the chances that the helper call needs to yield.

This eliminates a bunch of FillBufferAsync state machines and also decreases the size of the state machines when they are needed.

* Replace CreditManager's waiter queue with a circular singly-linked list

This has a variety of benefits:
- We no longer need to allocate a `Queue<Waiter>` and its underlying `Waiter[]`.
- We no longer need to allocate a `TaskCompletionSource<int>` and its `Task<int>`, instead creating a single `IValueTaskSource<T>` implementation.
- For non-cancelable waiters, we can specialize to not need to carry around a meaningless CancellationToken field.
- For cancelable waiters (the common case), we can avoid an entire async method and its state machine by just storing the relevant state onto the waiter itself.

* Fix comment from previous change

* Manually inline and specialize EnsureIncomingBytesAsync

It's not that much more code to just manually inline EnsureIncomingBytesAsync into the three places it's used, and doing so has multiple benefits, both for size and for error messages.

* Update src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs

Co-Authored-By: Cory Nelson <phrosty@gmail.com>
* Fix typo in online feedback

Co-authored-by: Cory Nelson <phrosty@gmail.com>
src/libraries/Common/src/System/Net/Http/aspnetcore/Http2/Hpack/HPackDecoder.cs
src/libraries/Common/src/System/Net/Http/aspnetcore/Http2/Hpack/HeaderField.cs
src/libraries/System.Net.Http/src/Resources/Strings.resx
src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeaders.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditManager.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs

index 98fb416..997047f 100644 (file)
@@ -12,7 +12,7 @@ namespace System.Net.Http.HPack
 {
     internal class HPackDecoder
     {
-        private enum State
+        private enum State : byte
         {
             Ready,
             HeaderFieldIndex,
index 1eba824..f876204 100644 (file)
@@ -21,11 +21,8 @@ namespace System.Net.Http.HPack
             // We should revisit our allocation strategy here so we don't need to allocate per entry
             // and we have a cap to how much allocation can happen per dynamic table
             // (without limiting the number of table entries a server can provide within the table size limit).
-            Name = new byte[name.Length];
-            name.CopyTo(Name);
-
-            Value = new byte[value.Length];
-            value.CopyTo(Value);
+            Name = name.ToArray();
+            Value = value.ToArray();
         }
 
         public byte[] Name { get; }
index 456e14c..c4e6717 100644 (file)
   <data name="net_http_invalid_response_premature_eof" xml:space="preserve">
     <value>The response ended prematurely.</value>
   </data>
+  <data name="net_http_invalid_response_missing_frame" xml:space="preserve">
+    <value>The response ended prematurely while waiting for the next frame from the server.</value>
+  </data>
   <data name="net_http_invalid_response_premature_eof_bytecount" xml:space="preserve">
     <value>The response ended prematurely, with at least {0} additional bytes expected.</value>
   </data>
index 5590d0a..889ae2a 100644 (file)
@@ -19,11 +19,11 @@ namespace System.Net.Http.Headers
         public static readonly KnownHeader AcceptLanguage = new KnownHeader("Accept-Language", HttpHeaderType.Request, GenericHeaderParser.MultipleValueStringWithQualityParser, null, H2StaticTable.AcceptLanguage, H3StaticTable.AcceptLanguage);
         public static readonly KnownHeader AcceptPatch = new KnownHeader("Accept-Patch");
         public static readonly KnownHeader AcceptRanges = new KnownHeader("Accept-Ranges", HttpHeaderType.Response, GenericHeaderParser.TokenListParser, null, H2StaticTable.AcceptRanges, H3StaticTable.AcceptRangesBytes);
-        public static readonly KnownHeader AccessControlAllowCredentials = new KnownHeader("Access-Control-Allow-Credentials", http3StaticTableIndex: H3StaticTable.AccessControlAllowCredentials);
-        public static readonly KnownHeader AccessControlAllowHeaders = new KnownHeader("Access-Control-Allow-Headers", http3StaticTableIndex: H3StaticTable.AccessControlAllowHeadersCacheControl);
-        public static readonly KnownHeader AccessControlAllowMethods = new KnownHeader("Access-Control-Allow-Methods", http3StaticTableIndex: H3StaticTable.AccessControlAllowMethodsGet);
-        public static readonly KnownHeader AccessControlAllowOrigin = new KnownHeader("Access-Control-Allow-Origin", H2StaticTable.AccessControlAllowOrigin, H3StaticTable.AccessControlAllowOriginAny);
-        public static readonly KnownHeader AccessControlExposeHeaders = new KnownHeader("Access-Control-Expose-Headers", H3StaticTable.AccessControlExposeHeadersContentLength);
+        public static readonly KnownHeader AccessControlAllowCredentials = new KnownHeader("Access-Control-Allow-Credentials", HttpHeaderType.Response, parser: null, new string[] { "true" }, http3StaticTableIndex: H3StaticTable.AccessControlAllowCredentials);
+        public static readonly KnownHeader AccessControlAllowHeaders = new KnownHeader("Access-Control-Allow-Headers", HttpHeaderType.Response, parser: null, new string[] { "*" }, http3StaticTableIndex: H3StaticTable.AccessControlAllowHeadersCacheControl);
+        public static readonly KnownHeader AccessControlAllowMethods = new KnownHeader("Access-Control-Allow-Methods", HttpHeaderType.Response, parser: null, new string[] { "*" }, http3StaticTableIndex: H3StaticTable.AccessControlAllowMethodsGet);
+        public static readonly KnownHeader AccessControlAllowOrigin = new KnownHeader("Access-Control-Allow-Origin", HttpHeaderType.Response, parser: null, new string[] { "*", "null" }, H2StaticTable.AccessControlAllowOrigin, H3StaticTable.AccessControlAllowOriginAny);
+        public static readonly KnownHeader AccessControlExposeHeaders = new KnownHeader("Access-Control-Expose-Headers", HttpHeaderType.Response, parser: null, new string[] { "*" }, H3StaticTable.AccessControlExposeHeadersContentLength);
         public static readonly KnownHeader AccessControlMaxAge = new KnownHeader("Access-Control-Max-Age");
         public static readonly KnownHeader Age = new KnownHeader("Age", HttpHeaderType.Response | HttpHeaderType.NonTrailing, TimeSpanHeaderParser.Parser, null, H2StaticTable.Age, H3StaticTable.Age0);
         public static readonly KnownHeader Allow = new KnownHeader("Allow", HttpHeaderType.Content, GenericHeaderParser.TokenListParser, null, H2StaticTable.Allow);
index f1bbf34..a1f08a2 100644 (file)
@@ -3,11 +3,10 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Diagnostics;
-using System.Collections.Generic;
-using System.IO;
 using System.Runtime.ExceptionServices;
 using System.Threading;
 using System.Threading.Tasks;
+using System.Threading.Tasks.Sources;
 
 namespace System.Net.Http
 {
@@ -16,8 +15,10 @@ namespace System.Net.Http
         private readonly IHttpTrace _owner;
         private readonly string _name;
         private int _current;
-        private Queue<Waiter> _waiters;
         private bool _disposed;
+        /// <summary>Circular singly-linked list of active waiters.</summary>
+        /// <remarks>If null, the list is empty.  If non-null, this is the tail.  If the list has one item, its Next is itself.</remarks>
+        private Waiter _waitersTail;
 
         public CreditManager(IHttpTrace owner, string name, int initialCredit)
         {
@@ -46,12 +47,13 @@ namespace System.Net.Http
             {
                 if (_disposed)
                 {
-                    throw CreateObjectDisposedException(forActiveWaiter: false);
+                    throw new ObjectDisposedException($"{nameof(CreditManager)}:{_owner.GetType().Name}:{_name}");
                 }
 
+                // If we can satisfy the request with credit already available, do so synchronously.
                 if (_current > 0)
                 {
-                    Debug.Assert(_waiters == null || _waiters.Count == 0, "Shouldn't have waiters when credit is available");
+                    Debug.Assert(_waitersTail is null, "Shouldn't have waiters when credit is available");
 
                     int granted = Math.Min(amount, _current);
                     if (NetEventSource.IsEnabled) _owner.Trace($"{_name}. requested={amount}, current={_current}, granted={granted}");
@@ -61,12 +63,25 @@ namespace System.Net.Http
 
                 if (NetEventSource.IsEnabled) _owner.Trace($"{_name}. requested={amount}, no credit available.");
 
-                var waiter = new Waiter { Amount = amount };
-                (_waiters ??= new Queue<Waiter>()).Enqueue(waiter);
+                // Otherwise, create a new waiter.
+                Waiter waiter = cancellationToken.CanBeCanceled ?
+                    new CancelableWaiter(amount, SyncObject, cancellationToken) :
+                    new Waiter(amount);
 
-                return cancellationToken.CanBeCanceled ?
-                    waiter.WaitWithCancellationAsync(cancellationToken) :
-                    new ValueTask<int>(waiter.Task);
+                // Add the waiter at the tail of the queue.
+                if (_waitersTail is null)
+                {
+                    _waitersTail = waiter.Next = waiter;
+                }
+                else
+                {
+                    waiter.Next = _waitersTail.Next;
+                    _waitersTail.Next = waiter;
+                    _waitersTail = waiter;
+                }
+
+                // And return a ValueTask<int> for it.
+                return waiter.AsValueTask();
             }
         }
 
@@ -84,25 +99,35 @@ namespace System.Net.Http
                     return;
                 }
 
-                Debug.Assert(_current <= 0 || _waiters == null || _waiters.Count == 0, "Shouldn't have waiters when credit is available");
+                Debug.Assert(_current <= 0 || _waitersTail is null, "Shouldn't have waiters when credit is available");
 
-                checked
-                {
-                    _current += amount;
-                }
+                _current = checked(_current + amount);
 
-                if (_waiters != null)
+                while (_current > 0 && _waitersTail != null)
                 {
-                    while (_current > 0 && _waiters.TryDequeue(out Waiter waiter))
+                    // Get the waiter from the head of the queue.
+                    Waiter waiter = _waitersTail.Next;
+                    int granted = Math.Min(waiter.Amount, _current);
+
+                    // Remove the waiter from the list.
+                    if (waiter.Next == waiter)
+                    {
+                        Debug.Assert(_waitersTail == waiter);
+                        _waitersTail = null;
+                    }
+                    else
                     {
-                        int granted = Math.Min(waiter.Amount, _current);
+                        _waitersTail.Next = waiter.Next;
+                    }
+                    waiter.Next = null;
 
-                        // Ensure that we grant credit only if the task has not been canceled.
-                        if (waiter.TrySetResult(granted))
-                        {
-                            _current -= granted;
-                        }
+                    // Ensure that we grant credit only if the task has not been canceled.
+                    if (waiter.TrySetResult(granted))
+                    {
+                        _current -= granted;
                     }
+
+                    waiter.Dispose();
                 }
             }
         }
@@ -118,23 +143,104 @@ namespace System.Net.Http
 
                 _disposed = true;
 
-                if (_waiters != null)
+                Waiter waiter = _waitersTail;
+                if (waiter != null)
                 {
-                    while (_waiters.TryDequeue(out Waiter waiter))
+                    do
                     {
-                        waiter.TrySetException(ExceptionDispatchInfo.SetCurrentStackTrace(CreateObjectDisposedException(forActiveWaiter: true)));
+                        Waiter next = waiter.Next;
+                        waiter.Next = null;
+                        waiter.Dispose();
+                        waiter = next;
                     }
+                    while (waiter != _waitersTail);
+
+                    _waitersTail = null;
                 }
             }
         }
 
-        private ObjectDisposedException CreateObjectDisposedException(bool forActiveWaiter) => forActiveWaiter ?
-            new ObjectDisposedException($"{nameof(CreditManager)}:{_owner.GetType().Name}:{_name}", SR.net_http_disposed_while_in_use) :
-            new ObjectDisposedException($"{nameof(CreditManager)}:{_owner.GetType().Name}:{_name}");
+        /// <summary>Represents a waiter for credit.</summary>
+        /// <remarks>All of the public members on the instance must only be accessed while holding the CreditManager's lock.</remarks>
+        private class Waiter : IValueTaskSource<int>
+        {
+            public readonly int Amount;
+            public Waiter Next;
+            protected ManualResetValueTaskSourceCore<int> _source;
+
+            public Waiter(int amount)
+            {
+                Amount = amount;
+                _source.RunContinuationsAsynchronously = true;
+            }
+
+            public ValueTask<int> AsValueTask() => new ValueTask<int>(this, _source.Version);
+
+            public bool IsPending => _source.GetStatus(_source.Version) == ValueTaskSourceStatus.Pending;
 
-        private sealed class Waiter : TaskCompletionSourceWithCancellation<int>
+            public bool TrySetResult(int result)
+            {
+                if (IsPending)
+                {
+                    _source.SetResult(result);
+                    return true;
+                }
+
+                return false;
+            }
+
+            public virtual void Dispose()
+            {
+                if (IsPending)
+                {
+                    _source.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new ObjectDisposedException(nameof(CreditManager), SR.net_http_disposed_while_in_use)));
+                }
+            }
+
+            int IValueTaskSource<int>.GetResult(short token) =>
+                _source.GetResult(token);
+            ValueTaskSourceStatus IValueTaskSource<int>.GetStatus(short token) =>
+                _source.GetStatus(token);
+            void IValueTaskSource<int>.OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) =>
+                _source.OnCompleted(continuation, state, token, flags);
+        }
+
+        private sealed class CancelableWaiter : Waiter
         {
-            public int Amount;
+            private readonly object _syncObj;
+            private CancellationTokenRegistration _registration;
+
+            public CancelableWaiter(int amount, object syncObj, CancellationToken cancellationToken) : base(amount)
+            {
+                _syncObj = syncObj;
+                _registration = cancellationToken.UnsafeRegister(s =>
+                {
+                    CancelableWaiter thisRef = (CancelableWaiter)s!;
+                    lock (thisRef._syncObj)
+                    {
+                        if (thisRef.IsPending)
+                        {
+                            thisRef._source.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(thisRef._registration.Token)));
+                            thisRef._registration = default; // benign race with setting in the ctor
+
+                            // We don't remove it from the list as we lack a prev pointer that would enable us to do so correctly,
+                            // and it's not worth adding a prev pointer for the rare case of cancellation.  We instead just
+                            // check when completing a waiter whether it's already been canceled.  As such, we also do not
+                            // dispose it here.
+                        }
+                    }
+                }, this);
+            }
+
+            public override void Dispose()
+            {
+                Monitor.IsEntered(_syncObj);
+
+                _registration.Dispose();
+                _registration = default;
+
+                base.Dispose();
+            }
         }
     }
 }
index be2323c..1e9f209 100644 (file)
@@ -159,31 +159,6 @@ namespace System.Net.Http
             _ = ProcessIncomingFramesAsync();
         }
 
-        private async ValueTask EnsureIncomingBytesAsync(int bytesNeeded)
-        {
-            Debug.Assert(bytesNeeded >= 0);
-            if (NetEventSource.IsEnabled) Trace($"{nameof(bytesNeeded)}={bytesNeeded}");
-
-            bytesNeeded -= _incomingBuffer.ActiveLength;
-            if (bytesNeeded > 0)
-            {
-                _incomingBuffer.EnsureAvailableSpace(bytesNeeded);
-                do
-                {
-                    int bytesRead = await _stream.ReadAsync(_incomingBuffer.AvailableMemory).ConfigureAwait(false);
-                    Debug.Assert(bytesRead >= 0);
-                    if (bytesRead == 0)
-                    {
-                        throw new IOException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, bytesNeeded));
-                    }
-
-                    _incomingBuffer.Commit(bytesRead);
-                    bytesNeeded -= bytesRead;
-                }
-                while (bytesNeeded > 0);
-            }
-        }
-
         private async Task FlushOutgoingBytesAsync()
         {
             if (NetEventSource.IsEnabled) Trace($"{nameof(_outgoingBuffer.ActiveLength)}={_outgoingBuffer.ActiveLength}");
@@ -209,13 +184,21 @@ namespace System.Net.Http
         {
             if (NetEventSource.IsEnabled) Trace($"{nameof(initialFrame)}={initialFrame}");
 
-            // Read frame header
+            // Ensure we've read enough data for the frame header.
             if (_incomingBuffer.ActiveLength < FrameHeader.Size)
             {
-                await EnsureIncomingBytesAsync(FrameHeader.Size).ConfigureAwait(false);
+                _incomingBuffer.EnsureAvailableSpace(FrameHeader.Size - _incomingBuffer.ActiveLength);
+                do
+                {
+                    int bytesRead = await _stream.ReadAsync(_incomingBuffer.AvailableMemory).ConfigureAwait(false);
+                    _incomingBuffer.Commit(bytesRead);
+                    if (bytesRead == 0) ThrowPrematureEOF(FrameHeader.Size);
+                }
+                while (_incomingBuffer.ActiveLength < FrameHeader.Size);
             }
-            FrameHeader frameHeader = FrameHeader.ReadFrom(_incomingBuffer.ActiveSpan);
 
+            // Parse the frame header from our read buffer and validate it.
+            FrameHeader frameHeader = FrameHeader.ReadFrom(_incomingBuffer.ActiveSpan);
             if (frameHeader.Length > FrameHeader.MaxLength)
             {
                 if (initialFrame && NetEventSource.IsEnabled)
@@ -229,19 +212,31 @@ namespace System.Net.Http
             }
             _incomingBuffer.Discard(FrameHeader.Size);
 
-            // Read frame contents
+            // Ensure we've read the frame contents into our buffer.
             if (_incomingBuffer.ActiveLength < frameHeader.Length)
             {
-                await EnsureIncomingBytesAsync(frameHeader.Length).ConfigureAwait(false);
+                _incomingBuffer.EnsureAvailableSpace(frameHeader.Length - _incomingBuffer.ActiveLength);
+                do
+                {
+                    int bytesRead = await _stream.ReadAsync(_incomingBuffer.AvailableMemory).ConfigureAwait(false);
+                    _incomingBuffer.Commit(bytesRead);
+                    if (bytesRead == 0) ThrowPrematureEOF(frameHeader.Length);
+                }
+                while (_incomingBuffer.ActiveLength < frameHeader.Length);
             }
 
+            // Return the read frame header.
             return frameHeader;
+
+            void ThrowPrematureEOF(int requiredBytes) =>
+                throw new IOException(SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, requiredBytes - _incomingBuffer.ActiveLength));
         }
 
         private async Task ProcessIncomingFramesAsync()
         {
             try
             {
+                // Read the initial settings frame.
                 FrameHeader frameHeader = await ReadFrameAsync(initialFrame: true).ConfigureAwait(false);
                 if (frameHeader.Type != FrameType.Settings || frameHeader.AckFlag)
                 {
@@ -255,10 +250,37 @@ namespace System.Net.Http
                 // Keep processing frames as they arrive.
                 for (long frameNum = 1; ; frameNum++)
                 {
-                    await EnsureIncomingBytesAsync(FrameHeader.Size).ConfigureAwait(false); // not functionally necessary, but often ReadFrameAsync yielding/allocating
+                    // We could just call ReadFrameAsync here, but we add this code before it for two reasons:
+                    // 1. To provide a better error message when we're unable to read another frame.  We otherwise
+                    //    generally output an error message that's relatively obscure.
+                    // 2. To avoid another state machine allocation in the relatively common case where we
+                    //    currently don't have enough data buffered and issuing a read for the frame header
+                    //    completes asynchronously, but that read ends up also reading enough data to fulfill
+                    //    the entire frame's needs (not just the header).
+                    if (_incomingBuffer.ActiveLength < FrameHeader.Size)
+                    {
+                        _incomingBuffer.EnsureAvailableSpace(FrameHeader.Size - _incomingBuffer.ActiveLength);
+                        do
+                        {
+                            int bytesRead = await _stream.ReadAsync(_incomingBuffer.AvailableMemory).ConfigureAwait(false);
+                            Debug.Assert(bytesRead >= 0);
+                            _incomingBuffer.Commit(bytesRead);
+                            if (bytesRead == 0)
+                            {
+                                string message = _incomingBuffer.ActiveLength == 0 ?
+                                    SR.net_http_invalid_response_missing_frame :
+                                    SR.Format(SR.net_http_invalid_response_premature_eof_bytecount, FrameHeader.Size - _incomingBuffer.ActiveLength);
+                                throw new IOException(message);
+                            }
+                        }
+                        while (_incomingBuffer.ActiveLength < FrameHeader.Size);
+                    }
+
+                    // Read the frame.
                     frameHeader = await ReadFrameAsync().ConfigureAwait(false);
                     if (NetEventSource.IsEnabled) Trace($"Frame {frameNum}: {frameHeader}.");
 
+                    // Process the frame.
                     switch (frameHeader.Type)
                     {
                         case FrameType.Headers:
index ccd636e..5746f09 100644 (file)
@@ -164,7 +164,9 @@ namespace System.Net.Http
 
                 // Create a linked cancellation token source so that we can cancel the request in the event of receiving RST_STREAM
                 // and similiar situations where we need to cancel the request body (see Cancel method).
-                _requestBodyCancellationToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _requestBodyCancellationSource.Token).Token;
+                _requestBodyCancellationToken = cancellationToken.CanBeCanceled ?
+                    CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _requestBodyCancellationSource.Token).Token :
+                    _requestBodyCancellationSource.Token;
 
                 try
                 {
@@ -1009,8 +1011,6 @@ namespace System.Net.Http
 
             private async ValueTask SendDataAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
             {
-                ReadOnlyMemory<byte> remaining = buffer;
-
                 // Deal with [ActiveIssue("https://github.com/dotnet/runtime/issues/17492")]
                 // Custom HttpContent classes do not get passed the cancellationToken.
                 // So, inject the expected CancellationToken here, to ensure we can cancel the request body send if needed.
@@ -1027,18 +1027,22 @@ namespace System.Net.Http
                     cancellationToken = customCancellationSource.Token;
                 }
 
-                using (customCancellationSource)
+                try
                 {
-                    while (remaining.Length > 0)
+                    while (buffer.Length > 0)
                     {
-                        int sendSize = await _streamWindow.RequestCreditAsync(remaining.Length, cancellationToken).ConfigureAwait(false);
+                        int sendSize = await _streamWindow.RequestCreditAsync(buffer.Length, cancellationToken).ConfigureAwait(false);
 
                         ReadOnlyMemory<byte> current;
-                        (current, remaining) = SplitBuffer(remaining, sendSize);
+                        (current, buffer) = SplitBuffer(buffer, sendSize);
 
                         await _connection.SendStreamDataAsync(_streamId, current, cancellationToken).ConfigureAwait(false);
                     }
                 }
+                finally
+                {
+                    customCancellationSource?.Dispose();
+                }
             }
 
             private void CloseResponseBody()
index 81fc3bf..af703e0 100644 (file)
@@ -682,41 +682,54 @@ namespace System.Net.Security
             {
                 while (true)
                 {
-                    int copyBytes;
                     if (_decryptedBytesCount != 0)
                     {
-                        copyBytes = CopyDecryptedData(buffer);
-
-                        return copyBytes;
+                        return CopyDecryptedData(buffer);
                     }
 
-                    copyBytes = await adapter.ReadLockAsync(buffer).ConfigureAwait(false);
+                    int copyBytes = await adapter.ReadLockAsync(buffer).ConfigureAwait(false);
                     if (copyBytes > 0)
                     {
                         return copyBytes;
                     }
 
                     ResetReadBuffer();
-                    int readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
-                    if (readBytes == 0)
+
+                    // Read the next frame header.
+                    if (_internalBufferCount < SecureChannel.ReadHeaderSize)
                     {
-                        return 0;
+                        // We don't have enough bytes buffered, so issue an initial read to try to get enough.  This is
+                        // done in this method both to better consolidate error handling logic (the first read is the special
+                        // case that needs to differentiate reading 0 from > 0, and everything else needs to throw if it
+                        // doesn't read enough), and to minimize the chances that in the common case the FillBufferAsync
+                        // helper needs to yield and allocate a state machine.
+                        int readBytes = await adapter.ReadAsync(_internalBuffer.AsMemory(_internalBufferCount)).ConfigureAwait(false);
+                        if (readBytes == 0)
+                        {
+                            return 0;
+                        }
+
+                        _internalBufferCount += readBytes;
+                        if (_internalBufferCount < SecureChannel.ReadHeaderSize)
+                        {
+                            await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
+                        }
                     }
+                    Debug.Assert(_internalBufferCount >= SecureChannel.ReadHeaderSize);
 
-                    int payloadBytes = GetFrameSize(new ReadOnlySpan<byte>(_internalBuffer, _internalOffset, readBytes));
+                    // Parse the frame header to determine the payload size (which includes the header size).
+                    int payloadBytes = GetFrameSize(_internalBuffer.AsSpan(_internalOffset));
                     if (payloadBytes < 0)
                     {
                         throw new IOException(SR.net_frame_read_size);
                     }
 
-                    readBytes = await FillBufferAsync(adapter, payloadBytes).ConfigureAwait(false);
-                    Debug.Assert(readBytes >= 0);
-                    if (readBytes == 0)
+                    // Read in the rest of the payload if we don't have it.
+                    if (_internalBufferCount < payloadBytes)
                     {
-                        throw new IOException(SR.net_io_eof);
+                        await FillBufferAsync(adapter, payloadBytes).ConfigureAwait(false);
                     }
 
-                    // At this point, readBytes contains the size of the header plus body.
                     // Set _decrytpedBytesOffset/Count to the current frame we have (including header)
                     // DecryptData will decrypt in-place and modify these to point to the actual decrypted data, which may be smaller.
                     _decryptedBytesOffset = _internalOffset;
@@ -725,7 +738,7 @@ namespace System.Net.Security
 
                     // Treat the bytes we just decrypted as consumed
                     // Note, we won't do another buffer read until the decrypted bytes are processed
-                    ConsumeBufferedBytes(readBytes);
+                    ConsumeBufferedBytes(payloadBytes);
 
                     if (status.ErrorCode != SecurityStatusPalErrorCode.OK)
                     {
@@ -826,68 +839,26 @@ namespace System.Net.Security
                         return minSize;
                     }
 
-                    int bytesNeeded = minSize - _handshakeBuffer.ActiveLength;
                     task = adap.ReadAsync(_handshakeBuffer.AvailableMemory);
                 }
             }
         }
 
-        private ValueTask<int> FillBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
+        private async ValueTask FillBufferAsync<TIOAdapter>(TIOAdapter adapter, int numBytesRequired)
             where TIOAdapter : ISslIOAdapter
         {
-            if (_internalBufferCount >= minSize)
-            {
-                return new ValueTask<int>(minSize);
-            }
+            Debug.Assert(_internalBufferCount > 0);
+            Debug.Assert(_internalBufferCount < numBytesRequired);
 
-            int initialCount = _internalBufferCount;
-            do
+            while (_internalBufferCount < numBytesRequired)
             {
-                ValueTask<int> t = adapter.ReadAsync(new Memory<byte>(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount));
-                if (!t.IsCompletedSuccessfully)
+                int bytesRead = await adapter.ReadAsync(_internalBuffer.AsMemory(_internalBufferCount)).ConfigureAwait(false);
+                if (bytesRead == 0)
                 {
-                    return InternalFillBufferAsync(adapter, t, minSize, initialCount);
+                    throw new IOException(SR.net_io_eof);
                 }
-                int bytes = t.Result;
-                if (bytes == 0)
-                {
-                    if (_internalBufferCount != initialCount)
-                    {
-                        // We read some bytes, but not as many as we expected, so throw.
-                        throw new IOException(SR.net_io_eof);
-                    }
 
-                    return new ValueTask<int>(0);
-                }
-
-                _internalBufferCount += bytes;
-            } while (_internalBufferCount < minSize);
-
-            return new ValueTask<int>(minSize);
-
-            async ValueTask<int> InternalFillBufferAsync(TIOAdapter adap, ValueTask<int> task, int min, int initial)
-            {
-                while (true)
-                {
-                    int b = await task.ConfigureAwait(false);
-                    if (b == 0)
-                    {
-                        if (_internalBufferCount != initial)
-                        {
-                            throw new IOException(SR.net_io_eof);
-                        }
-
-                        return 0;
-                    }
-
-                    _internalBufferCount += b;
-                    if (_internalBufferCount >= min)
-                    {
-                        return min;
-                    }
-
-                    task = adap.ReadAsync(new Memory<byte>(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount));
-                }
+                _internalBufferCount += bytesRead;
             }
         }
 
@@ -949,7 +920,7 @@ namespace System.Net.Security
             int copyBytes = Math.Min(_decryptedBytesCount, buffer.Length);
             if (copyBytes != 0)
             {
-                new Span<byte>(_internalBuffer, _decryptedBytesOffset, copyBytes).CopyTo(buffer.Span);
+                new ReadOnlySpan<byte>(_internalBuffer, _decryptedBytesOffset, copyBytes).CopyTo(buffer.Span);
 
                 _decryptedBytesOffset += copyBytes;
                 _decryptedBytesCount -= copyBytes;
@@ -1168,10 +1139,7 @@ namespace System.Net.Security
             return Framing.Unified; // Will use Ssl2 just for this frame.
         }
 
-        //
-        // This is called from SslStream class too.
         // Returns TLS Frame size.
-        //
         private int GetFrameSize(ReadOnlySpan<byte> buffer)
         {
             if (NetEventSource.IsEnabled)
@@ -1184,7 +1152,7 @@ namespace System.Net.Security
                 case Framing.BeforeSSL3:
                     if (buffer.Length < 2)
                     {
-                        throw new System.IO.IOException(SR.net_ssl_io_frame);
+                        throw new IOException(SR.net_ssl_io_frame);
                     }
                     // Note: Cannot detect version mismatch for <= SSL2
 
@@ -1203,7 +1171,7 @@ namespace System.Net.Security
                 case Framing.SinceSSL3:
                     if (buffer.Length < 5)
                     {
-                        throw new System.IO.IOException(SR.net_ssl_io_frame);
+                        throw new IOException(SR.net_ssl_io_frame);
                     }
 
                     payloadSize = ((buffer[3] << 8) | buffer[4]) + 5;