Avoid unobserved ReadAheadTask exceptions in HttpConnection (#80214)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Wed, 1 Feb 2023 16:12:58 +0000 (08:12 -0800)
committerGitHub <noreply@github.com>
Wed, 1 Feb 2023 16:12:58 +0000 (08:12 -0800)
* Rework ReadAheadTask impl in HttpConnection

* Handle sync exceptions in PrepareForReuse

* Split proactive EOF reactions out of the PR

* Add a few comments

* Remove the now-unnecessary ExecutionContext.IsFlowSuppressed dance

* PR feedback

src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/AuthenticationHelper.NtAuth.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs

index d6f36c4ed8f9943ca53c36a7526c9ff74c6295c7..4de834c4992d06fe278c029ca1d211eaeb5294f6 100644 (file)
@@ -50,7 +50,7 @@ namespace System.Net.Http
         private static Task<HttpResponseMessage> InnerSendAsync(HttpRequestMessage request, bool async, bool isProxyAuth, HttpConnectionPool pool, HttpConnection connection, CancellationToken cancellationToken)
         {
             return isProxyAuth ?
-                connection.SendAsyncCore(request, async, cancellationToken) :
+                connection.SendAsync(request, async, cancellationToken) :
                 pool.SendWithNtProxyAuthAsync(connection, request, async, cancellationToken);
         }
 
index eef8947f384d374aa8dec3da196d7e5258ea3227..fbfa024a647cbd8819b2b3b920d466ac65b703e4 100644 (file)
@@ -54,8 +54,11 @@ namespace System.Net.Http
         [ThreadStatic]
         private static string[]? t_headerValues;
 
-        private ValueTask<int>? _readAheadTask;
-        private int _readAheadTaskLock; // 0 == free, 1 == held
+        private const int ReadAheadTask_NotStarted = 0;
+        private const int ReadAheadTask_Started = 1;
+        private const int ReadAheadTask_CompletionReserved = 2;
+        private int _readAheadTaskStatus;
+        private ValueTask<int> _readAheadTask;
         private ArrayBuffer _readBuffer;
 
         private long _idleSinceTickCount;
@@ -125,19 +128,12 @@ namespace System.Net.Http
                 {
                     GC.SuppressFinalize(this);
                     _stream.Dispose();
-
-                    // Eat any exceptions from the read-ahead task.  We don't need to log, as we expect
-                    // failures from this task due to closing the connection while a read is in progress.
-                    ValueTask<int>? readAheadTask = ConsumeReadAheadTask();
-                    if (readAheadTask != null)
-                    {
-                        IgnoreExceptions(readAheadTask.GetValueOrDefault());
-                    }
                 }
             }
         }
 
-        /// <summary>Prepare an idle connection to be used for a new request.</summary>
+        /// <summary>Prepare an idle connection to be used for a new request.
+        /// The caller MUST call SendAsync afterwards if this method returns true.</summary>
         /// <param name="async">Indicates whether the coming request will be sync or async.</param>
         /// <returns>True if connection can be used, false if it is invalid due to a timeout or receiving EOF or unexpected data.</returns>
         public bool PrepareForReuse(bool async)
@@ -149,9 +145,9 @@ namespace System.Net.Http
 
             // We may already have a read-ahead task if we did a previous scavenge and haven't used the connection since.
             // If the read-ahead task is completed, then we've received either EOF or erroneous data the connection, so it's not usable.
-            if (_readAheadTask is not null)
+            if (ReadAheadTaskHasStarted)
             {
-                return !_readAheadTask.Value.IsCompleted;
+                return TryOwnReadAheadTaskCompletion();
             }
 
             // Check to see if we've received anything on the connection; if we have, that's
@@ -173,6 +169,9 @@ namespace System.Net.Http
             }
             else
             {
+                Debug.Assert(_readAheadTaskStatus == ReadAheadTask_NotStarted);
+                _readAheadTaskStatus = ReadAheadTask_CompletionReserved;
+
                 // Perform an async read on the stream, since we're going to need to read from it
                 // anyway, and in doing so we can avoid the extra syscall.
                 try
@@ -180,7 +179,8 @@ namespace System.Net.Http
 #pragma warning disable CA2012 // we're very careful to ensure the ValueTask is only consumed once, even though it's stored into a field
                     _readAheadTask = _stream.ReadAsync(_readBuffer.AvailableMemory);
 #pragma warning restore CA2012
-                    return !_readAheadTask.Value.IsCompleted;
+
+                    return !_readAheadTask.IsCompleted;
                 }
                 catch (Exception error)
                 {
@@ -201,26 +201,63 @@ namespace System.Net.Http
             }
 
             // We may already have a read-ahead task if we did a previous scavenge and haven't used the connection since.
-#pragma warning disable CA2012 // we're very careful to ensure the ValueTask is only consumed once, even though it's stored into a field
-            _readAheadTask ??= ReadAheadWithZeroByteReadAsync();
-#pragma warning restore CA2012
+            EnsureReadAheadTaskHasStarted();
 
             // If the read-ahead task is completed, then we've received either EOF or erroneous data the connection, so it's not usable.
-            return !_readAheadTask.Value.IsCompleted;
+            return !_readAheadTask.IsCompleted;
+        }
+
+        private bool ReadAheadTaskHasStarted =>
+            _readAheadTaskStatus != ReadAheadTask_NotStarted;
+
+        private bool TryOwnReadAheadTaskCompletion() =>
+            Interlocked.CompareExchange(ref _readAheadTaskStatus, ReadAheadTask_CompletionReserved, ReadAheadTask_Started) == ReadAheadTask_Started;
+
+        private void EnsureReadAheadTaskHasStarted()
+        {
+            if (_readAheadTaskStatus == ReadAheadTask_NotStarted)
+            {
+                Debug.Assert(_readAheadTask == default);
+
+                _readAheadTaskStatus = ReadAheadTask_Started;
+
+#pragma warning disable CA2012 // we're very careful to ensure the ValueTask is only consumed once, even though it's stored into a field
+                _readAheadTask = ReadAheadWithZeroByteReadAsync();
+#pragma warning restore CA2012
+            }
 
             async ValueTask<int> ReadAheadWithZeroByteReadAsync()
             {
-                Debug.Assert(_readAheadTask is null);
+                Debug.Assert(_readAheadTask == default);
                 Debug.Assert(_readBuffer.ActiveLength == 0);
 
-                // Issue a zero-byte read.
-                // If the underlying stream supports it, this will not complete until the stream has data available,
-                // which will avoid pinning the connection's read buffer (and possibly allow us to release it to the buffer pool in the future, if desired).
-                // If not, it will complete immediately.
-                await _stream.ReadAsync(Memory<byte>.Empty).ConfigureAwait(false);
+                try
+                {
+                    // Issue a zero-byte read.
+                    // If the underlying stream supports it, this will not complete until the stream has data available,
+                    // which will avoid pinning the connection's read buffer (and possibly allow us to release it to the buffer pool in the future, if desired).
+                    // If not, it will complete immediately.
+                    await _stream.ReadAsync(Memory<byte>.Empty).ConfigureAwait(false);
 
-                // We don't know for sure that the stream actually has data available, so we need to issue a real read now.
-                return await _stream.ReadAsync(_readBuffer.AvailableMemory).ConfigureAwait(false);
+                    // We don't know for sure that the stream actually has data available, so we need to issue a real read now.
+                    int read = await _stream.ReadAsync(_readBuffer.AvailableMemory).ConfigureAwait(false);
+
+                    // PrepareForReuse will check TryOwnReadAheadTaskCompletion before calling into SendAsync.
+                    // If we can own the completion from within the read-ahead task, it means that PrepareForReuse hasn't been called yet.
+                    // In that case we've received EOF/erroneous data before we sent the request headers, and the connection can't be reused.
+                    if (TryOwnReadAheadTaskCompletion())
+                    {
+                        if (NetEventSource.Log.IsEnabled()) Trace("Read-ahead task observed data before the request was sent.");
+                    }
+
+                    return read;
+                }
+                catch (Exception error) when (TryOwnReadAheadTaskCompletion())
+                {
+                    if (NetEventSource.Log.IsEnabled()) Trace($"Error performing read ahead: {error}");
+
+                    return 0;
+                }
             }
         }
 
@@ -233,21 +270,6 @@ namespace System.Net.Http
                 GetIdleTicks(Environment.TickCount64) >= _keepAliveTimeoutSeconds * 1000;
         }
 
-        private ValueTask<int>? ConsumeReadAheadTask()
-        {
-            if (Interlocked.CompareExchange(ref _readAheadTaskLock, 1, 0) == 0)
-            {
-                ValueTask<int>? t = _readAheadTask;
-                _readAheadTask = null;
-                Volatile.Write(ref _readAheadTaskLock, 0);
-                return t;
-            }
-
-            // We couldn't get the lock, which means it must already be held
-            // by someone else who will consume the task.
-            return null;
-        }
-
         public override long GetIdleTicks(long nowTicks) => nowTicks - _idleSinceTickCount;
 
         public TransportContext? TransportContext => _transportContext;
@@ -489,10 +511,11 @@ namespace System.Net.Http
                 throw new HttpRequestException(SR.net_http_request_invalid_char_encoding);
         }
 
-        public async Task<HttpResponseMessage> SendAsyncCore(HttpRequestMessage request, bool async, CancellationToken cancellationToken)
+        public async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, bool async, CancellationToken cancellationToken)
         {
             Debug.Assert(_currentRequest == null, $"Expected null {nameof(_currentRequest)}.");
             Debug.Assert(_readBuffer.ActiveLength == 0, "Unexpected data in read buffer");
+            Debug.Assert(_readAheadTaskStatus != ReadAheadTask_Started);
 
             TaskCompletionSource<bool>? allowExpect100ToContinue = null;
             Task? sendRequestContentTask = null;
@@ -559,15 +582,22 @@ namespace System.Net.Http
 
                 // When the connection was taken out of the pool, a pre-emptive read was performed
                 // into the read buffer. We need to consume that read prior to issuing another read.
-                ValueTask<int>? t = ConsumeReadAheadTask();
-                if (t != null)
+                if (ReadAheadTaskHasStarted)
                 {
+                    // If the read-ahead task completed synchronously, it would have claimed ownership of its completion,
+                    // meaning that PrepareForReuse would have failed, and we wouldn't have called SendAsync.
+                    // The task therefore shouldn't be 'default', as it's representing an async operation that had to yield at some point.
+                    Debug.Assert(_readAheadTask != default);
+                    Debug.Assert(_readAheadTaskStatus == ReadAheadTask_CompletionReserved);
+
                     // Handle the pre-emptive read.  For the async==false case, hopefully the read has
                     // already completed and this will be a nop, but if it hasn't, the caller will be forced to block
                     // waiting for the async operation to complete.  We will only hit this case for proxied HTTPS
                     // requests that use a pooled connection, as in that case we don't have a Socket we
                     // can poll and are forced to issue an async read.
-                    ValueTask<int> vt = t.GetValueOrDefault();
+                    ValueTask<int> vt = _readAheadTask;
+                    _readAheadTask = default;
+
                     int bytesRead;
                     if (vt.IsCompleted)
                     {
@@ -586,6 +616,8 @@ namespace System.Net.Http
                     _readBuffer.Commit(bytesRead);
 
                     if (NetEventSource.Log.IsEnabled()) Trace($"Received {bytesRead} bytes.");
+
+                    _readAheadTaskStatus = ReadAheadTask_NotStarted;
                 }
                 else
                 {
@@ -795,6 +827,13 @@ namespace System.Net.Http
                 // Make sure to complete the allowExpect100ToContinue task if it exists.
                 allowExpect100ToContinue?.TrySetResult(false);
 
+                if (_readAheadTask != default)
+                {
+                    Debug.Assert(_readAheadTaskStatus == ReadAheadTask_CompletionReserved);
+
+                    LogExceptions(_readAheadTask.AsTask());
+                }
+
                 if (NetEventSource.Log.IsEnabled()) Trace($"Error sending request: {error}");
 
                 // In the rare case where Expect: 100-continue was used and then processing
@@ -838,9 +877,6 @@ namespace System.Net.Http
             }
         }
 
-        public Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, bool async, CancellationToken cancellationToken) =>
-            SendAsyncCore(request, async, cancellationToken);
-
         private bool MapSendException(Exception exception, CancellationToken cancellationToken, out Exception mappedException)
         {
             if (CancellationHelper.ShouldWrapInOperationCanceledException(exception, cancellationToken))
@@ -1557,7 +1593,7 @@ namespace System.Net.Http
         // Does not throw on EOF. Also assumes there is no buffered data.
         private async ValueTask InitialFillAsync(bool async)
         {
-            Debug.Assert(_readAheadTask == null);
+            Debug.Assert(!ReadAheadTaskHasStarted);
             Debug.Assert(_readBuffer.AvailableLength == _readBuffer.Capacity);
             Debug.Assert(_readBuffer.AvailableLength >= InitialReadBufferSize);
 
@@ -1573,7 +1609,7 @@ namespace System.Net.Http
         // Throws IOException on EOF.  This is only called when we expect more data.
         private async ValueTask FillAsync(bool async)
         {
-            Debug.Assert(_readAheadTask == null);
+            Debug.Assert(_readAheadTask == default);
 
             _readBuffer.EnsureAvailableSpace(1);
 
@@ -1700,7 +1736,7 @@ namespace System.Net.Http
 
             // No data in read buffer.
             // Do an unbuffered read directly against the underlying stream.
-            Debug.Assert(_readAheadTask == null, "Read ahead task should have been consumed as part of the headers.");
+            Debug.Assert(_readAheadTask == default, "Read ahead task should have been consumed as part of the headers.");
             int count = _stream.Read(destination);
             if (NetEventSource.Log.IsEnabled()) Trace($"Received {count} bytes.");
             return count;
@@ -1718,7 +1754,7 @@ namespace System.Net.Http
 
             // No data in read buffer.
             // Do an unbuffered read directly against the underlying stream.
-            Debug.Assert(_readAheadTask == null, "Read ahead task should have been consumed as part of the headers.");
+            Debug.Assert(_readAheadTask == default, "Read ahead task should have been consumed as part of the headers.");
             int count = await _stream.ReadAsync(destination).ConfigureAwait(false);
             if (NetEventSource.Log.IsEnabled()) Trace($"Received {count} bytes.");
             return count;
@@ -1731,7 +1767,7 @@ namespace System.Net.Http
             if (_readBuffer.ActiveLength == 0)
             {
                 // Do a buffered read directly against the underlying stream.
-                Debug.Assert(_readAheadTask == null, "Read ahead task should have been consumed as part of the headers.");
+                Debug.Assert(_readAheadTask == default, "Read ahead task should have been consumed as part of the headers.");
 
                 if (destination.Length == 0)
                 {
@@ -1768,7 +1804,7 @@ namespace System.Net.Http
             if (_readBuffer.ActiveLength == 0)
             {
                 // Do a buffered read directly against the underlying stream.
-                Debug.Assert(_readAheadTask == null, "Read ahead task should have been consumed as part of the headers.");
+                Debug.Assert(_readAheadTask == default, "Read ahead task should have been consumed as part of the headers.");
 
                 Debug.Assert(_readBuffer.AvailableLength == _readBuffer.Capacity);
                 int bytesRead = await _stream.ReadAsync(_readBuffer.AvailableMemory).ConfigureAwait(false);
@@ -2025,7 +2061,8 @@ namespace System.Net.Http
         private void ReturnConnectionToPool()
         {
             Debug.Assert(_currentRequest == null, "Connection should no longer be associated with a request.");
-            Debug.Assert(_readAheadTask == null, "Expected a previous initial read to already be consumed.");
+            Debug.Assert(_readAheadTask == default, "Expected a previous initial read to already be consumed.");
+            Debug.Assert(_readAheadTaskStatus == ReadAheadTask_NotStarted, "Expected SendAsync to reset the read-ahead task status.");
             Debug.Assert(_readBuffer.ActiveLength == 0, "Unexpected data in connection read buffer.");
 
             // If we decided not to reuse the connection (either because the server sent Connection: close,
index 4ad0546969f0c098d4c4afa8c4330ee3cef99439..f33667e2d72e08f4f0781e0b3de6760b259c4639 100644 (file)
@@ -33,6 +33,72 @@ namespace System.Net.Http.Functional.Tests
     {
         public SocketsHttpHandler_HttpClientHandler_Asynchrony_Test(ITestOutputHelper output) : base(output) { }
 
+        [OuterLoop("Relies on finalization")]
+        [Fact]
+        public async Task ReadAheadTaskOnScavenge_ExceptionsAreObserved()
+        {
+            bool seenUnobservedExceptions = false;
+
+            EventHandler<UnobservedTaskExceptionEventArgs> eventHandler = (_, e) =>
+            {
+                if (e.Exception.InnerException?.Message == nameof(ReadAheadTaskOnScavenge_ExceptionsAreObserved))
+                {
+                    seenUnobservedExceptions = true;
+                }
+            };
+
+            TaskScheduler.UnobservedTaskException += eventHandler;
+            try
+            {
+                for (int i = 0; i < 3; i++)
+                {
+                    await MakeARequestWithoutDisposingTheHandlerAsync();
+                    GC.Collect();
+                    GC.WaitForPendingFinalizers();
+                    await Task.Delay(1000);
+                }
+            }
+            finally
+            {
+                TaskScheduler.UnobservedTaskException -= eventHandler;
+            }
+
+            Assert.False(seenUnobservedExceptions);
+
+            static async Task MakeARequestWithoutDisposingTheHandlerAsync()
+            {
+                var cts = new CancellationTokenSource();
+                var requestCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+                var handler = new SocketsHttpHandler();
+                handler.ConnectCallback = async (_, _) =>
+                {
+                    cts.Cancel();
+                    await requestCompleted.Task;
+
+                    Task completedWhenFinalized = new SetOnFinalized().CompletedWhenFinalized.Task;
+
+                    return new DelegateDelegatingStream(Stream.Null)
+                    {
+                        ReadAsyncMemoryFunc = async (_, _) =>
+                        {
+                            await completedWhenFinalized.WaitAsync(TestHelper.PassingTestTimeout);
+
+                            throw new Exception(nameof(ReadAheadTaskOnScavenge_ExceptionsAreObserved));
+                        }
+                    };
+                };
+
+                handler.PooledConnectionIdleTimeout = TimeSpan.FromSeconds(1);
+
+                var client = new HttpClient(handler);
+
+                await Assert.ThrowsAsync<TaskCanceledException>(() => client.GetStringAsync("http://foo", cts.Token));
+
+                requestCompleted.SetResult();
+            }
+        }
+
         [Fact]
         public async Task ExecutionContext_Suppressed_Success()
         {
@@ -91,10 +157,9 @@ namespace System.Net.Http.Functional.Tests
         [MethodImpl(MethodImplOptions.NoInlining)] // avoid JIT extending lifetime of the finalizable object
         private static (Task completedOnFinalized, Task getRequest) MakeHttpRequestWithTcsSetOnFinalizationInAsyncLocal(HttpClient client, Uri uri)
         {
-            var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
-
             // Put something in ExecutionContext, start the HTTP request, then undo the EC change.
-            var al = new AsyncLocal<object>() { Value = new SetOnFinalized() { _completedWhenFinalized = tcs } };
+            var al = new AsyncLocal<SetOnFinalized>() { Value = new SetOnFinalized() };
+            TaskCompletionSource tcs = al.Value.CompletedWhenFinalized;
             Task t = client.GetStringAsync(uri);
             al.Value = null;
 
@@ -108,8 +173,9 @@ namespace System.Net.Http.Functional.Tests
 
         private sealed class SetOnFinalized
         {
-            internal TaskCompletionSource _completedWhenFinalized;
-            ~SetOnFinalized() => _completedWhenFinalized.SetResult();
+            public readonly TaskCompletionSource CompletedWhenFinalized = new(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            ~SetOnFinalized() => CompletedWhenFinalized.SetResult();
         }
     }