Fix HttpClient.CancelAllPending/Timeout handling for GetString/ByteArrayAsync (#42346)
authorStephen Toub <stoub@microsoft.com>
Thu, 1 Oct 2020 11:45:12 +0000 (07:45 -0400)
committerGitHub <noreply@github.com>
Thu, 1 Oct 2020 11:45:12 +0000 (07:45 -0400)
* Fix HttpClient.CancelAllPending/Timeout handling for GetString/ByteArrayAsync

When GetStringAsync and GetByteArrayAsync are reading the response body, they're only paying attention to the provided CancellationToken; they're not paying attention to the HttpClient's CancelAllPending or Timeout.

* Address PR feedback

src/libraries/System.Net.Http/src/System/Net/Http/HttpClient.cs
src/libraries/System.Net.Http/src/System/Net/Http/NetEventSource.Http.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientTest.cs
src/libraries/System.Net.Http/tests/FunctionalTests/TelemetryTest.cs

index bb03dee..b6fe5d7 100644 (file)
@@ -2,9 +2,9 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
 using System.IO;
 using System.Net.Http.Headers;
-using System.Runtime.CompilerServices;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -175,78 +175,57 @@ namespace System.Net.Http
 
         private async Task<string> GetStringAsyncCore(HttpRequestMessage request, CancellationToken cancellationToken)
         {
-            bool telemetryStarted = false, responseContentTelemetryStarted = false;
-            if (HttpTelemetry.Log.IsEnabled() && request.RequestUri != null)
-            {
-                HttpTelemetry.Log.RequestStart(request);
-                telemetryStarted = true;
-            }
+            bool telemetryStarted = StartSend(request);
+            bool responseContentTelemetryStarted = false;
 
+            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
+            HttpResponseMessage? response = null;
             try
             {
-                // Wait for the response message.
-                using (HttpResponseMessage responseMessage = await SendAsyncCore(request, HttpCompletionOption.ResponseHeadersRead, async: true, emitTelemetryStartStop: false, cancellationToken).ConfigureAwait(false))
+                // Wait for the response message and make sure it completed successfully.
+                response = await base.SendAsync(request, cts.Token).ConfigureAwait(false);
+                ThrowForNullResponse(response);
+                response.EnsureSuccessStatusCode();
+
+                // Get the response content.
+                HttpContent c = response.Content;
+                if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
                 {
-                    // Make sure it completed successfully.
-                    responseMessage.EnsureSuccessStatusCode();
+                    HttpTelemetry.Log.ResponseContentStart();
+                    responseContentTelemetryStarted = true;
+                }
 
-                    // Get the response content.
-                    HttpContent? c = responseMessage.Content;
-                    if (c != null)
-                    {
-                        if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
-                        {
-                            HttpTelemetry.Log.ResponseContentStart();
-                            responseContentTelemetryStarted = true;
-                        }
-    #if NET46
-                        return await c.ReadAsStringAsync().ConfigureAwait(false);
-    #else
-                        HttpContentHeaders headers = c.Headers;
-
-                        // Since the underlying byte[] will never be exposed, we use an ArrayPool-backed
-                        // stream to which we copy all of the data from the response.
-                        using (Stream responseStream = c.TryReadAsStream() ?? await c.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false))
-                        using (var buffer = new HttpContent.LimitArrayPoolWriteStream(_maxResponseContentBufferSize, (int)headers.ContentLength.GetValueOrDefault()))
-                        {
-                            try
-                            {
-                                await responseStream.CopyToAsync(buffer, cancellationToken).ConfigureAwait(false);
-                            }
-                            catch (Exception e) when (HttpContent.StreamCopyExceptionNeedsWrapping(e))
-                            {
-                                throw HttpContent.WrapStreamCopyException(e);
-                            }
-
-                            if (buffer.Length > 0)
-                            {
-                                // Decode and return the data from the buffer.
-                                return HttpContent.ReadBufferAsString(buffer.GetBuffer(), headers);
-                            }
-                        }
-    #endif
-                    }
+                // Since the underlying byte[] will never be exposed, we use an ArrayPool-backed
+                // stream to which we copy all of the data from the response.
+                using Stream responseStream = c.TryReadAsStream() ?? await c.ReadAsStreamAsync(cts.Token).ConfigureAwait(false);
+                using var buffer = new HttpContent.LimitArrayPoolWriteStream(_maxResponseContentBufferSize, (int)c.Headers.ContentLength.GetValueOrDefault());
+
+                try
+                {
+                    await responseStream.CopyToAsync(buffer, cts.Token).ConfigureAwait(false);
+                }
+                catch (Exception e) when (HttpContent.StreamCopyExceptionNeedsWrapping(e))
+                {
+                    throw HttpContent.WrapStreamCopyException(e);
+                }
 
-                    // No content to return.
-                    return string.Empty;
+                if (buffer.Length > 0)
+                {
+                    // Decode and return the data from the buffer.
+                    return HttpContent.ReadBufferAsString(buffer.GetBuffer(), c.Headers);
                 }
+
+                // No content to return.
+                return string.Empty;
             }
-            catch when (LogRequestFailed(telemetryStarted))
+            catch (Exception e)
             {
-                // Unreachable as LogRequestFailed will return false
+                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
                 throw;
             }
             finally
             {
-                if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
-                {
-                    if (responseContentTelemetryStarted)
-                    {
-                        HttpTelemetry.Log.ResponseContentStop();
-                    }
-
-                    HttpTelemetry.Log.RequestStop();
-                }
+                FinishSend(cts, disposeCts, telemetryStarted, responseContentTelemetryStarted);
             }
         }
 
@@ -271,109 +250,61 @@ namespace System.Net.Http
 
         private async Task<byte[]> GetByteArrayAsyncCore(HttpRequestMessage request, CancellationToken cancellationToken)
         {
-            bool telemetryStarted = false, responseContentTelemetryStarted = false;
-            if (HttpTelemetry.Log.IsEnabled() && request.RequestUri != null)
-            {
-                HttpTelemetry.Log.RequestStart(request);
-                telemetryStarted = true;
-            }
+            bool telemetryStarted = StartSend(request);
+            bool responseContentTelemetryStarted = false;
 
+            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
+            HttpResponseMessage? response = null;
             try
             {
-                // Wait for the response message.
-                using (HttpResponseMessage responseMessage = await SendAsyncCore(request, HttpCompletionOption.ResponseHeadersRead, async: true, emitTelemetryStartStop: false, cancellationToken).ConfigureAwait(false))
-                {
-                    // Make sure it completed successfully.
-                    responseMessage.EnsureSuccessStatusCode();
+                // Wait for the response message and make sure it completed successfully.
+                response = await base.SendAsync(request, cts.Token).ConfigureAwait(false);
+                ThrowForNullResponse(response);
+                response.EnsureSuccessStatusCode();
 
-                    // Get the response content.
-                    HttpContent? c = responseMessage.Content;
-                    if (c != null)
-                    {
-                        if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
-                        {
-                            HttpTelemetry.Log.ResponseContentStart();
-                            responseContentTelemetryStarted = true;
-                        }
-    #if NET46
-                        return await c.ReadAsByteArrayAsync().ConfigureAwait(false);
-    #else
-                        HttpContentHeaders headers = c.Headers;
-                        using (Stream responseStream = c.TryReadAsStream() ?? await c.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false))
-                        {
-                            long? contentLength = headers.ContentLength;
-                            Stream buffer; // declared here to share the state machine field across both if/else branches
-
-                            if (contentLength.HasValue)
-                            {
-                                // If we got a content length, then we assume that it's correct and create a MemoryStream
-                                // to which the content will be transferred.  That way, assuming we actually get the exact
-                                // amount we were expecting, we can simply return the MemoryStream's underlying buffer.
-                                buffer = new HttpContent.LimitMemoryStream(_maxResponseContentBufferSize, (int)contentLength.GetValueOrDefault());
-
-                                try
-                                {
-                                    await responseStream.CopyToAsync(buffer, cancellationToken).ConfigureAwait(false);
-                                }
-                                catch (Exception e) when (HttpContent.StreamCopyExceptionNeedsWrapping(e))
-                                {
-                                    throw HttpContent.WrapStreamCopyException(e);
-                                }
-
-                                if (buffer.Length > 0)
-                                {
-                                    return ((HttpContent.LimitMemoryStream)buffer).GetSizedBuffer();
-                                }
-                            }
-                            else
-                            {
-                                // If we didn't get a content length, then we assume we're going to have to grow
-                                // the buffer potentially several times and that it's unlikely the underlying buffer
-                                // at the end will be the exact size needed, in which case it's more beneficial to use
-                                // ArrayPool buffers and copy out to a new array at the end.
-                                buffer = new HttpContent.LimitArrayPoolWriteStream(_maxResponseContentBufferSize);
-                                try
-                                {
-                                    try
-                                    {
-                                        await responseStream.CopyToAsync(buffer, cancellationToken).ConfigureAwait(false);
-                                    }
-                                    catch (Exception e) when (HttpContent.StreamCopyExceptionNeedsWrapping(e))
-                                    {
-                                        throw HttpContent.WrapStreamCopyException(e);
-                                    }
-
-                                    if (buffer.Length > 0)
-                                    {
-                                        return ((HttpContent.LimitArrayPoolWriteStream)buffer).ToArray();
-                                    }
-                                }
-                                finally { buffer.Dispose(); }
-                            }
-                        }
-    #endif
-                    }
+                // Get the response content.
+                HttpContent c = response.Content;
+                if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
+                {
+                    HttpTelemetry.Log.ResponseContentStart();
+                    responseContentTelemetryStarted = true;
+                }
 
-                    // No content to return.
-                    return Array.Empty<byte>();
+                // If we got a content length, then we assume that it's correct and create a MemoryStream
+                // to which the content will be transferred.  That way, assuming we actually get the exact
+                // amount we were expecting, we can simply return the MemoryStream's underlying buffer.
+                // If we didn't get a content length, then we assume we're going to have to grow
+                // the buffer potentially several times and that it's unlikely the underlying buffer
+                // at the end will be the exact size needed, in which case it's more beneficial to use
+                // ArrayPool buffers and copy out to a new array at the end.
+                long? contentLength = c.Headers.ContentLength;
+                using Stream buffer = contentLength.HasValue ?
+                    new HttpContent.LimitMemoryStream(_maxResponseContentBufferSize, (int)contentLength.GetValueOrDefault()) :
+                    new HttpContent.LimitArrayPoolWriteStream(_maxResponseContentBufferSize);
+
+                using Stream responseStream = c.TryReadAsStream() ?? await c.ReadAsStreamAsync(cts.Token).ConfigureAwait(false);
+                try
+                {
+                    await responseStream.CopyToAsync(buffer, cts.Token).ConfigureAwait(false);
                 }
+                catch (Exception e) when (HttpContent.StreamCopyExceptionNeedsWrapping(e))
+                {
+                    throw HttpContent.WrapStreamCopyException(e);
+                }
+
+                return
+                    buffer.Length == 0 ? Array.Empty<byte>() :
+                    buffer is HttpContent.LimitMemoryStream lms ? lms.GetSizedBuffer() :
+                    ((HttpContent.LimitArrayPoolWriteStream)buffer).ToArray();
             }
-            catch when (LogRequestFailed(telemetryStarted))
+            catch (Exception e)
             {
-                // Unreachable as LogRequestFailed will return false
+                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
                 throw;
             }
             finally
             {
-                if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
-                {
-                    if (responseContentTelemetryStarted)
-                    {
-                        HttpTelemetry.Log.ResponseContentStop();
-                    }
-
-                    HttpTelemetry.Log.RequestStop();
-                }
+                FinishSend(cts, disposeCts, telemetryStarted, responseContentTelemetryStarted);
             }
         }
 
@@ -398,46 +329,28 @@ namespace System.Net.Http
 
         private async Task<Stream> GetStreamAsyncCore(HttpRequestMessage request, CancellationToken cancellationToken)
         {
-            bool telemetryStarted = false, responseContentTelemetryStarted = false;
-            if (HttpTelemetry.Log.IsEnabled() && request.RequestUri != null)
-            {
-                HttpTelemetry.Log.RequestStart(request);
-                telemetryStarted = true;
-            }
+            bool telemetryStarted = StartSend(request);
 
+            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
+            HttpResponseMessage? response = null;
             try
             {
-                HttpResponseMessage response = await SendAsyncCore(request, HttpCompletionOption.ResponseHeadersRead, async: true, emitTelemetryStartStop: false, cancellationToken).ConfigureAwait(false);
+                // Wait for the response message and make sure it completed successfully.
+                response = await base.SendAsync(request, cts.Token).ConfigureAwait(false);
+                ThrowForNullResponse(response);
                 response.EnsureSuccessStatusCode();
-                HttpContent? c = response.Content;
-                if (c != null)
-                {
-                    if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
-                    {
-                        HttpTelemetry.Log.ResponseContentStart();
-                        responseContentTelemetryStarted = true;
-                    }
 
-                    return c.TryReadAsStream() ?? await c.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
-                }
-                return Stream.Null;
+                HttpContent c = response.Content;
+                return c.TryReadAsStream() ?? await c.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
             }
-            catch when (LogRequestFailed(telemetryStarted))
+            catch (Exception e)
             {
-                // Unreachable as LogRequestFailed will return false
+                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
                 throw;
             }
             finally
             {
-                if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
-                {
-                    if (responseContentTelemetryStarted)
-                    {
-                        HttpTelemetry.Log.ResponseContentStop();
-                    }
-
-                    HttpTelemetry.Log.RequestStop();
-                }
+                FinishSend(cts, disposeCts, telemetryStarted, responseContentTelemetryStarted: false);
             }
         }
 
@@ -593,19 +506,18 @@ namespace System.Net.Http
             return Send(request, completionOption, cancellationToken: default);
         }
 
-        public override HttpResponseMessage Send(HttpRequestMessage request,
-            CancellationToken cancellationToken)
+        public override HttpResponseMessage Send(HttpRequestMessage request, CancellationToken cancellationToken)
         {
             return Send(request, defaultCompletionOption, cancellationToken);
         }
 
-        public HttpResponseMessage Send(HttpRequestMessage request, HttpCompletionOption completionOption,
-            CancellationToken cancellationToken)
+        public HttpResponseMessage Send(HttpRequestMessage request, HttpCompletionOption completionOption, CancellationToken cancellationToken)
         {
             // Called outside of async state machine to propagate certain exception even without awaiting the returned task.
             CheckRequestBeforeSend(request);
 
-            ValueTask<HttpResponseMessage> sendTask = SendAsyncCore(request, completionOption, async: false, emitTelemetryStartStop: true, cancellationToken);
+            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
+            ValueTask<HttpResponseMessage> sendTask = SendAsyncCore(request, completionOption, async: false, cts, disposeCts, timeoutTime, cancellationToken);
             Debug.Assert(sendTask.IsCompleted);
             return sendTask.GetAwaiter().GetResult();
         }
@@ -615,8 +527,7 @@ namespace System.Net.Http
             return SendAsync(request, defaultCompletionOption, CancellationToken.None);
         }
 
-        public override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request,
-            CancellationToken cancellationToken)
+        public override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
         {
             return SendAsync(request, defaultCompletionOption, cancellationToken);
         }
@@ -626,13 +537,13 @@ namespace System.Net.Http
             return SendAsync(request, completionOption, CancellationToken.None);
         }
 
-        public Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, HttpCompletionOption completionOption,
-            CancellationToken cancellationToken)
+        public Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, HttpCompletionOption completionOption, CancellationToken cancellationToken)
         {
             // Called outside of async state machine to propagate certain exception even without awaiting the returned task.
             CheckRequestBeforeSend(request);
 
-            return SendAsyncCore(request, completionOption, async: true, emitTelemetryStartStop: true, cancellationToken).AsTask();
+            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
+            return SendAsyncCore(request, completionOption, async: true, cts, disposeCts, timeoutTime, cancellationToken).AsTask();
         }
 
         private void CheckRequestBeforeSend(HttpRequestMessage request)
@@ -651,25 +562,11 @@ namespace System.Net.Http
         }
 
         private async ValueTask<HttpResponseMessage> SendAsyncCore(HttpRequestMessage request, HttpCompletionOption completionOption,
-            bool async, bool emitTelemetryStartStop, CancellationToken cancellationToken)
+            bool async, CancellationTokenSource cts, bool disposeCts, long timeoutTime, CancellationToken originalCancellationToken)
         {
-            // Combines given cancellationToken with the global one and the timeout.
-            CancellationTokenSource cts = PrepareCancellationTokenSource(cancellationToken, out bool disposeCts, out long timeoutTime);
-
-            bool buffered = completionOption == HttpCompletionOption.ResponseContentRead &&
-                            !string.Equals(request.Method.Method, "HEAD", StringComparison.OrdinalIgnoreCase);
-
-            bool telemetryStarted = false, responseContentTelemetryStarted = false;
-            if (HttpTelemetry.Log.IsEnabled())
-            {
-                if (emitTelemetryStartStop && request.RequestUri != null)
-                {
-                    HttpTelemetry.Log.RequestStart(request);
-                    telemetryStarted = true;
-                }
-            }
+            bool telemetryStarted = StartSend(request);
+            bool responseContentTelemetryStarted = false;
 
-            // Initiate the send.
             HttpResponseMessage? response = null;
             try
             {
@@ -677,13 +574,11 @@ namespace System.Net.Http
                 response = async ?
                     await base.SendAsync(request, cts.Token).ConfigureAwait(false) :
                     base.Send(request, cts.Token);
-                if (response == null)
-                {
-                    throw new InvalidOperationException(SR.net_http_handler_noresponse);
-                }
+                ThrowForNullResponse(response);
 
-                // Buffer the response content if we've been asked to and we have a Content to buffer.
-                if (buffered && response.Content != null)
+                // Buffer the response content if we've been asked to.
+                if (completionOption == HttpCompletionOption.ResponseContentRead &&
+                    !string.Equals(request.Method.Method, "HEAD", StringComparison.OrdinalIgnoreCase))
                 {
                     if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
                     {
@@ -701,72 +596,80 @@ namespace System.Net.Http
                     }
                 }
 
-                if (NetEventSource.Log.IsEnabled()) NetEventSource.ClientSendCompleted(this, response, request);
                 return response;
             }
             catch (Exception e)
             {
-                LogRequestFailed(telemetryStarted);
-
-                response?.Dispose();
-
-                if (e is OperationCanceledException operationException && TimeoutFired(cancellationToken, timeoutTime))
-                {
-                    HandleSendTimeout(operationException);
-                    throw CreateTimeoutException(operationException);
-                }
-
-                HandleFinishSendAsyncError(e, cts);
+                HandleFailure(e, telemetryStarted, response, cts, originalCancellationToken, timeoutTime);
                 throw;
             }
             finally
             {
-                if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
-                {
-                    if (responseContentTelemetryStarted)
-                    {
-                        HttpTelemetry.Log.ResponseContentStop();
-                    }
-
-                    HttpTelemetry.Log.RequestStop();
-                }
-
-                HandleFinishSendCleanup(cts, disposeCts);
+                FinishSend(cts, disposeCts, telemetryStarted, responseContentTelemetryStarted);
             }
         }
 
-        private bool TimeoutFired(CancellationToken callerToken, long timeoutTime) => !callerToken.IsCancellationRequested && Environment.TickCount64 >= timeoutTime;
-
-        private TaskCanceledException CreateTimeoutException(OperationCanceledException originalException)
+        private static void ThrowForNullResponse([NotNull] HttpResponseMessage? response)
         {
-            return new TaskCanceledException(string.Format(SR.net_http_request_timedout, _timeout.TotalSeconds),
-                new TimeoutException(originalException.Message, originalException), originalException.CancellationToken);
+            if (response is null)
+            {
+                throw new InvalidOperationException(SR.net_http_handler_noresponse);
+            }
         }
 
-        private void HandleFinishSendAsyncError(Exception e, CancellationTokenSource cts)
+        private void HandleFailure(Exception e, bool telemetryStarted, HttpResponseMessage? response, CancellationTokenSource cts, CancellationToken cancellationToken, long timeoutTime)
         {
+            LogRequestFailed(telemetryStarted);
+
+            response?.Dispose();
+
+            Exception? toThrow = null;
+
+            if (e is OperationCanceledException oce && !cancellationToken.IsCancellationRequested && Environment.TickCount64 >= timeoutTime)
+            {
+                // If this exception is for cancellation, but cancellation wasn't requested and instead we find that we've passed a timeout end time,
+                // treat this instead as a timeout.
+                e = toThrow = new TaskCanceledException(string.Format(SR.net_http_request_timedout, _timeout.TotalSeconds), new TimeoutException(e.Message, e), oce.CancellationToken);
+            }
+            else if (cts.IsCancellationRequested && e is HttpRequestException) // if cancellationToken is canceled, cts will also be canceled
+            {
+                // If the cancellation token source was canceled, race conditions abound, and we consider the failure to be
+                // caused by the cancellation (e.g. WebException when reading from canceled response stream).
+                e = toThrow = new OperationCanceledException(cts.Token);
+            }
+
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, e);
 
-            // If the cancellation token was canceled, we consider the exception to be caused by the
-            // cancellation (e.g. WebException when reading from canceled response stream).
-            if (cts.IsCancellationRequested && e is HttpRequestException)
+            if (toThrow != null)
             {
-                if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, "Canceled");
-                throw new OperationCanceledException(cts.Token);
+                throw toThrow;
             }
         }
 
-        private void HandleSendTimeout(OperationCanceledException e)
+        private static bool StartSend(HttpRequestMessage request)
         {
-            if (NetEventSource.Log.IsEnabled())
+            if (HttpTelemetry.Log.IsEnabled() && request.RequestUri != null)
             {
-                NetEventSource.Error(this, e);
-                NetEventSource.Error(this, "Canceled due to timeout");
+                HttpTelemetry.Log.RequestStart(request);
+                return true;
             }
+
+            return false;
         }
 
-        private void HandleFinishSendCleanup(CancellationTokenSource cts, bool disposeCts)
+        private static void FinishSend(CancellationTokenSource cts, bool disposeCts, bool telemetryStarted, bool responseContentTelemetryStarted)
         {
+            // Log completion.
+            if (HttpTelemetry.Log.IsEnabled() && telemetryStarted)
+            {
+                if (responseContentTelemetryStarted)
+                {
+                    HttpTelemetry.Log.ResponseContentStop();
+                }
+
+                HttpTelemetry.Log.RequestStop();
+            }
+
             // Dispose of the CancellationTokenSource if it was created specially for this request
             // rather than being used across multiple requests.
             if (disposeCts)
@@ -777,11 +680,11 @@ namespace System.Net.Http
             // This method used to also dispose of the request content, e.g.:
             //     request.Content?.Dispose();
             // This has multiple problems:
-            // 1. It prevents code from reusing request content objects for subsequent requests,
-            //    as disposing of the object likely invalidates it for further use.
-            // 2. It prevents the possibility of partial or full duplex communication, even if supported
-            //    by the handler, as the request content may still be in use even if the response
-            //    (or response headers) has been received.
+            //   1. It prevents code from reusing request content objects for subsequent requests,
+            //      as disposing of the object likely invalidates it for further use.
+            //   2. It prevents the possibility of partial or full duplex communication, even if supported
+            //      by the handler, as the request content may still be in use even if the response
+            //      (or response headers) has been received.
             // By changing this to not dispose of the request content, disposal may end up being
             // left for the finalizer to handle, or the developer can explicitly dispose of the
             // content when they're done with it.  But it allows request content to be reused,
@@ -905,7 +808,7 @@ namespace System.Net.Http
             }
         }
 
-        private CancellationTokenSource PrepareCancellationTokenSource(CancellationToken cancellationToken, out bool disposeCts, out long timeoutTime)
+        private (CancellationTokenSource TokenSource, bool DisposeTokenSource, long TimeoutTime) PrepareCancellationTokenSource(CancellationToken cancellationToken)
         {
             // We need a CancellationTokenSource to use with the request.  We always have the global
             // _pendingRequestsCts to use, plus we may have a token provided by the caller, and we may
@@ -913,10 +816,9 @@ namespace System.Net.Http
             // CTS (we can't, for example, timeout the pending requests CTS, as that could cancel other
             // unrelated operations).  Otherwise, we can use the pending requests CTS directly.
             bool hasTimeout = _timeout != s_infiniteTimeout;
-            timeoutTime = long.MaxValue;
+            long timeoutTime = long.MaxValue;
             if (hasTimeout || cancellationToken.CanBeCanceled)
             {
-                disposeCts = true;
                 CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _pendingRequestsCts.Token);
                 if (hasTimeout)
                 {
@@ -924,11 +826,10 @@ namespace System.Net.Http
                     cts.CancelAfter(_timeout);
                 }
 
-                return cts;
+                return (cts, DisposeTokenSource: true, timeoutTime);
             }
 
-            disposeCts = false;
-            return _pendingRequestsCts;
+            return (_pendingRequestsCts, DisposeTokenSource: false, timeoutTime);
         }
 
         private static void CheckBaseAddress(Uri? baseAddress, string parameterName)
index 3213395..d30e6cc 100644 (file)
@@ -3,7 +3,6 @@
 
 using System.Diagnostics;
 using System.Diagnostics.Tracing;
-using System.Net.Http;
 
 namespace System.Net
 {
@@ -12,8 +11,7 @@ namespace System.Net
     {
         private const int UriBaseAddressId = NextAvailableEventId;
         private const int ContentNullId = UriBaseAddressId + 1;
-        private const int ClientSendCompletedId = ContentNullId + 1;
-        private const int HeadersInvalidValueId = ClientSendCompletedId + 1;
+        private const int HeadersInvalidValueId = ContentNullId + 1;
         private const int HandlerMessageId = HeadersInvalidValueId + 1;
         private const int AuthenticationInfoId = HandlerMessageId + 1;
         private const int AuthenticationErrorId = AuthenticationInfoId + 1;
@@ -41,17 +39,6 @@ namespace System.Net
         private void ContentNull(string objName, int objHash) =>
             WriteEvent(ContentNullId, objName, objHash);
 
-        [NonEvent]
-        public static void ClientSendCompleted(HttpClient httpClient, HttpResponseMessage response, HttpRequestMessage request)
-        {
-            Debug.Assert(Log.IsEnabled());
-            Log.ClientSendCompleted(response?.ToString(), GetHashCode(request), GetHashCode(response), GetHashCode(httpClient));
-        }
-
-        [Event(ClientSendCompletedId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
-        private void ClientSendCompleted(string? responseString, int httpRequestMessageHash, int httpResponseMessageHash, int httpClientHash) =>
-            WriteEvent(ClientSendCompletedId, responseString, httpRequestMessageHash, httpResponseMessageHash, httpClientHash);
-
         [Event(HeadersInvalidValueId, Keywords = Keywords.Debug, Level = EventLevel.Error)]
         public void HeadersInvalidValue(string name, string rawValue) =>
             WriteEvent(HeadersInvalidValueId, name, rawValue);
index 03f5777..2ff22e9 100644 (file)
@@ -415,6 +415,93 @@ namespace System.Net.Http.Functional.Tests
                 });
         }
 
+        [OuterLoop("Incurs small timeout")]
+        [Theory]
+        [InlineData(0, 0)]
+        [InlineData(0, 1)]
+        [InlineData(0, 2)]
+        [InlineData(1, 0)]
+        [InlineData(1, 1)]
+        [InlineData(1, 2)]
+        public async Task GetAsync_ContentCanBeCanceled(int getMode, int cancelMode)
+        {
+            // cancelMode:
+            // 0: CancellationToken
+            // 1: CancelAllPending()
+            // 2: Timeout
+
+            var tcs = new TaskCompletionSource();
+            var cts = new CancellationTokenSource();
+            using HttpClient httpClient = CreateHttpClient();
+
+            // Give client time to read the headers.  There's a race condition here, but if it occurs and the client hasn't finished reading
+            // the headers by when we want it to, the test should still pass, it just won't be testing what we want it to.
+            // The same applies to the Task.Delay below.
+            httpClient.Timeout = cancelMode == 2 ?
+                TimeSpan.FromSeconds(1) :
+                Timeout.InfiniteTimeSpan;
+
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    try
+                    {
+                        Exception e = await Assert.ThrowsAsync<TaskCanceledException>(async () =>
+                        {
+                            switch (getMode)
+                            {
+                                case 0:
+                                    await httpClient.GetStringAsync(uri, cts.Token);
+                                    break;
+
+                                case 1:
+                                    await httpClient.GetByteArrayAsync(uri, cts.Token);
+                                    break;
+                            }
+                        });
+
+                        if (cancelMode == 2)
+                        {
+                            Assert.IsType<TimeoutException>(e.InnerException);
+                        }
+                        else
+                        {
+                            Assert.IsNotType<TimeoutException>(e.InnerException);
+                        }
+                    }
+                    finally
+                    {
+                        tcs.SetResult();
+                    }
+                },
+                async server =>
+                {
+                    await server.AcceptConnectionAsync(async connection =>
+                    {
+                        await connection.ReadRequestDataAsync(readBody: false);
+                        await connection.SendResponseAsync(HttpStatusCode.OK, headers: new HttpHeaderData[] { new HttpHeaderData("Content-Length", "5") });
+                        await connection.SendResponseBodyAsync("he");
+
+                        switch (cancelMode)
+                        {
+                            case 0:
+                                await Task.Delay(100);
+                                cts.Cancel();
+                                break;
+
+                            case 1:
+                                await Task.Delay(100);
+                                httpClient.CancelPendingRequests();
+                                break;
+
+                            // case 2: timeout fires on its own
+                        }
+
+                        await tcs.Task;
+                    });
+                });
+        }
+
         [Fact]
         public async Task GetByteArrayAsync_Success()
         {
index 7822122..7446805 100644 (file)
@@ -36,12 +36,14 @@ namespace System.Net.Http.Functional.Tests
         {
             yield return new object[] { "GetAsync" };
             yield return new object[] { "SendAsync" };
+            yield return new object[] { "UnbufferedSendAsync" };
             yield return new object[] { "GetStringAsync" };
             yield return new object[] { "GetByteArrayAsync" };
             yield return new object[] { "GetStreamAsync" };
             yield return new object[] { "InvokerSendAsync" };
 
             yield return new object[] { "Send" };
+            yield return new object[] { "UnbufferedSend" };
             yield return new object[] { "InvokerSend" };
         }
 
@@ -63,6 +65,7 @@ namespace System.Net.Http.Functional.Tests
                 Version version = Version.Parse(useVersionString);
                 using var listener = new TestEventListener("System.Net.Http", EventLevel.Verbose, eventCounterInterval: 0.1d);
 
+                bool buffersResponse = false;
                 var events = new ConcurrentQueue<EventWrittenEventArgs>();
                 await listener.RunWithCallbackAsync(events.Enqueue, async () =>
                 {
@@ -81,38 +84,78 @@ namespace System.Net.Http.Functional.Tests
                             switch (testMethod)
                             {
                                 case "GetAsync":
-                                    await client.GetAsync(uri);
+                                    {
+                                        buffersResponse = true;
+                                        await client.GetAsync(uri);
+                                    }
                                     break;
 
                                 case "Send":
-                                    await Task.Run(() => client.Send(request));
+                                    {
+                                        buffersResponse = true;
+                                        await Task.Run(() => client.Send(request));
+                                    }
+                                    break;
+
+                                case "UnbufferedSend":
+                                    {
+                                        buffersResponse = false;
+                                        HttpResponseMessage response = await Task.Run(() => client.Send(request, HttpCompletionOption.ResponseHeadersRead));
+                                        response.Content.CopyTo(Stream.Null, null, default);
+                                    }
                                     break;
 
                                 case "SendAsync":
-                                    await client.SendAsync(request);
+                                    {
+                                        buffersResponse = true;
+                                        await client.SendAsync(request);
+                                    }
+                                    break;
+
+                                case "UnbufferedSendAsync":
+                                    {
+                                        buffersResponse = false;
+                                        HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
+                                        await response.Content.CopyToAsync(Stream.Null);
+                                    }
                                     break;
 
                                 case "GetStringAsync":
-                                    await client.GetStringAsync(uri);
+                                    {
+                                        buffersResponse = true;
+                                        await client.GetStringAsync(uri);
+                                    }
                                     break;
 
                                 case "GetByteArrayAsync":
-                                    await client.GetByteArrayAsync(uri);
+                                    {
+                                        buffersResponse = true;
+                                        await client.GetByteArrayAsync(uri);
+                                    }
                                     break;
 
                                 case "GetStreamAsync":
-                                    Stream responseStream = await client.GetStreamAsync(uri);
-                                    await responseStream.CopyToAsync(Stream.Null);
+                                    {
+                                        buffersResponse = false;
+                                        Stream responseStream = await client.GetStreamAsync(uri);
+                                        await responseStream.CopyToAsync(Stream.Null);
+                                    }
                                     break;
 
                                 case "InvokerSend":
-                                    HttpResponseMessage syncResponse = await Task.Run(() => invoker.Send(request, cancellationToken: default));
-                                    await syncResponse.Content.CopyToAsync(Stream.Null);
+                                    {
+                                        buffersResponse = false;
+                                        HttpResponseMessage response = await Task.Run(() => invoker.Send(request, cancellationToken: default));
+                                        await response.Content.CopyToAsync(Stream.Null);
+                                    }
                                     break;
 
                                 case "InvokerSendAsync":
-                                    HttpResponseMessage asyncResponse = await invoker.SendAsync(request, cancellationToken: default);
-                                    await asyncResponse.Content.CopyToAsync(Stream.Null);
+                                    {
+                                        buffersResponse = false;
+                                        HttpResponseMessage response = await invoker.SendAsync(request, cancellationToken: default);
+                                        await response.Content.CopyToAsync(Stream.Null);
+                                    }
                                     break;
                             }
                         },
@@ -143,7 +186,7 @@ namespace System.Net.Http.Functional.Tests
                 ValidateRequestResponseStartStopEvents(
                     events,
                     requestContentLength: null,
-                    responseContentLength: testMethod.StartsWith("InvokerSend") ? null : ResponseContentLength,
+                    responseContentLength: buffersResponse ? ResponseContentLength : null,
                     count: 1);
 
                 VerifyEventCounters(events, requestCount: 1, shouldHaveFailures: false);
@@ -194,10 +237,18 @@ namespace System.Net.Http.Functional.Tests
                                     await Assert.ThrowsAsync<TaskCanceledException>(async () => await Task.Run(() => client.Send(request, cts.Token)));
                                     break;
 
+                                case "UnbufferedSend":
+                                    await Assert.ThrowsAsync<TaskCanceledException>(async () => await Task.Run(() => client.Send(request, HttpCompletionOption.ResponseHeadersRead, cts.Token)));
+                                    break;
+
                                 case "SendAsync":
                                     await Assert.ThrowsAsync<TaskCanceledException>(async () => await client.SendAsync(request, cts.Token));
                                     break;
 
+                                case "UnbufferedSendAsync":
+                                    await Assert.ThrowsAsync<TaskCanceledException>(async () => await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cts.Token));
+                                    break;
+
                                 case "GetStringAsync":
                                     await Assert.ThrowsAsync<TaskCanceledException>(async () => await client.GetStringAsync(uri, cts.Token));
                                     break;