Remove Http3RequestStream from _activeRequest when response is read/aborted. (#72670)
authorRadek Zikmund <32671551+rzikm@users.noreply.github.com>
Tue, 26 Jul 2022 14:43:30 +0000 (16:43 +0200)
committerGitHub <noreply@github.com>
Tue, 26 Jul 2022 14:43:30 +0000 (16:43 +0200)
* Remove Http3RequestStream from _activeRequest when response is read/aborted.

* Add time limit in test

* Wait for both sides to be completed

* Code review feedback

src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Cancellation.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs

index 9382e87..4afc8a7 100644 (file)
@@ -259,18 +259,14 @@ namespace System.Net.Http.Functional.Tests
                     Task<HttpResponseMessage> getResponse = client.SendAsync(TestAsync, req, HttpCompletionOption.ResponseHeadersRead, cts.Token);
                     await ValidateClientCancellationAsync(async () =>
                     {
-                        // This 'using' shouldn't be necessary in general. However, HTTP3 does not remove the request stream from the
-                        // active stream table until the user disposes the response (or it gets finalized).
-                        // This means the connection will fail to shut down promptly.
-                        // See https://github.com/dotnet/runtime/issues/58072
-                        using HttpResponseMessage resp = await getResponse;
+                        HttpResponseMessage resp = await getResponse;
                         Stream respStream = await resp.Content.ReadAsStreamAsync(TestAsync);
                         Task readTask = readOrCopyToAsync ?
                             respStream.ReadAsync(new byte[1], 0, 1, cts.Token) :
                             respStream.CopyToAsync(Stream.Null, 10, cts.Token);
                         cts.Cancel();
                         await readTask;
-                    });
+                    }).WaitAsync(TimeSpan.FromSeconds(30));
                     try
                     {
                         clientFinished.SetResult(true);
@@ -499,7 +495,7 @@ namespace System.Net.Http.Functional.Tests
                     canSeekFunc: () => true,
                     lengthFunc: () => 1,
                     positionGetFunc: () => 0,
-                    positionSetFunc: _ => {},
+                    positionSetFunc: _ => { },
                     readAsyncFunc: async (buffer, offset, count, cancellationToken) =>
                     {
                         int result = 1;
@@ -535,7 +531,7 @@ namespace System.Net.Http.Functional.Tests
                     canSeekFunc: () => true,
                     lengthFunc: () => 1,
                     positionGetFunc: () => 0,
-                    positionSetFunc: _ => {},
+                    positionSetFunc: _ => { },
                     readAsyncFunc: async (buffer, offset, count, cancellationToken) =>
                     {
                         int result = 1;
@@ -584,14 +580,14 @@ namespace System.Net.Http.Functional.Tests
                 {
                     using (var invoker = new HttpMessageInvoker(CreateHttpClientHandler()))
                     using (var req = new HttpRequestMessage(HttpMethod.Post, uri) { Content = content, Version = UseVersion })
-                    try
-                    {
-                        using (HttpResponseMessage resp = await invoker.SendAsync(TestAsync, req, cancellationTokenSource.Token))
+                        try
                         {
-                            Assert.Equal("Hello World", await resp.Content.ReadAsStringAsync());
+                            using (HttpResponseMessage resp = await invoker.SendAsync(TestAsync, req, cancellationTokenSource.Token))
+                            {
+                                Assert.Equal("Hello World", await resp.Content.ReadAsStringAsync());
+                            }
                         }
-                    }
-                    catch (OperationCanceledException) { }
+                        catch (OperationCanceledException) { }
                 },
                 async server =>
                 {
index bfebc04..f0403c4 100644 (file)
@@ -338,9 +338,8 @@ namespace System.Net.Http
             lock (SyncObj)
             {
                 bool removed = _activeRequests.Remove(stream);
-                Debug.Assert(removed);
 
-                if (ShuttingDown)
+                if (removed && ShuttingDown)
                 {
                     CheckForShutdown();
                 }
index d5a4e75..a11f8b8 100644 (file)
@@ -56,6 +56,9 @@ namespace System.Net.Http
         // For the precomputed length case, we need to add the DATA framing for the first write only.
         private bool _singleDataFrameWritten;
 
+        private bool _requestSendCompleted;
+        private bool _responseRecvCompleted;
+
         public long StreamId
         {
             get => Volatile.Read(ref _streamId);
@@ -74,6 +77,9 @@ namespace System.Net.Http
             _headerDecoder = new QPackDecoder(maxHeadersLength: (int)Math.Min(int.MaxValue, _headerBudgetRemaining));
 
             _requestBodyCancellationSource = new CancellationTokenSource();
+
+            _requestSendCompleted = _request.Content == null;
+            _responseRecvCompleted = false;
         }
 
         public void Dispose()
@@ -87,6 +93,14 @@ namespace System.Net.Http
             }
         }
 
+        private void RemoveFromConnectionIfDone()
+        {
+            if (_responseRecvCompleted && _requestSendCompleted)
+            {
+                _connection.RemoveStream(_stream);
+            }
+        }
+
         public async ValueTask DisposeAsync()
         {
             if (!_disposed)
@@ -338,74 +352,79 @@ namespace System.Net.Http
 
         private async Task SendContentAsync(HttpContent content, CancellationToken cancellationToken)
         {
-            // If we're using Expect 100 Continue, wait to send content
-            // until we get a response back or until our timeout elapses.
-            if (_expect100ContinueCompletionSource != null)
+            try
             {
-                Timer? timer = null;
-
-                try
+                // If we're using Expect 100 Continue, wait to send content
+                // until we get a response back or until our timeout elapses.
+                if (_expect100ContinueCompletionSource != null)
                 {
-                    if (_connection.Pool.Settings._expect100ContinueTimeout != Timeout.InfiniteTimeSpan)
-                    {
-                        timer = new Timer(static o => ((Http3RequestStream)o!)._expect100ContinueCompletionSource!.TrySetResult(true),
-                            this, _connection.Pool.Settings._expect100ContinueTimeout, Timeout.InfiniteTimeSpan);
-                    }
+                    Timer? timer = null;
 
-                    if (!await _expect100ContinueCompletionSource.Task.ConfigureAwait(false))
+                    try
                     {
-                        // We received an error response code, so the body should not be sent.
-                        return;
+                        if (_connection.Pool.Settings._expect100ContinueTimeout != Timeout.InfiniteTimeSpan)
+                        {
+                            timer = new Timer(static o => ((Http3RequestStream)o!)._expect100ContinueCompletionSource!.TrySetResult(true),
+                                this, _connection.Pool.Settings._expect100ContinueTimeout, Timeout.InfiniteTimeSpan);
+                        }
+
+                        if (!await _expect100ContinueCompletionSource.Task.ConfigureAwait(false))
+                        {
+                            // We received an error response code, so the body should not be sent.
+                            return;
+                        }
                     }
-                }
-                finally
-                {
-                    if (timer != null)
+                    finally
                     {
-                        await timer.DisposeAsync().ConfigureAwait(false);
+                        if (timer != null)
+                        {
+                            await timer.DisposeAsync().ConfigureAwait(false);
+                        }
                     }
                 }
-            }
 
-            if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.RequestContentStart();
+                if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.RequestContentStart();
 
-            // If we have a Content-Length, keep track of it so we don't over-send and so we can send in a single DATA frame.
-            _requestContentLengthRemaining = content.Headers.ContentLength ?? -1;
+                // If we have a Content-Length, keep track of it so we don't over-send and so we can send in a single DATA frame.
+                _requestContentLengthRemaining = content.Headers.ContentLength ?? -1;
 
-            var writeStream = new Http3WriteStream(this);
-            try
-            {
-                await content.CopyToAsync(writeStream, null, cancellationToken).ConfigureAwait(false);
-            }
-            finally
-            {
-                writeStream.Dispose();
-            }
+                long bytesWritten;
+                using (var writeStream = new Http3WriteStream(this))
+                {
+                    await content.CopyToAsync(writeStream, null, cancellationToken).ConfigureAwait(false);
+                    bytesWritten = writeStream.BytesWritten;
+                }
 
-            if (_requestContentLengthRemaining > 0)
-            {
-                // The number of bytes we actually sent doesn't match the advertised Content-Length
-                long contentLength = content.Headers.ContentLength.GetValueOrDefault();
-                long sent = contentLength - _requestContentLengthRemaining;
-                throw new HttpRequestException(SR.Format(SR.net_http_request_content_length_mismatch, sent, contentLength));
-            }
+                if (_requestContentLengthRemaining > 0)
+                {
+                    // The number of bytes we actually sent doesn't match the advertised Content-Length
+                    long contentLength = content.Headers.ContentLength.GetValueOrDefault();
+                    long sent = contentLength - _requestContentLengthRemaining;
+                    throw new HttpRequestException(SR.Format(SR.net_http_request_content_length_mismatch, sent, contentLength));
+                }
 
-            // Set to 0 to recognize that the whole request body has been sent and therefore there's no need to abort write side in case of a premature disposal.
-            _requestContentLengthRemaining = 0;
+                // Set to 0 to recognize that the whole request body has been sent and therefore there's no need to abort write side in case of a premature disposal.
+                _requestContentLengthRemaining = 0;
 
-            if (_sendBuffer.ActiveLength != 0)
-            {
-                // Our initial send buffer, which has our headers, is normally sent out on the first write to the Http3WriteStream.
-                // If we get here, it means the content didn't actually do any writing. Send out the headers now.
-                // Also send the FIN flag, since this is the last write. No need to call Shutdown separately.
-                await FlushSendBufferAsync(endStream: true, cancellationToken).ConfigureAwait(false);
+                if (_sendBuffer.ActiveLength != 0)
+                {
+                    // Our initial send buffer, which has our headers, is normally sent out on the first write to the Http3WriteStream.
+                    // If we get here, it means the content didn't actually do any writing. Send out the headers now.
+                    // Also send the FIN flag, since this is the last write. No need to call Shutdown separately.
+                    await FlushSendBufferAsync(endStream: true, cancellationToken).ConfigureAwait(false);
+                }
+                else
+                {
+                    _stream.CompleteWrites();
+                }
+
+                if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.RequestContentStop(bytesWritten);
             }
-            else
+            finally
             {
-                _stream.CompleteWrites();
+                _requestSendCompleted = true;
+                RemoveFromConnectionIfDone();
             }
-
-            if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.RequestContentStop(writeStream.BytesWritten);
         }
 
         private async ValueTask WriteRequestContentAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
@@ -1064,6 +1083,8 @@ namespace System.Net.Http
                     if (_responseDataPayloadRemaining <= 0 && !ReadNextDataFrameAsync(response, CancellationToken.None).AsTask().GetAwaiter().GetResult())
                     {
                         // End of stream.
+                        _responseRecvCompleted = true;
+                        RemoveFromConnectionIfDone();
                         break;
                     }
 
@@ -1134,6 +1155,8 @@ namespace System.Net.Http
                     if (_responseDataPayloadRemaining <= 0 && !await ReadNextDataFrameAsync(response, cancellationToken).ConfigureAwait(false))
                     {
                         // End of stream.
+                        _responseRecvCompleted = true;
+                        RemoveFromConnectionIfDone();
                         break;
                     }
 
@@ -1192,6 +1215,10 @@ namespace System.Net.Http
         [DoesNotReturn]
         private void HandleReadResponseContentException(Exception ex, CancellationToken cancellationToken)
         {
+            // The stream is, or is going to be aborted
+            _responseRecvCompleted = true;
+            RemoveFromConnectionIfDone();
+
             switch (ex)
             {
                 case QuicException e when (e.QuicError == QuicError.StreamAborted):