Simplify Http2Connection shutdown/dispose logic (#90094)
authorMiha Zupan <mihazupan.zupan1@gmail.com>
Tue, 8 Aug 2023 17:44:03 +0000 (19:44 +0200)
committerGitHub <noreply@github.com>
Tue, 8 Aug 2023 17:44:03 +0000 (10:44 -0700)
* Simplify Http2Connection shutdown/dispose logic

* Add a test for SocketsHttpHandler disposal mid request

src/libraries/Common/tests/System/IO/DelegateDelegatingStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs

index 53de599..e801eb9 100644 (file)
@@ -37,6 +37,8 @@ namespace System.IO
         public Action<byte[], int, int> WriteFunc { get; set; }
         public Func<byte[], int, int, CancellationToken, Task> WriteAsyncArrayFunc { get; set; }
         public Func<ReadOnlyMemory<byte>, CancellationToken, ValueTask> WriteAsyncMemoryFunc { get; set; }
+        public Action<Stream, int> CopyToFunc { get; set; }
+        public Func<Stream, int, CancellationToken, Task> CopyToAsyncFunc { get; set; }
         public Action<bool> DisposeFunc { get; set; }
         public Func<ValueTask> DisposeAsyncFunc { get; set; }
 
@@ -62,6 +64,9 @@ namespace System.IO
         public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => WriteAsyncArrayFunc != null ? WriteAsyncArrayFunc(buffer, offset, count, cancellationToken) : base.WriteAsync(buffer, offset, count, cancellationToken);
         public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) => WriteAsyncMemoryFunc != null ? WriteAsyncMemoryFunc(buffer, cancellationToken) : base.WriteAsync(buffer, cancellationToken);
 
+        public override void CopyTo(Stream destination, int bufferSize) { if (CopyToFunc != null) CopyToFunc(destination, bufferSize); else base.CopyTo(destination, bufferSize); }
+        public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) => CopyToAsyncFunc != null ? CopyToAsyncFunc(destination, bufferSize, cancellationToken) : base.CopyToAsync(destination, bufferSize, cancellationToken);
+
         protected override void Dispose(bool disposing) { if (DisposeFunc != null) DisposeFunc(disposing); else base.Dispose(disposing); }
         public override ValueTask DisposeAsync() => DisposeAsyncFunc != null ? DisposeAsyncFunc() : base.DisposeAsync();
     }
index 2b2a1d0..f96ce45 100644 (file)
@@ -64,18 +64,15 @@ namespace System.Net.Http
         // (1) We received a GOAWAY frame from the server
         // (2) We have exhaustead StreamIds (i.e. _nextStream == MaxStreamId)
         // (3) A connection-level error occurred, in which case _abortException below is set.
+        // (4) The connection is being disposed.
+        // Requests currently in flight will continue to be processed.
+        // When all requests have completed, the connection will be torn down.
         private bool _shutdown;
-        private TaskCompletionSource? _shutdownWaiter;
 
         // If this is set, the connection is aborting due to an IO failure (IOException) or a protocol violation (Http2ProtocolException).
         // _shutdown above is true, and requests in flight have been (or are being) failed.
         private Exception? _abortException;
 
-        // This means that the user (i.e. the connection pool) has disposed us and will not submit further requests.
-        // Requests currently in flight will continue to be processed.
-        // When all requests have completed, the connection will be torn down.
-        private bool _disposed;
-
         private const int MaxStreamId = int.MaxValue;
 
         // Temporary workaround for request burst handling on connection start.
@@ -255,51 +252,23 @@ namespace System.Net.Http
             _ = ProcessOutgoingFramesAsync();
         }
 
-        // This will complete when the connection begins to shut down and cannot be used anymore, or if it is disposed.
-        public ValueTask WaitForShutdownAsync()
-        {
-            lock (SyncObject)
-            {
-                Debug.Assert(!_disposed, "As currently used, we don't expect to call this after disposing and we don't handle the ODE");
-                ObjectDisposedException.ThrowIf(_disposed, this);
-
-                if (_shutdown)
-                {
-                    Debug.Assert(_shutdownWaiter is null);
-                    return default;
-                }
-
-                _shutdownWaiter ??= new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
-
-                return new ValueTask(_shutdownWaiter.Task);
-            }
-        }
-
         private void Shutdown()
         {
             if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(_shutdown)}={_shutdown}, {nameof(_abortException)}={_abortException}");
 
             Debug.Assert(Monitor.IsEntered(SyncObject));
 
-            SignalAvailableStreamsWaiter(false);
-            SignalShutdownWaiter();
-
-            // Note _shutdown could already be set, but that's fine.
-            _shutdown = true;
-        }
-
-        private void SignalShutdownWaiter()
-        {
-            if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(_shutdownWaiter)}?={_shutdownWaiter is not null}");
+            if (!_shutdown)
+            {
+                _pool.InvalidateHttp2Connection(this);
+                SignalAvailableStreamsWaiter(false);
 
-            Debug.Assert(Monitor.IsEntered(SyncObject));
+                _shutdown = true;
 
-            if (_shutdownWaiter is not null)
-            {
-                Debug.Assert(!_disposed);
-                Debug.Assert(!_shutdown);
-                _shutdownWaiter.SetResult();
-                _shutdownWaiter = null;
+                if (_streamsInUse == 0)
+                {
+                    FinalTeardown();
+                }
             }
         }
 
@@ -307,9 +276,6 @@ namespace System.Net.Http
         {
             lock (SyncObject)
             {
-                Debug.Assert(!_disposed, "As currently used, we don't expect to call this after disposing and we don't handle the ODE");
-                ObjectDisposedException.ThrowIf(_disposed, this);
-
                 if (_shutdown)
                 {
                     return false;
@@ -353,7 +319,7 @@ namespace System.Net.Http
                 {
                     MarkConnectionAsIdle();
 
-                    if (_disposed)
+                    if (_shutdown)
                     {
                         FinalTeardown();
                     }
@@ -367,9 +333,6 @@ namespace System.Net.Http
         {
             lock (SyncObject)
             {
-                Debug.Assert(!_disposed, "As currently used, we don't expect to call this after disposing and we don't handle the ODE");
-                ObjectDisposedException.ThrowIf(_disposed, this);
-
                 Debug.Assert(_availableStreamsWaiter is null, "As used currently, shouldn't already have a waiter");
 
                 if (_shutdown)
@@ -396,7 +359,6 @@ namespace System.Net.Http
 
             if (_availableStreamsWaiter is not null)
             {
-                Debug.Assert(!_disposed);
                 Debug.Assert(!_shutdown);
                 _availableStreamsWaiter.SetResult(result);
                 _availableStreamsWaiter = null;
@@ -1213,7 +1175,7 @@ namespace System.Net.Http
                 // We must be trying to send something asynchronously (like RST_STREAM or a PING or a SETTINGS ACK) and it has raced with the connection tear down.
                 // As such, it should not matter that we were not able to actually send the frame.
                 // But just in case, throw ObjectDisposedException. Asynchronous callers will ignore the failure.
-                Debug.Assert(_disposed && _streamsInUse == 0);
+                Debug.Assert(_shutdown && _streamsInUse == 0);
                 return Task.FromException(new ObjectDisposedException(nameof(Http2Connection)));
             }
 
@@ -1342,7 +1304,7 @@ namespace System.Net.Http
 
         internal void HeartBeat()
         {
-            if (_disposed)
+            if (_shutdown)
                 return;
 
             try
@@ -1880,7 +1842,7 @@ namespace System.Net.Http
         {
             if (NetEventSource.Log.IsEnabled()) Trace("");
 
-            Debug.Assert(_disposed);
+            Debug.Assert(_shutdown);
             Debug.Assert(_streamsInUse == 0);
 
             GC.SuppressFinalize(this);
@@ -1901,20 +1863,7 @@ namespace System.Net.Http
         {
             lock (SyncObject)
             {
-                if (NetEventSource.Log.IsEnabled()) Trace($"{nameof(_disposed)}={_disposed}, {nameof(_streamsInUse)}={_streamsInUse}");
-
-                if (!_disposed)
-                {
-                    SignalAvailableStreamsWaiter(false);
-                    SignalShutdownWaiter();
-
-                    _disposed = true;
-
-                    if (_streamsInUse == 0)
-                    {
-                        FinalTeardown();
-                    }
-                }
+                Shutdown();
             }
         }
 
index cf40ee0..163a949 100644 (file)
@@ -742,17 +742,8 @@ namespace System.Net.Http
 
             if (connection is not null)
             {
-                // Register for shutdown notification.
-                // Do this before we return the connection to the pool, because that may result in it being disposed.
-                ValueTask shutdownTask = connection.WaitForShutdownAsync();
-
                 // Add the new connection to the pool.
                 ReturnHttp2Connection(connection, isNewConnection: true, queueItem.Waiter);
-
-                // Wait for connection shutdown.
-                await shutdownTask.ConfigureAwait(false);
-
-                InvalidateHttp2Connection(connection);
             }
             else
             {
index 6adf4a5..c793a1d 100644 (file)
@@ -412,5 +412,67 @@ namespace System.Net.Http.Functional.Tests
                 return base.SerializeToStreamAsync(stream, context);
             }
         }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task RequestSent_HandlerDisposed_RequestIsUnaffected(bool post)
+        {
+            byte[] postContent = "Hello world"u8.ToArray();
+
+            TaskCompletionSource serverReceivedRequest = new(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            await LoopbackServerFactory.CreateClientAndServerAsync(async uri =>
+            {
+                using HttpClientHandler handler = CreateHttpClientHandler();
+                using HttpClient client = CreateHttpClient(handler);
+
+                using HttpRequestMessage request = CreateRequest(post ? HttpMethod.Post : HttpMethod.Get, uri, UseVersion);
+
+                if (post)
+                {
+                    request.Content = new StreamContent(new DelegateDelegatingStream(new MemoryStream())
+                    {
+                        CanSeekFunc = () => false,
+                        CopyToFunc = (destination, _) =>
+                        {
+                            destination.Flush();
+                            Assert.True(serverReceivedRequest.Task.Wait(TestHelper.PassingTestTimeout));
+                            destination.Write(postContent);
+                        },
+                        CopyToAsyncFunc = async (destination, _, ct) =>
+                        {
+                            await destination.FlushAsync(ct);
+                            await serverReceivedRequest.Task.WaitAsync(ct);
+                            await destination.WriteAsync(postContent, ct);
+                        }
+                    });
+                }
+
+                Task<HttpResponseMessage> clientTask = client.SendAsync(TestAsync, request);
+                await serverReceivedRequest.Task.WaitAsync(TestHelper.PassingTestTimeout);
+
+                handler.Dispose();
+                await Task.Delay(1); // Give any potential disposal/cancellation some time to propagate
+
+                await clientTask;
+            },
+            async server =>
+            {
+                await server.AcceptConnectionAsync(async connection =>
+                {
+                    await connection.ReadRequestDataAsync(readBody: false);
+                    serverReceivedRequest.SetResult();
+
+                    if (post)
+                    {
+                        byte[] received = await connection.ReadRequestBodyAsync();
+                        Assert.Equal(postContent, received);
+                    }
+
+                    await connection.SendResponseAsync();
+                });
+            });
+        }
     }
 }