Enable SocketsHttpHandler cancellation support (dotnet/corefx#27029)
authorStephen Toub <stoub@microsoft.com>
Tue, 13 Feb 2018 21:30:34 +0000 (16:30 -0500)
committerGitHub <noreply@github.com>
Tue, 13 Feb 2018 21:30:34 +0000 (16:30 -0500)
* Enable SocketsHttpHandler cancellation support

This change significantly improves the cancellation support in SocketsHttpHandler.  Previously we were passing the CancellationToken around to every method, eventually bottoming out in calls to the underlying Stream which then ends up passing them down to the underlying Socket.  But today Socket's support for cancellation is minimal, only doing up-front checks; if cancellation is requested during the socket operation rather than before, the request will be ignored.  Since HttpClient implements features like timeouts on top of cancellation support, it's important to do better than this.

The change implements cancellation by registering with the CancellationToken to dispose of the connection.  This will cause any reads/writes to wake up.  We then translate resulting exceptions into cancellation exceptions.  When in the main SendAsync method, we register once for the whole body of the operation until the point that we're returning the response message.  For individual operations on the response content stream, we register per operation; however, when feasible we try to avoid the registration costs by only registering if operations don't complete synchronously.  We also account for the case that on Unix, closing the connection may result in read operations waking up not with an exception but rather with EOF, which we also need to translate into cancellation when appropriate.

Along the way I cleaned up a few minor issues as well.

I also added a bunch of cancellation-related tests:
- Test cancellation occurring while sending request content
- Test cancellation occurring while receiving response headers
- Test cancellation occurring while receiving response body and using a buffered operation
- Test that all of the above are triggerable with CancellationTokenSource.Cancel, HttpClient.CancelPendingRequests, and HttpClient.Dispose
- Test cancellation occurring while receiving response body and using an unbuffered operation, either a ReadAsync or CopyToAsync on the response stream
- Test that a CancelPendingRequests doesn't affect unbuffered operations on the response stream

There are deficiencies here in the existing handlers, and tests have been selectively disabled accordingly (I also fixed a couple cases that naturally fell out of the changes I was making for SocketsHttpHandler).  SocketsHttpHandler passes now for all of them.

* Add test that Dispose doesn't cancel response stream

Commit migrated from https://github.com/dotnet/corefx/commit/53be85c2fe473fdea8c001e3d9fd81dd478b858e

27 files changed:
src/libraries/Common/src/System/Net/Http/NoWriteNoSeekStreamContent.cs
src/libraries/Common/src/System/Net/Logging/NetEventSource.Common.cs
src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs
src/libraries/System.Net.Http/src/System.Net.Http.csproj
src/libraries/System.Net.Http/src/System/Net/Http/CurlHandler/CurlHandler.CurlResponseMessage.cs
src/libraries/System.Net.Http/src/System/Net/Http/HttpClient.cs
src/libraries/System.Net.Http/src/System/Net/Http/HttpContent.cs
src/libraries/System.Net.Http/src/System/Net/Http/NetEventSource.Http.cs
src/libraries/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthReadStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DecompressionHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/EmptyReadStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionResponseContent.cs [moved from src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionContent.cs with 70% similarity]
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs
src/libraries/System.Net.Http/tests/FunctionalTests/CancellationTest.cs [deleted file]
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs [new file with mode: 0644]
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs
src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj

index 2d55520..202782c 100644 (file)
@@ -14,10 +14,9 @@ namespace System.Net.Http
     internal sealed class NoWriteNoSeekStreamContent : HttpContent
     {
         private readonly Stream _content;
-        private readonly CancellationToken _cancellationToken;
         private bool _contentConsumed;
 
-        internal NoWriteNoSeekStreamContent(Stream content, CancellationToken cancellationToken)
+        internal NoWriteNoSeekStreamContent(Stream content)
         {
             Debug.Assert(content != null);
             Debug.Assert(content.CanRead);
@@ -25,10 +24,16 @@ namespace System.Net.Http
             Debug.Assert(!content.CanSeek);
 
             _content = content;
-            _cancellationToken = cancellationToken;
         }
 
-        protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
+        protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+            SerializeToStreamAsync(stream, context, CancellationToken.None);
+
+        internal
+#if HTTP_DLL
+            override
+#endif
+            Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken)
         {
             Debug.Assert(stream != null);
 
@@ -39,7 +44,7 @@ namespace System.Net.Http
             _contentConsumed = true;
 
             const int BufferSize = 8192;
-            Task copyTask = _content.CopyToAsync(stream, BufferSize, _cancellationToken);
+            Task copyTask = _content.CopyToAsync(stream, BufferSize, cancellationToken);
             if (copyTask.IsCompleted)
             {
                 try { _content.Dispose(); } catch { } // same as StreamToStreamCopy behavior
@@ -75,6 +80,10 @@ namespace System.Net.Http
             base.Dispose(disposing);
         }
 
-        protected override Task<Stream> CreateContentReadStreamAsync() => Task.FromResult<Stream>(_content);
+        protected override Task<Stream> CreateContentReadStreamAsync() => Task.FromResult(_content);
+
+#if HTTP_DLL
+        internal override Stream TryCreateContentReadStream() => _content;
+#endif
     }
 }
index 685a8fe..0948ec7 100644 (file)
@@ -395,7 +395,9 @@ namespace System.Net
             Debug.Assert(IsEnabled || arg == null, $"Should not be formatting FormattableString \"{arg}\" if tracing isn't enabled");
         }
 
-        public static new bool IsEnabled => Log.IsEnabled();
+        public static new bool IsEnabled =>
+            Log.IsEnabled();
+            //true; // uncomment for debugging only
 
         [NonEvent]
         public static string IdOf(object value) => value != null ? value.GetType().Name + "#" + GetHashCode(value) : NullInstance;
index 16af2ad..bc6ff3f 100644 (file)
@@ -92,7 +92,7 @@ namespace System.Net.Http
                     }
                 }
 
-                response.Content = new NoWriteNoSeekStreamContent(decompressedStream, state.CancellationToken);
+                response.Content = new NoWriteNoSeekStreamContent(decompressedStream);
                 response.RequestMessage = request;
 
                 // Parse raw response headers and place them into response message.
index 2104094..ef88d16 100644 (file)
@@ -1,4 +1,4 @@
-<?xml version="1.0" encoding="utf-8"?>
+<?xml version="1.0" encoding="utf-8"?>
 <Project ToolsVersion="14.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
   <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" />
   <PropertyGroup>
     <Compile Include="System\Net\Http\SocketsHttpHandler\DecompressionHandler.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\EmptyReadStream.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnection.cs" />
-    <Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionContent.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionHandler.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionKey.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionPool.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionPools.cs" />
+    <Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionResponseContent.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\HttpConnectionSettings.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\HttpContentDuplexStream.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\HttpContentReadStream.cs" />
     <Reference Include="System.Security.Cryptography.Primitives" />
   </ItemGroup>
   <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
-</Project>
\ No newline at end of file
+</Project>
index 1fafd96..55d2228 100644 (file)
@@ -22,7 +22,7 @@ namespace System.Net.Http
                 Debug.Assert(easy != null, "Expected non-null EasyRequest");
                 RequestMessage = easy._requestMessage;
                 ResponseStream = new CurlResponseStream(easy);
-                Content = new NoWriteNoSeekStreamContent(ResponseStream, CancellationToken.None);
+                Content = new NoWriteNoSeekStreamContent(ResponseStream);
 
                 // On Windows, we pass the equivalent of the easy._cancellationToken
                 // in to StreamContent's ctor.  This in turn passes that token through
index 1bcac19..3c8eca3 100644 (file)
@@ -475,7 +475,7 @@ namespace System.Net.Http
                 // Buffer the response content if we've been asked to and we have a Content to buffer.
                 if (response.Content != null)
                 {
-                    await response.Content.LoadIntoBufferAsync(_maxResponseContentBufferSize).ConfigureAwait(false);
+                    await response.Content.LoadIntoBufferAsync(_maxResponseContentBufferSize, cts.Token).ConfigureAwait(false);
                 }
 
                 if (NetEventSource.IsEnabled) NetEventSource.ClientSendCompleted(this, response, request);
index 191c5f4..f1c15e3 100644 (file)
@@ -299,7 +299,18 @@ namespace System.Net.Http
 
         protected abstract Task SerializeToStreamAsync(Stream stream, TransportContext context);
 
-        public Task CopyToAsync(Stream stream, TransportContext context)
+        // TODO #9071: Expose this publicly.  Until it's public, only sealed or internal types should override it, and then change
+        // their SerializeToStreamAsync implementation to delegate to this one.  They need to be sealed as otherwise an external
+        // type could derive from it and override SerializeToStreamAsync(stream, context) further, at which point when
+        // HttpClient calls SerializeToStreamAsync(stream, context, cancellationToken), their custom override will be skipped.
+        internal virtual Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+            SerializeToStreamAsync(stream, context);
+
+        public Task CopyToAsync(Stream stream, TransportContext context) =>
+            CopyToAsync(stream, context, CancellationToken.None);
+
+        // TODO #9071: Expose this publicly.
+        internal Task CopyToAsync(Stream stream, TransportContext context, CancellationToken cancellationToken)
         {
             CheckDisposed();
             if (stream == null)
@@ -313,11 +324,11 @@ namespace System.Net.Http
                 ArraySegment<byte> buffer;
                 if (TryGetBuffer(out buffer))
                 {
-                    task = stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count);
+                    task = stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken);
                 }
                 else
                 {
-                    task = SerializeToStreamAsync(stream, context);
+                    task = SerializeToStreamAsync(stream, context, cancellationToken);
                     CheckTaskNotNull(task);
                 }
 
@@ -354,7 +365,10 @@ namespace System.Net.Http
         // No "CancellationToken" parameter needed since canceling the CTS will close the connection, resulting
         // in an exception being thrown while we're buffering.
         // If buffering is used without a connection, it is supposed to be fast, thus no cancellation required.
-        public Task LoadIntoBufferAsync(long maxBufferSize)
+        public Task LoadIntoBufferAsync(long maxBufferSize) =>
+            LoadIntoBufferAsync(maxBufferSize, CancellationToken.None);
+
+        internal Task LoadIntoBufferAsync(long maxBufferSize, CancellationToken cancellationToken)
         {
             CheckDisposed();
             if (maxBufferSize > HttpContent.MaxBufferSize)
@@ -382,7 +396,7 @@ namespace System.Net.Http
 
             try
             {
-                Task task = SerializeToStreamAsync(tempBuffer, null);
+                Task task = SerializeToStreamAsync(tempBuffer, null, cancellationToken);
                 CheckTaskNotNull(task);
                 return LoadIntoBufferAsyncCore(task, tempBuffer);
             }
index 1924dc6..8735970 100644 (file)
@@ -63,6 +63,7 @@ namespace System.Net
         [Event(HandlerMessageId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
         public void HandlerMessage(int handlerId, int workerId, int requestId, string memberName, string message) =>
             WriteEvent(HandlerMessageId, handlerId, workerId, requestId, memberName, message);
+            //Console.WriteLine($"{handlerId}/{workerId}/{requestId}: ({memberName}): {message}");  // uncomment for debugging only
 
         [NonEvent]
         private unsafe void WriteEvent(int eventId, int arg1, int arg2, int arg3, string arg4, string arg5)
index 0a9588d..11b9877 100644 (file)
@@ -4,6 +4,7 @@
 
 using System.IO;
 using System.Runtime.InteropServices;
+using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Net.Http
@@ -26,6 +27,9 @@ namespace System.Net.Http
         protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
             stream.WriteAsync(_content);
 
+        internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+            stream.WriteAsync(_content, cancellationToken);
+
         protected internal override bool TryComputeLength(out long length)
         {
             length = _content.Length;
index ab8865e..32bd40b 100644 (file)
@@ -30,13 +30,13 @@ namespace System.Net.Http
             {
             }
 
-            private async Task<bool> TryGetNextChunkAsync(CancellationToken cancellationToken)
+            private async Task<bool> TryGetNextChunkAsync()
             {
                 Debug.Assert(_chunkBytesRemaining == 0);
 
                 // Read the start of the chunk line.
                 _connection._allowedReadLineBytes = MaxChunkBytesAllowed;
-                ArraySegment<byte> line = await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false);
+                ArraySegment<byte> line = await _connection.ReadNextLineAsync().ConfigureAwait(false);
 
                 // Parse the hex value.
                 if (!Utf8Parser.TryParse(line.AsReadOnlySpan(), out ulong chunkSize, out int bytesConsumed, 'X'))
@@ -73,7 +73,7 @@ namespace System.Net.Http
                 while (true)
                 {
                     _connection._allowedReadLineBytes = MaxTrailingHeaderLength;
-                    if (LineIsEmpty(await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false)))
+                    if (LineIsEmpty(await _connection.ReadNextLineAsync().ConfigureAwait(false)))
                     {
                         break;
                     }
@@ -84,59 +84,77 @@ namespace System.Net.Http
                 return false;
             }
 
-            private async Task ConsumeChunkBytesAsync(ulong bytesConsumed, CancellationToken cancellationToken)
+            private Task ConsumeChunkBytesAsync(ulong bytesConsumed)
             {
                 Debug.Assert(bytesConsumed <= _chunkBytesRemaining);
                 _chunkBytesRemaining -= bytesConsumed;
-                if (_chunkBytesRemaining == 0)
+                return _chunkBytesRemaining != 0 ?
+                    Task.CompletedTask :
+                    ReadNextLineAndThrowIfNotEmptyAsync();
+            }
+
+            private async Task ReadNextLineAndThrowIfNotEmptyAsync()
+            {
+                _connection._allowedReadLineBytes = 2; // \r\n
+                if (!LineIsEmpty(await _connection.ReadNextLineAsync().ConfigureAwait(false)))
                 {
-                    _connection._allowedReadLineBytes = 2; // \r\n
-                    if (!LineIsEmpty(await _connection.ReadNextLineAsync(cancellationToken).ConfigureAwait(false)))
-                    {
-                        ThrowInvalidHttpResponse();
-                    }
+                    ThrowInvalidHttpResponse();
                 }
             }
 
             public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
             {
                 ValidateBufferArgs(buffer, offset, count);
-                return ReadAsync(new Memory<byte>(buffer, offset, count)).AsTask();
+                return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
             }
 
-            public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
+            public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken)
             {
+                cancellationToken.ThrowIfCancellationRequested();
+
                 if (_connection == null || destination.Length == 0)
                 {
                     // Response body fully consumed or the caller didn't ask for any data
                     return 0;
                 }
 
-                if (_chunkBytesRemaining == 0)
+                CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                try
                 {
-                    if (!await TryGetNextChunkAsync(cancellationToken).ConfigureAwait(false))
+                    if (_chunkBytesRemaining == 0)
                     {
-                        // End of response body
-                        return 0;
+                        if (!await TryGetNextChunkAsync().ConfigureAwait(false))
+                        {
+                            // End of response body
+                            return 0;
+                        }
                     }
-                }
 
-                if (_chunkBytesRemaining < (ulong)destination.Length)
-                {
-                    destination = destination.Slice(0, (int)_chunkBytesRemaining);
-                }
+                    if (_chunkBytesRemaining < (ulong)destination.Length)
+                    {
+                        destination = destination.Slice(0, (int)_chunkBytesRemaining);
+                    }
 
-                int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false);
+                    int bytesRead = await _connection.ReadAsync(destination).ConfigureAwait(false);
 
-                if (bytesRead <= 0)
-                {
-                    // Unexpected end of response stream
-                    throw new IOException(SR.net_http_invalid_response);
-                }
+                    if (bytesRead <= 0)
+                    {
+                        // Unexpected end of response stream
+                        throw new IOException(SR.net_http_invalid_response);
+                    }
 
-                await ConsumeChunkBytesAsync((ulong)bytesRead, cancellationToken).ConfigureAwait(false);
+                    await ConsumeChunkBytesAsync((ulong)bytesRead).ConfigureAwait(false);
 
-                return bytesRead;
+                    return bytesRead;
+                }
+                catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                {
+                    throw new OperationCanceledException(s_cancellationMessage, exc, cancellationToken);
+                }
+                finally
+                {
+                    ctr.Dispose();
+                }
             }
 
             public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
@@ -145,6 +163,12 @@ namespace System.Net.Http
                 {
                     throw new ArgumentNullException(nameof(destination));
                 }
+                if (bufferSize <= 0)
+                {
+                    throw new ArgumentOutOfRangeException(nameof(bufferSize));
+                }
+
+                cancellationToken.ThrowIfCancellationRequested();
 
                 if (_connection == null)
                 {
@@ -152,16 +176,28 @@ namespace System.Net.Http
                     return;
                 }
 
-                if (_chunkBytesRemaining > 0)
+                CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                try
                 {
-                    await _connection.CopyToAsync(destination, _chunkBytesRemaining, cancellationToken).ConfigureAwait(false);
-                    await ConsumeChunkBytesAsync(_chunkBytesRemaining, cancellationToken).ConfigureAwait(false);
-                }
+                    if (_chunkBytesRemaining > 0)
+                    {
+                        await _connection.CopyToAsync(destination, _chunkBytesRemaining).ConfigureAwait(false);
+                        await ConsumeChunkBytesAsync(_chunkBytesRemaining).ConfigureAwait(false);
+                    }
 
-                while (await TryGetNextChunkAsync(cancellationToken).ConfigureAwait(false))
+                    while (await TryGetNextChunkAsync().ConfigureAwait(false))
+                    {
+                        await _connection.CopyToAsync(destination, _chunkBytesRemaining).ConfigureAwait(false);
+                        await ConsumeChunkBytesAsync(_chunkBytesRemaining).ConfigureAwait(false);
+                    }
+                }
+                catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                {
+                    throw CreateOperationCanceledException(exc, cancellationToken);
+                }
+                finally
                 {
-                    await _connection.CopyToAsync(destination, _chunkBytesRemaining, cancellationToken).ConfigureAwait(false);
-                    await ConsumeChunkBytesAsync(_chunkBytesRemaining, cancellationToken).ConfigureAwait(false);
+                    ctr.Dispose();
                 }
             }
         }
index 04bcdc1..82682ac 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Diagnostics;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -13,8 +14,7 @@ namespace System.Net.Http
         {
             private static readonly byte[] s_finalChunkBytes = { (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' };
 
-            public ChunkedEncodingWriteStream(HttpConnection connection, CancellationToken cancellationToken) :
-                base(connection, cancellationToken)
+            public ChunkedEncodingWriteStream(HttpConnection connection) : base(connection)
             {
             }
 
@@ -24,13 +24,17 @@ namespace System.Net.Http
                 return WriteAsync(new Memory<byte>(buffer, offset, count), ignored);
             }
 
-            public override Task WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
+            public override Task WriteAsync(ReadOnlyMemory<byte> source, CancellationToken ignored)
             {
+                // The token is ignored because it's coming from SendAsync and the only operations
+                // here are those that are already covered by the token having been registered with
+                // to close the connection.
+
                 if (source.Length == 0)
                 {
                     // Don't write if nothing was given, especially since we don't want to accidentally send a 0 chunk,
                     // which would indicate end of body.  Instead, just ensure no content is stuck in the buffer.
-                    return _connection.FlushAsync(RequestCancellationToken);
+                    return _connection.FlushAsync();
                 }
 
                 if (_connection._currentRequest == null)
@@ -54,17 +58,17 @@ namespace System.Net.Http
                     int digit = (source.Length & mask) >> shift;
                     if (digitWritten || digit != 0)
                     {
-                        await _connection.WriteByteAsync((byte)(digit < 10 ? '0' + digit : 'A' + digit - 10), RequestCancellationToken).ConfigureAwait(false);
+                        await _connection.WriteByteAsync((byte)(digit < 10 ? '0' + digit : 'A' + digit - 10)).ConfigureAwait(false);
                         digitWritten = true;
                     }
                 }
 
                 // End chunk length
-                await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n', RequestCancellationToken).ConfigureAwait(false);
+                await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false);
 
                 // Write chunk contents
-                await _connection.WriteAsync(source, RequestCancellationToken).ConfigureAwait(false);
-                await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n', RequestCancellationToken).ConfigureAwait(false);
+                await _connection.WriteAsync(source).ConfigureAwait(false);
+                await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false);
 
                 // Flush the chunk.  This is reasonable from the standpoint of having just written a standalone piece
                 // of data, but is also necessary to support duplex communication, where a CopyToAsync is taking the
@@ -72,18 +76,16 @@ namespace System.Net.Http
                 // source was empty, and it might be kept open to enable subsequent communication.  And it's necessary
                 // in general for at least the first write, as we need to ensure if it's the entirety of the content
                 // and if all of the headers and content fit in the write buffer that we've actually sent the request.
-                await _connection.FlushAsync(RequestCancellationToken).ConfigureAwait(false);
+                await _connection.FlushAsync().ConfigureAwait(false);
             }
 
-            public override Task FlushAsync(CancellationToken ignored)
-            {
-                return _connection.FlushAsync(RequestCancellationToken);
-            }
-            
+            public override Task FlushAsync(CancellationToken ignored) => // see comment on WriteAsync about "ignored"
+                _connection.FlushAsync();
+
             public override async Task FinishAsync()
             {
                 // Send 0 byte chunk to indicate end, then final CrLf
-                await _connection.WriteBytesAsync(s_finalChunkBytes, RequestCancellationToken).ConfigureAwait(false);
+                await _connection.WriteBytesAsync(s_finalChunkBytes).ConfigureAwait(false);
                 _connection = null;
             }
         }
index ab0a28b..2c000f6 100644 (file)
@@ -22,17 +22,49 @@ namespace System.Net.Http
                 return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
             }
 
-            public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
+            public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken)
             {
+                cancellationToken.ThrowIfCancellationRequested();
+
                 if (_connection == null || destination.Length == 0)
                 {
                     // Response body fully consumed or the caller didn't ask for any data
                     return 0;
                 }
 
-                int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false);
+                ValueTask<int> readTask = _connection.ReadAsync(destination);
+                int bytesRead;
+                if (readTask.IsCompletedSuccessfully)
+                {
+                    bytesRead = readTask.Result;
+                }
+                else
+                {
+                    CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                    try
+                    {
+                        bytesRead = await readTask.ConfigureAwait(false);
+                    }
+                    catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                    {
+                        throw CreateOperationCanceledException(exc, cancellationToken);
+                    }
+                    finally
+                    {
+                        ctr.Dispose();
+                    }
+                }
+
                 if (bytesRead == 0)
                 {
+                    // If cancellation is requested and tears down the connection, it could cause the read
+                    // to return 0, which would otherwise signal the end of the data, but that would lead
+                    // the caller to think that it actually received all of the data, rather than it ending
+                    // early due to cancellation.  So we prioritize cancellation in this race condition, and
+                    // if we read 0 bytes and then find that cancellation has requested, we assume cancellation
+                    // was the cause and throw.
+                    cancellationToken.ThrowIfCancellationRequested();
+
                     // We cannot reuse this connection, so close it.
                     _connection.Dispose();
                     _connection = null;
@@ -48,15 +80,46 @@ namespace System.Net.Http
                 {
                     throw new ArgumentNullException(nameof(destination));
                 }
+                if (bufferSize <= 0)
+                {
+                    throw new ArgumentOutOfRangeException(nameof(bufferSize));
+                }
 
-                if (_connection != null) // null if response body fully consumed
+                cancellationToken.ThrowIfCancellationRequested();
+
+                if (_connection == null)
                 {
-                    await _connection.CopyToAsync(destination, cancellationToken).ConfigureAwait(false);
+                    // Response body fully consumed
+                    return;
+                }
 
-                    // We cannot reuse this connection, so close it.
-                    _connection.Dispose();
-                    _connection = null;
+                Task copyTask = _connection.CopyToAsync(destination);
+                if (!copyTask.IsCompletedSuccessfully)
+                {
+                    CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                    try
+                    {
+                        await copyTask.ConfigureAwait(false);
+                    }
+                    catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                    {
+                        throw CreateOperationCanceledException(exc, cancellationToken);
+                    }
+                    finally
+                    {
+                        ctr.Dispose();
+                    }
                 }
+
+                // If cancellation is requested and tears down the connection, it could cause the copy
+                // to end early but think it ended successfully. So we prioritize cancellation in this
+                // race condition, and if we find after the copy has completed that cancellation has
+                // been requested, we assume the copy completed due to cancellation and throw.
+                cancellationToken.ThrowIfCancellationRequested();
+
+                // We cannot reuse this connection, so close it.
+                _connection.Dispose();
+                _connection = null;
             }
         }
     }
index 65fd114..e6daf28 100644 (file)
@@ -15,8 +15,7 @@ namespace System.Net.Http
         {
             private ulong _contentBytesRemaining;
 
-            public ContentLengthReadStream(HttpConnection connection, ulong contentLength)
-                : base(connection)
+            public ContentLengthReadStream(HttpConnection connection, ulong contentLength) : base(connection)
             {
                 Debug.Assert(contentLength > 0, "Caller should have checked for 0.");
                 _contentBytesRemaining = contentLength;
@@ -28,8 +27,10 @@ namespace System.Net.Http
                 return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
             }
 
-            public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
+            public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken)
             {
+                cancellationToken.ThrowIfCancellationRequested();
+
                 if (_connection == null || destination.Length == 0)
                 {
                     // Response body fully consumed or the caller didn't ask for any data
@@ -43,11 +44,35 @@ namespace System.Net.Http
                     destination = destination.Slice(0, (int)_contentBytesRemaining);
                 }
 
-                int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false);
+                ValueTask<int> readTask = _connection.ReadAsync(destination);
+                int bytesRead;
+                if (readTask.IsCompletedSuccessfully)
+                {
+                    bytesRead = readTask.Result;
+                }
+                else
+                {
+                    CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                    try
+                    {
+                        bytesRead = await readTask.ConfigureAwait(false);
+                    }
+                    catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                    {
+                        throw CreateOperationCanceledException(exc, cancellationToken);
+                    }
+                    finally
+                    {
+                        ctr.Dispose();
+                    }
+                }
 
                 if (bytesRead <= 0)
                 {
-                    // Unexpected end of response stream
+                    // A cancellation request may have caused the EOF.
+                    cancellationToken.ThrowIfCancellationRequested();
+
+                    // Unexpected end of response stream.
                     throw new IOException(SR.net_http_invalid_response);
                 }
 
@@ -70,6 +95,12 @@ namespace System.Net.Http
                 {
                     throw new ArgumentNullException(nameof(destination));
                 }
+                if (bufferSize <= 0)
+                {
+                    throw new ArgumentOutOfRangeException(nameof(bufferSize));
+                }
+
+                cancellationToken.ThrowIfCancellationRequested();
 
                 if (_connection == null)
                 {
@@ -77,7 +108,23 @@ namespace System.Net.Http
                     return;
                 }
 
-                await _connection.CopyToAsync(destination, _contentBytesRemaining, cancellationToken).ConfigureAwait(false);
+                Task copyTask = _connection.CopyToAsync(destination, _contentBytesRemaining);
+                if (!copyTask.IsCompletedSuccessfully)
+                {
+                    CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                    try
+                    {
+                        await copyTask.ConfigureAwait(false);
+                    }
+                    catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                    {
+                        throw CreateOperationCanceledException(exc, cancellationToken);
+                    }
+                    finally
+                    {
+                        ctr.Dispose();
+                    }
+                }
 
                 _contentBytesRemaining = 0;
                 _connection.ReturnConnectionToPool();
index 0329f43..507f883 100644 (file)
@@ -11,18 +11,17 @@ namespace System.Net.Http
     {
         private sealed class ContentLengthWriteStream : HttpContentWriteStream
         {
-            public ContentLengthWriteStream(HttpConnection connection, CancellationToken cancellationToken) :
-                base(connection, cancellationToken)
+            public ContentLengthWriteStream(HttpConnection connection) : base(connection)
             {
             }
 
-            public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ignored)
+            public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ignored) // token ignored as it comes from SendAsync
             {
                 ValidateBufferArgs(buffer, offset, count);
                 return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), ignored);
             }
 
-            public override Task WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
+            public override Task WriteAsync(ReadOnlyMemory<byte> source, CancellationToken ignored) // token ignored as it comes from SendAsync
             {
                 if (_connection._currentRequest == null)
                 {
@@ -34,13 +33,11 @@ namespace System.Net.Http
                 // Have the connection write the data, skipping the buffer. Importantly, this will
                 // force a flush of anything already in the buffer, i.e. any remaining request headers
                 // that are still buffered.
-                return _connection.WriteWithoutBufferingAsync(source, RequestCancellationToken);
+                return _connection.WriteWithoutBufferingAsync(source);
             }
 
-            public override Task FlushAsync(CancellationToken ignored)
-            {
-                return _connection.FlushAsync(RequestCancellationToken);
-            }
+            public override Task FlushAsync(CancellationToken ignored) => // token ignored as it comes from SendAsync
+                _connection.FlushAsync();
 
             public override Task FinishAsync()
             {
index 261ae26..41843de 100644 (file)
@@ -108,11 +108,14 @@ namespace System.Net.Http
 
             protected abstract Stream GetDecompressedStream(Stream originalStream);
 
-            protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context)
+            protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+                SerializeToStreamAsync(stream, context, CancellationToken.None);
+
+            internal override async Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken)
             {
                 using (Stream decompressedStream = await CreateContentReadStreamAsync().ConfigureAwait(false))
                 {
-                    await decompressedStream.CopyToAsync(stream).ConfigureAwait(false);
+                    await decompressedStream.CopyToAsync(stream, cancellationToken).ConfigureAwait(false);
                 }
             }
 
index bcc4e60..6320af3 100644 (file)
@@ -23,10 +23,13 @@ namespace System.Net.Http
             public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
             {
                 ValidateBufferArgs(buffer, offset, count);
-                return s_zeroTask;
+                return cancellationToken.IsCancellationRequested ?
+                    Task.FromCanceled<int>(cancellationToken) :
+                    s_zeroTask;
             }
 
-            public override ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default) =>
+            public override ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken) =>
+                cancellationToken.IsCancellationRequested ? new ValueTask<int>(Task.FromCanceled<int>(cancellationToken)) :
                 new ValueTask<int>(0);
         }
     }
index dfbedcb..35dd90b 100644 (file)
@@ -43,12 +43,14 @@ namespace System.Net.Http
         private static readonly byte[] s_spaceHttp11NewlineAsciiBytes = Encoding.ASCII.GetBytes(" HTTP/1.1\r\n");
         private static readonly byte[] s_hostKeyAndSeparator = Encoding.ASCII.GetBytes(HttpKnownHeaderNames.Host + ": ");
         private static readonly byte[] s_httpSchemeAndDelimiter = Encoding.ASCII.GetBytes(Uri.UriSchemeHttp + Uri.SchemeDelimiter);
+        private static readonly string s_cancellationMessage = new OperationCanceledException().Message; // use same message as the default ctor
 
         private readonly HttpConnectionPool _pool;
         private readonly Stream _stream;
         private readonly TransportContext _transportContext;
         private readonly bool _usingProxy;
         private readonly byte[] _idnHostAsciiBytes;
+        private readonly WeakReference<HttpConnection> _weakThisRef;
 
         private HttpRequestMessage _currentRequest;
         private Task _sendRequestContentTask;
@@ -83,6 +85,8 @@ namespace System.Net.Http
             _writeBuffer = new byte[InitialWriteBufferSize];
             _readBuffer = new byte[InitialReadBufferSize];
 
+            _weakThisRef = new WeakReference<HttpConnection>(this);
+
             if (NetEventSource.IsEnabled)
             {
                 if (_stream is SslStream sslStream)
@@ -152,64 +156,64 @@ namespace System.Net.Http
 
         public DateTimeOffset CreationTime { get; } = DateTimeOffset.UtcNow;
 
-        private async Task WriteHeadersAsync(HttpHeaders headers, string cookiesFromContainer, CancellationToken cancellationToken)
+        private async Task WriteHeadersAsync(HttpHeaders headers, string cookiesFromContainer)
         {
             foreach (KeyValuePair<string, IEnumerable<string>> header in headers)
             {
-                await WriteAsciiStringAsync(header.Key, cancellationToken).ConfigureAwait(false);
-                await WriteTwoBytesAsync((byte)':', (byte)' ', cancellationToken).ConfigureAwait(false);
+                await WriteAsciiStringAsync(header.Key).ConfigureAwait(false);
+                await WriteTwoBytesAsync((byte)':', (byte)' ').ConfigureAwait(false);
 
                 var values = (string[])header.Value; // typed as IEnumerable<string>, but always a string[]
                 Debug.Assert(values.Length > 0, "No values for header??");
                 if (values.Length > 0)
                 {
-                    await WriteStringAsync(values[0], cancellationToken).ConfigureAwait(false);
+                    await WriteStringAsync(values[0]).ConfigureAwait(false);
 
                     if (cookiesFromContainer != null && header.Key == HttpKnownHeaderNames.Cookie)
                     {
-                        await WriteTwoBytesAsync((byte)';', (byte)' ', cancellationToken).ConfigureAwait(false);
-                        await WriteStringAsync(cookiesFromContainer, cancellationToken).ConfigureAwait(false);
+                        await WriteTwoBytesAsync((byte)';', (byte)' ').ConfigureAwait(false);
+                        await WriteStringAsync(cookiesFromContainer).ConfigureAwait(false);
 
                         cookiesFromContainer = null;
                     }
 
                     for (int i = 1; i < values.Length; i++)
                     {
-                        await WriteTwoBytesAsync((byte)',', (byte)' ', cancellationToken).ConfigureAwait(false);
-                        await WriteStringAsync(values[i], cancellationToken).ConfigureAwait(false);
+                        await WriteTwoBytesAsync((byte)',', (byte)' ').ConfigureAwait(false);
+                        await WriteStringAsync(values[i]).ConfigureAwait(false);
                     }
                 }
 
-                await WriteTwoBytesAsync((byte)'\r', (byte)'\n', cancellationToken).ConfigureAwait(false);
+                await WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false);
             }
 
             if (cookiesFromContainer != null)
             {
-                await WriteAsciiStringAsync(HttpKnownHeaderNames.Cookie, cancellationToken).ConfigureAwait(false);
-                await WriteTwoBytesAsync((byte)':', (byte)' ', cancellationToken).ConfigureAwait(false);
-                await WriteAsciiStringAsync(cookiesFromContainer, cancellationToken).ConfigureAwait(false);
-                await WriteTwoBytesAsync((byte)'\r', (byte)'\n', cancellationToken).ConfigureAwait(false);
+                await WriteAsciiStringAsync(HttpKnownHeaderNames.Cookie).ConfigureAwait(false);
+                await WriteTwoBytesAsync((byte)':', (byte)' ').ConfigureAwait(false);
+                await WriteAsciiStringAsync(cookiesFromContainer).ConfigureAwait(false);
+                await WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false);
             }
         }
 
-        private async Task WriteHostHeaderAsync(Uri uri, CancellationToken cancellationToken)
+        private async Task WriteHostHeaderAsync(Uri uri)
         {
-            await WriteBytesAsync(s_hostKeyAndSeparator, cancellationToken).ConfigureAwait(false);
+            await WriteBytesAsync(s_hostKeyAndSeparator).ConfigureAwait(false);
 
             await (_idnHostAsciiBytes != null ?
-                WriteBytesAsync(_idnHostAsciiBytes, cancellationToken) :
-                WriteAsciiStringAsync(uri.IdnHost, cancellationToken)).ConfigureAwait(false);
+                WriteBytesAsync(_idnHostAsciiBytes) :
+                WriteAsciiStringAsync(uri.IdnHost)).ConfigureAwait(false);
 
             if (!uri.IsDefaultPort)
             {
-                await WriteByteAsync((byte)':', cancellationToken).ConfigureAwait(false);
-                await WriteFormattedInt32Async(uri.Port, cancellationToken).ConfigureAwait(false);
+                await WriteByteAsync((byte)':').ConfigureAwait(false);
+                await WriteFormattedInt32Async(uri.Port).ConfigureAwait(false);
             }
 
-            await WriteTwoBytesAsync((byte)'\r', (byte)'\n', cancellationToken).ConfigureAwait(false);
+            await WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false);
         }
 
-        private Task WriteFormattedInt32Async(int value, CancellationToken cancellationToken)
+        private Task WriteFormattedInt32Async(int value)
         {
             // Try to format into our output buffer directly.
             if (Utf8Formatter.TryFormat(value, new Span<byte>(_writeBuffer, _writeOffset, _writeBuffer.Length - _writeOffset), out int bytesWritten))
@@ -219,7 +223,7 @@ namespace System.Net.Http
             }
 
             // If we don't have enough room, do it the slow way.
-            return WriteAsciiStringAsync(value.ToString(CultureInfo.InvariantCulture), cancellationToken);
+            return WriteAsciiStringAsync(value.ToString(CultureInfo.InvariantCulture));
         }
 
         public async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
@@ -233,43 +237,43 @@ namespace System.Net.Http
 
             // Send the request.
             if (NetEventSource.IsEnabled) Trace($"Sending request: {request}");
+            CancellationTokenRegistration cancellationRegistration = RegisterCancellation(cancellationToken);
             try
             {
                 // Write request line
-                await WriteStringAsync(request.Method.Method, cancellationToken).ConfigureAwait(false);
-                await WriteByteAsync((byte)' ', cancellationToken).ConfigureAwait(false);
+                await WriteStringAsync(request.Method.Method).ConfigureAwait(false);
+                await WriteByteAsync((byte)' ').ConfigureAwait(false);
 
                 if (_usingProxy)
                 {
                     // Proxied requests contain full URL
                     Debug.Assert(request.RequestUri.Scheme == Uri.UriSchemeHttp);
-                    await WriteBytesAsync(s_httpSchemeAndDelimiter, cancellationToken).ConfigureAwait(false);
-                    await WriteAsciiStringAsync(request.RequestUri.IdnHost, cancellationToken).ConfigureAwait(false);
+                    await WriteBytesAsync(s_httpSchemeAndDelimiter).ConfigureAwait(false);
+                    await WriteAsciiStringAsync(request.RequestUri.IdnHost).ConfigureAwait(false);
                 }
 
-                await WriteStringAsync(request.RequestUri.PathAndQuery, cancellationToken).ConfigureAwait(false);
+                await WriteStringAsync(request.RequestUri.PathAndQuery).ConfigureAwait(false);
 
                 // Fall back to 1.1 for all versions other than 1.0
                 Debug.Assert(request.Version.Major >= 0 && request.Version.Minor >= 0); // guaranteed by Version class
                 bool isHttp10 = request.Version.Minor == 0 && request.Version.Major == 1;
-                await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11NewlineAsciiBytes,
-                                      cancellationToken).ConfigureAwait(false);
+                await WriteBytesAsync(isHttp10 ? s_spaceHttp10NewlineAsciiBytes : s_spaceHttp11NewlineAsciiBytes).ConfigureAwait(false);
 
                 // Determine cookies to send
-                string cookies = null;
+                string cookiesFromContainer = null;
                 if (_pool.Pools.Settings._useCookies)
                 {
-                    cookies = _pool.Pools.Settings._cookieContainer.GetCookieHeader(request.RequestUri);
-                    if (cookies == "")
+                    cookiesFromContainer = _pool.Pools.Settings._cookieContainer.GetCookieHeader(request.RequestUri);
+                    if (cookiesFromContainer == "")
                     {
-                        cookies = null;
+                        cookiesFromContainer = null;
                     }
                 }
 
                 // Write request headers
-                if (request.HasHeaders || cookies != null)
+                if (request.HasHeaders || cookiesFromContainer != null)
                 {
-                    await WriteHeadersAsync(request.Headers, cookies, cancellationToken).ConfigureAwait(false);
+                    await WriteHeadersAsync(request.Headers, cookiesFromContainer).ConfigureAwait(false);
                 }
 
                 if (request.Content == null)
@@ -278,30 +282,30 @@ namespace System.Net.Http
                     // unless this is a method that never has a body.
                     if (request.Method != HttpMethod.Get && request.Method != HttpMethod.Head)
                     {
-                        await WriteBytesAsync(s_contentLength0NewlineAsciiBytes, cancellationToken).ConfigureAwait(false);
+                        await WriteBytesAsync(s_contentLength0NewlineAsciiBytes).ConfigureAwait(false);
                     }
                 }
                 else
                 {
                     // Write content headers
-                    await WriteHeadersAsync(request.Content.Headers, null, cancellationToken).ConfigureAwait(false);
+                    await WriteHeadersAsync(request.Content.Headers, cookiesFromContainer: null).ConfigureAwait(false);
                 }
 
                 // Write special additional headers.  If a host isn't in the headers list, then a Host header
                 // wasn't sent, so as it's required by HTTP 1.1 spec, send one based on the Request Uri.
                 if (!request.HasHeaders || request.Headers.Host == null)
                 {
-                    await WriteHostHeaderAsync(request.RequestUri, cancellationToken).ConfigureAwait(false);
+                    await WriteHostHeaderAsync(request.RequestUri).ConfigureAwait(false);
                 }
 
                 // CRLF for end of headers.
-                await WriteTwoBytesAsync((byte)'\r', (byte)'\n', cancellationToken).ConfigureAwait(false);
+                await WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false);
 
                 Debug.Assert(_sendRequestContentTask == null);
                 if (request.Content == null)
                 {
                     // We have nothing more to send, so flush out any headers we haven't yet sent.
-                    await FlushAsync(cancellationToken).ConfigureAwait(false);
+                    await FlushAsync().ConfigureAwait(false);
                 }
                 else
                 {
@@ -310,13 +314,18 @@ namespace System.Net.Http
                     // to ensure the headers and content are sent.
                     bool transferEncodingChunked = request.HasHeaders && request.Headers.TransferEncodingChunked == true;
                     HttpContentWriteStream stream = transferEncodingChunked ? (HttpContentWriteStream)
-                        new ChunkedEncodingWriteStream(this, cancellationToken) :
-                        new ContentLengthWriteStream(this, cancellationToken);
+                        new ChunkedEncodingWriteStream(this) :
+                        new ContentLengthWriteStream(this);
 
                     if (!request.HasHeaders || request.Headers.ExpectContinue != true)
                     {
-                        // Send the request content asynchronously.
-                        Task sendTask = _sendRequestContentTask = SendRequestContentAsync(request, stream);
+                        // Send the request content asynchronously.  Note that elsewhere in SendAsync we don't pass
+                        // the cancellation token around, as we simply register with it for the duration of the
+                        // method in order to dispose of this connection and wake up any operations.  But SendRequestContentAsync
+                        // is special in that it ends up dealing with an external entity, the request HttpContent provided
+                        // by the caller to this handler, and we could end up blocking as part of getting that content,
+                        // which won't be affected by disposing this connection. Thus, we do pass the token in here.
+                        Task sendTask = _sendRequestContentTask = SendRequestContentAsync(request, stream, cancellationToken);
                         if (sendTask.IsFaulted)
                         {
                             // Technically this isn't necessary: if the task failed, it will have stored the exception
@@ -333,7 +342,7 @@ namespace System.Net.Http
                         // We're sending an Expect: 100-continue header. We need to flush headers so that the server receives
                         // all of them, and we need to do so before initiating the send, as once we do that, it effectively
                         // owns the right to write, and we don't want to concurrently be accessing the write buffer.
-                        await FlushAsync(cancellationToken).ConfigureAwait(false);
+                        await FlushAsync().ConfigureAwait(false);
 
                         // Create a TCS we'll use to block the request content from being sent, and create a timer that's used
                         // as a fail-safe to unblock the request content if we don't hear back from the server in a timely manner.
@@ -342,7 +351,8 @@ namespace System.Net.Http
                         var expect100Timer = new Timer(
                             s => ((TaskCompletionSource<bool>)s).TrySetResult(true),
                             allowExpect100ToContinue, TimeSpan.FromMilliseconds(Expect100TimeoutMilliseconds), Timeout.InfiniteTimeSpan);
-                        _sendRequestContentTask = SendRequestContentWithExpect100ContinueAsync(request, allowExpect100ToContinue.Task, stream, expect100Timer);
+                        _sendRequestContentTask = SendRequestContentWithExpect100ContinueAsync(
+                            request, allowExpect100ToContinue.Task, stream, expect100Timer, cancellationToken);
                     }
                 }
 
@@ -380,9 +390,9 @@ namespace System.Net.Http
                 _canRetry = false;
 
                 // Parse the response status line.
-                var response = new HttpResponseMessage() { RequestMessage = request, Content = new HttpConnectionContent(CancellationToken.None) };
-                ParseStatusLine(await ReadNextLineAsync(cancellationToken).ConfigureAwait(false), response);
-                
+                var response = new HttpResponseMessage() { RequestMessage = request, Content = new HttpConnectionResponseContent() };
+                ParseStatusLine(await ReadNextLineAsync().ConfigureAwait(false), response);
+
                 // If we sent an Expect: 100-continue header, handle the response accordingly.
                 if (allowExpect100ToContinue != null)
                 {
@@ -409,12 +419,12 @@ namespace System.Net.Http
                         if (response.StatusCode == HttpStatusCode.Continue)
                         {
                             // We got our continue header.  Read the subsequent empty line and parse the additional status line.
-                            if (!LineIsEmpty(await ReadNextLineAsync(cancellationToken).ConfigureAwait(false)))
+                            if (!LineIsEmpty(await ReadNextLineAsync().ConfigureAwait(false)))
                             {
                                 ThrowInvalidHttpResponse();
                             }
 
-                            ParseStatusLine(await ReadNextLineAsync(cancellationToken).ConfigureAwait(false), response);
+                            ParseStatusLine(await ReadNextLineAsync().ConfigureAwait(false), response);
                         }
                     }
                 }
@@ -422,7 +432,7 @@ namespace System.Net.Http
                 // Parse the response headers.
                 while (true)
                 {
-                    ArraySegment<byte> line = await ReadNextLineAsync(cancellationToken).ConfigureAwait(false);
+                    ArraySegment<byte> line = await ReadNextLineAsync().ConfigureAwait(false);
                     if (LineIsEmpty(line))
                     {
                         break;
@@ -447,6 +457,13 @@ namespace System.Net.Http
                     _sendRequestContentTask = null;
                 }
 
+                // We're about to create the response stream, at which point responsibility for canceling
+                // the remainder of the response lies with the stream.  Thus we dispose of our registration
+                // here (if an exception has occurred or does occur while creating/returning the stream,
+                // we'll still dispose of it in the catch below as part of Dispose'ing the connection).
+                cancellationRegistration.Dispose();
+                cancellationToken.ThrowIfCancellationRequested(); // in case cancellation may have disposed of the stream
+
                 // Create the response stream.
                 HttpContentStream responseStream;
                 if (request.Method == HttpMethod.Head || (int)response.StatusCode == 204 || (int)response.StatusCode == 304)
@@ -479,7 +496,7 @@ namespace System.Net.Http
                 {
                     responseStream = new ConnectionCloseReadStream(this);
                 }
-                ((HttpConnectionContent)response.Content).SetStream(responseStream);
+                ((HttpConnectionResponseContent)response.Content).SetStream(responseStream);
 
                 if (NetEventSource.IsEnabled) Trace($"Received response: {response}");
 
@@ -493,33 +510,82 @@ namespace System.Net.Http
             }
             catch (Exception error)
             {
+                // Clean up the cancellation registration in case we're still registered.
+                cancellationRegistration.Dispose();
+
                 // Make sure to complete the allowExpect100ToContinue task if it exists.
                 allowExpect100ToContinue?.TrySetResult(false);
 
                 if (NetEventSource.IsEnabled) Trace($"Error sending request: {error}");
                 Dispose();
 
-                if (_pendingException != null)
+                // At this point, we're going to throw an exception; we just need to
+                // determine which exception to throw.
+
+                if (ShouldWrapInOperationCanceledException(error, cancellationToken))
+                {
+                    // Cancellation was requested, so assume that the failure is due to
+                    // the cancellation request. This is a bit unorthodox, as usually we'd
+                    // prioritize a non-OperationCanceledException over a cancellation
+                    // request to avoid losing potentially pertinent information.  But given
+                    // the cancellation design where we tear down the underlying connection upon
+                    // a cancellation request, which can then result in a myriad of different
+                    // exceptions (argument exceptions, object disposed exceptions, socket exceptions,
+                    // etc.), as a middle ground we treat it as cancellation, but still propagate the
+                    // original information as the inner exception, for diagnostic purposes.
+                    throw CreateOperationCanceledException(_pendingException ?? error, cancellationToken);
+                }
+                else if (_pendingException != null)
                 {
                     // If we incurred an exception in non-linear control flow such that
                     // the exception didn't bubble up here (e.g. concurrent sending of
                     // the request content), use that error instead.
                     throw new HttpRequestException(SR.net_http_client_execution_error, _pendingException);
                 }
-
-                // Otherwise, propagate this exception, wrapping it if necessary to
-                // match exception type expectations.
-                if (error is InvalidOperationException || error is IOException)
+                else if (error is InvalidOperationException || error is IOException)
                 {
+                    // If it's an InvalidOperationException or an IOException, for consistency
+                    // with other handlers we wrap the exception in an HttpRequestException.
                     throw new HttpRequestException(SR.net_http_client_execution_error, error);
                 }
-                throw;
+                else
+                {
+                    // Otherwise, just allow the original exception to propagate.
+                    throw;
+                }
             }
         }
 
+        private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken)
+        {
+            // Cancellation design:
+            // - We register with the SendAsync CancellationToken for the duration of the SendAsync operation.
+            // - We register with the Read/Write/CopyToAsync methods on the response stream for each such individual operation.
+            // - The registration disposes of the connection, tearing it down and causing any pending operations to wake up.
+            // - Because such a tear down can result in a variety of different exception types, we check for a cancellation
+            //   request and prioritize that over other exceptions, wrapping the actual exception as an inner of an OCE.
+            // - A weak reference to this HttpConnection is stored in the cancellation token, to prevent the token from
+            //   artificially keeping this connection alive.
+            return cancellationToken.Register(s =>
+            {
+                var weakThisRef = (WeakReference<HttpConnection>)s;
+                if (weakThisRef.TryGetTarget(out HttpConnection strongThisRef))
+                {
+                    if (NetEventSource.IsEnabled) strongThisRef.Trace("Cancellation requested. Disposing of the connection.");
+                    strongThisRef.Dispose();
+                }
+            }, _weakThisRef);
+        }
+
+        private static bool ShouldWrapInOperationCanceledException(Exception error, CancellationToken cancellationToken) =>
+            !(error is OperationCanceledException) && cancellationToken.IsCancellationRequested;
+
+        private static Exception CreateOperationCanceledException(Exception error, CancellationToken cancellationToken) =>
+            new OperationCanceledException(s_cancellationMessage, error, cancellationToken);
+
         private static bool LineIsEmpty(ArraySegment<byte> line) => line.Count == 0;
 
-        private async Task SendRequestContentAsync(HttpRequestMessage request, HttpContentWriteStream stream)
+        private async Task SendRequestContentAsync(HttpRequestMessage request, HttpContentWriteStream stream, CancellationToken cancellationToken)
         {
             // Now that we're sending content, prohibit retries on this connection.
             _canRetry = false;
@@ -527,13 +593,13 @@ namespace System.Net.Http
             try
             {
                 // Copy all of the data to the server.
-                await request.Content.CopyToAsync(stream, _transportContext).ConfigureAwait(false);
+                await request.Content.CopyToAsync(stream, _transportContext, cancellationToken).ConfigureAwait(false);
 
                 // Finish the content; with a chunked upload, this includes writing the terminating chunk.
                 await stream.FinishAsync().ConfigureAwait(false);
 
                 // Flush any content that might still be buffered.
-                await FlushAsync(stream.RequestCancellationToken).ConfigureAwait(false);
+                await FlushAsync().ConfigureAwait(false);
             }
             catch (Exception e)
             {
@@ -545,7 +611,7 @@ namespace System.Net.Http
         }
 
         private async Task SendRequestContentWithExpect100ContinueAsync(
-            HttpRequestMessage request, Task<bool> allowExpect100ToContinueTask, HttpContentWriteStream stream, Timer expect100Timer)
+            HttpRequestMessage request, Task<bool> allowExpect100ToContinueTask, HttpContentWriteStream stream, Timer expect100Timer, CancellationToken cancellationToken)
         {
             // Wait until we receive a trigger notification that it's ok to continue sending content.
             // This will come either when the timer fires or when we receive a response status line from the server.
@@ -558,7 +624,7 @@ namespace System.Net.Http
             if (sendRequestContent)
             {
                 if (NetEventSource.IsEnabled) Trace($"Sending request content for Expect: 100-continue.");
-                await SendRequestContentAsync(request, stream).ConfigureAwait(false);
+                await SendRequestContentAsync(request, stream, cancellationToken).ConfigureAwait(false);
             }
             else
             {
@@ -708,7 +774,7 @@ namespace System.Net.Http
             _writeOffset += source.Length;
         }
 
-        private async Task WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken)
+        private async Task WriteAsync(ReadOnlyMemory<byte> source)
         {
             int remaining = _writeBuffer.Length - _writeOffset;
 
@@ -724,14 +790,14 @@ namespace System.Net.Http
                 // Fit what we can in the current write buffer and flush it.
                 WriteToBuffer(source.Slice(0, remaining));
                 source = source.Slice(remaining);
-                await FlushAsync(cancellationToken).ConfigureAwait(false);
+                await FlushAsync().ConfigureAwait(false);
             }
 
             if (source.Length >= _writeBuffer.Length)
             {
                 // Large write.  No sense buffering this.  Write directly to stream.
                 // CONSIDER: May want to be a bit smarter here?  Think about how large writes should work...
-                await WriteToStreamAsync(source, cancellationToken).ConfigureAwait(false);
+                await WriteToStreamAsync(source).ConfigureAwait(false);
             }
             else
             {
@@ -740,13 +806,13 @@ namespace System.Net.Http
             }
         }
 
-        private Task WriteWithoutBufferingAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken)
+        private Task WriteWithoutBufferingAsync(ReadOnlyMemory<byte> source)
         {
             if (_writeOffset == 0)
             {
                 // There's nothing in the write buffer we need to flush.
                 // Just write the supplied data out to the stream.
-                return WriteToStreamAsync(source, cancellationToken);
+                return WriteToStreamAsync(source);
             }
 
             int remaining = _writeBuffer.Length - _writeOffset;
@@ -757,40 +823,40 @@ namespace System.Net.Http
                 // the content to the write buffer and then flush it, so that we
                 // can do a single send rather than two.
                 WriteToBuffer(source);
-                return FlushAsync(cancellationToken);
+                return FlushAsync();
             }
 
             // There's data in the write buffer and the data we're writing doesn't fit after it.
             // Do two writes, one to flush the buffer and then another to write the supplied content.
-            return FlushThenWriteWithoutBufferingAsync(source, cancellationToken);
+            return FlushThenWriteWithoutBufferingAsync(source);
         }
 
-        private async Task FlushThenWriteWithoutBufferingAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken)
+        private async Task FlushThenWriteWithoutBufferingAsync(ReadOnlyMemory<byte> source)
         {
-            await FlushAsync(cancellationToken).ConfigureAwait(false);
-            await WriteToStreamAsync(source, cancellationToken).ConfigureAwait(false);
+            await FlushAsync().ConfigureAwait(false);
+            await WriteToStreamAsync(source).ConfigureAwait(false);
         }
 
-        private Task WriteByteAsync(byte b, CancellationToken cancellationToken)
+        private Task WriteByteAsync(byte b)
         {
             if (_writeOffset < _writeBuffer.Length)
             {
                 _writeBuffer[_writeOffset++] = b;
                 return Task.CompletedTask;
             }
-            return WriteByteSlowAsync(b, cancellationToken);
+            return WriteByteSlowAsync(b);
         }
 
-        private async Task WriteByteSlowAsync(byte b, CancellationToken cancellationToken)
+        private async Task WriteByteSlowAsync(byte b)
         {
             Debug.Assert(_writeOffset == _writeBuffer.Length);
-            await WriteToStreamAsync(_writeBuffer, cancellationToken).ConfigureAwait(false);
+            await WriteToStreamAsync(_writeBuffer).ConfigureAwait(false);
 
             _writeBuffer[0] = b;
             _writeOffset = 1;
         }
 
-        private Task WriteTwoBytesAsync(byte b1, byte b2, CancellationToken cancellationToken)
+        private Task WriteTwoBytesAsync(byte b1, byte b2)
         {
             if (_writeOffset <= _writeBuffer.Length - 2)
             {
@@ -799,16 +865,16 @@ namespace System.Net.Http
                 buffer[_writeOffset++] = b2;
                 return Task.CompletedTask;
             }
-            return WriteTwoBytesSlowAsync(b1, b2, cancellationToken);
+            return WriteTwoBytesSlowAsync(b1, b2);
         }
 
-        private async Task WriteTwoBytesSlowAsync(byte b1, byte b2, CancellationToken cancellationToken)
+        private async Task WriteTwoBytesSlowAsync(byte b1, byte b2)
         {
-            await WriteByteAsync(b1, cancellationToken).ConfigureAwait(false);
-            await WriteByteAsync(b2, cancellationToken).ConfigureAwait(false);
+            await WriteByteAsync(b1).ConfigureAwait(false);
+            await WriteByteAsync(b2).ConfigureAwait(false);
         }
 
-        private Task WriteBytesAsync(byte[] bytes, CancellationToken cancellationToken)
+        private Task WriteBytesAsync(byte[] bytes)
         {
             if (_writeOffset <= _writeBuffer.Length - bytes.Length)
             {
@@ -816,10 +882,10 @@ namespace System.Net.Http
                 _writeOffset += bytes.Length;
                 return Task.CompletedTask;
             }
-            return WriteBytesSlowAsync(bytes, cancellationToken);
+            return WriteBytesSlowAsync(bytes);
         }
 
-        private async Task WriteBytesSlowAsync(byte[] bytes, CancellationToken cancellationToken)
+        private async Task WriteBytesSlowAsync(byte[] bytes)
         {
             int offset = 0;
             while (true)
@@ -838,13 +904,13 @@ namespace System.Net.Http
                 }
                 else if (_writeOffset == _writeBuffer.Length)
                 {
-                    await WriteToStreamAsync(_writeBuffer, cancellationToken).ConfigureAwait(false);
+                    await WriteToStreamAsync(_writeBuffer).ConfigureAwait(false);
                     _writeOffset = 0;
                 }
             }
         }
 
-        private Task WriteStringAsync(string s, CancellationToken cancellationToken)
+        private Task WriteStringAsync(string s)
         {
             // If there's enough space in the buffer to just copy all of the string's bytes, do so.
             // Unlike WriteAsciiStringAsync, validate each char along the way.
@@ -866,10 +932,10 @@ namespace System.Net.Http
 
             // Otherwise, fall back to doing a normal slow string write; we could optimize away
             // the extra checks later, but the case where we cross a buffer boundary should be rare.
-            return WriteStringAsyncSlow(s, cancellationToken);
+            return WriteStringAsyncSlow(s);
         }
 
-        private Task WriteAsciiStringAsync(string s, CancellationToken cancellationToken)
+        private Task WriteAsciiStringAsync(string s)
         {
             // If there's enough space in the buffer to just copy all of the string's bytes, do so.
             int offset = _writeOffset;
@@ -886,10 +952,10 @@ namespace System.Net.Http
 
             // Otherwise, fall back to doing a normal slow string write; we could optimize away
             // the extra checks later, but the case where we cross a buffer boundary should be rare.
-            return WriteStringAsyncSlow(s, cancellationToken);
+            return WriteStringAsyncSlow(s);
         }
 
-        private async Task WriteStringAsyncSlow(string s, CancellationToken cancellationToken)
+        private async Task WriteStringAsyncSlow(string s)
         {
             for (int i = 0; i < s.Length; i++)
             {
@@ -898,28 +964,28 @@ namespace System.Net.Http
                 {
                     throw new HttpRequestException(SR.net_http_request_invalid_char_encoding);
                 }
-                await WriteByteAsync((byte)c, cancellationToken).ConfigureAwait(false);
+                await WriteByteAsync((byte)c).ConfigureAwait(false);
             }
         }
 
-        private Task FlushAsync(CancellationToken cancellationToken)
+        private Task FlushAsync()
         {
             if (_writeOffset > 0)
             {
-                Task t = WriteToStreamAsync(new ReadOnlyMemory<byte>(_writeBuffer, 0, _writeOffset), cancellationToken);
+                Task t = WriteToStreamAsync(new ReadOnlyMemory<byte>(_writeBuffer, 0, _writeOffset));
                 _writeOffset = 0;
                 return t;
             }
             return Task.CompletedTask;
         }
 
-        private Task WriteToStreamAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken)
+        private Task WriteToStreamAsync(ReadOnlyMemory<byte> source)
         {
             if (NetEventSource.IsEnabled) Trace($"Writing {source.Length} bytes.");
-            return _stream.WriteAsync(source, cancellationToken);
+            return _stream.WriteAsync(source);
         }
 
-        private async ValueTask<ArraySegment<byte>> ReadNextLineAsync(CancellationToken cancellationToken)
+        private async ValueTask<ArraySegment<byte>> ReadNextLineAsync()
         {
             int previouslyScannedBytes = 0;
             while (true)
@@ -954,12 +1020,12 @@ namespace System.Net.Http
                 {
                     ThrowInvalidHttpResponse();
                 }
-                await FillAsync(cancellationToken).ConfigureAwait(false);
+                await FillAsync().ConfigureAwait(false);
             }
         }
 
         // Throws IOException on EOF.  This is only called when we expect more data.
-        private async Task FillAsync(CancellationToken cancellationToken)
+        private async Task FillAsync()
         {
             Debug.Assert(_readAheadTask == null);
 
@@ -994,7 +1060,7 @@ namespace System.Net.Http
                 _readLength = remaining;
             }
 
-            int bytesRead = await _stream.ReadAsync(new Memory<byte>(_readBuffer, _readLength, _readBuffer.Length - _readLength), cancellationToken).ConfigureAwait(false);
+            int bytesRead = await _stream.ReadAsync(new Memory<byte>(_readBuffer, _readLength, _readBuffer.Length - _readLength)).ConfigureAwait(false);
 
             if (NetEventSource.IsEnabled) Trace($"Received {bytesRead} bytes.");
             if (bytesRead == 0)
@@ -1013,7 +1079,7 @@ namespace System.Net.Http
             _readOffset += buffer.Length;
         }
 
-        private async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken)
+        private async ValueTask<int> ReadAsync(Memory<byte> destination)
         {
             // This is called when reading the response body
 
@@ -1036,28 +1102,28 @@ namespace System.Net.Http
             // No data in read buffer. 
             // Do an unbuffered read directly against the underlying stream.
             Debug.Assert(_readAheadTask == null, "Read ahead task should have been consumed as part of the headers.");
-            int count = await _stream.ReadAsync(destination, cancellationToken).ConfigureAwait(false);
+            int count = await _stream.ReadAsync(destination).ConfigureAwait(false);
             if (NetEventSource.IsEnabled) Trace($"Received {count} bytes.");
             return count;
         }
 
-        private async Task CopyFromBufferAsync(Stream destination, int count, CancellationToken cancellationToken)
+        private async Task CopyFromBufferAsync(Stream destination, int count)
         {
             Debug.Assert(count <= _readLength - _readOffset);
 
             if (NetEventSource.IsEnabled) Trace($"Copying {count} bytes to stream.");
-            await destination.WriteAsync(_readBuffer, _readOffset, count, cancellationToken).ConfigureAwait(false);
+            await destination.WriteAsync(_readBuffer, _readOffset, count).ConfigureAwait(false);
             _readOffset += count;
         }
 
-        private async Task CopyToAsync(Stream destination, CancellationToken cancellationToken)
+        private async Task CopyToAsync(Stream destination)
         {
             Debug.Assert(destination != null);
 
             int remaining = _readLength - _readOffset;
             if (remaining > 0)
             {
-                await CopyFromBufferAsync(destination, remaining, cancellationToken).ConfigureAwait(false);
+                await CopyFromBufferAsync(destination, remaining).ConfigureAwait(false);
             }
 
             while (true)
@@ -1066,19 +1132,19 @@ namespace System.Net.Http
 
                 // Don't use FillAsync here as it will throw on EOF.
                 Debug.Assert(_readAheadTask == null);
-                _readLength = await _stream.ReadAsync(_readBuffer, cancellationToken).ConfigureAwait(false);
+                _readLength = await _stream.ReadAsync(_readBuffer).ConfigureAwait(false);
                 if (_readLength == 0)
                 {
                     // End of stream
                     break;
                 }
 
-                await CopyFromBufferAsync(destination, _readLength, cancellationToken).ConfigureAwait(false);
+                await CopyFromBufferAsync(destination, _readLength).ConfigureAwait(false);
             }
         }
 
         // Copy *exactly* [length] bytes into destination; throws on end of stream.
-        private async Task CopyToAsync(Stream destination, ulong length, CancellationToken cancellationToken)
+        private async Task CopyToAsync(Stream destination, ulong length)
         {
             Debug.Assert(destination != null);
             Debug.Assert(length > 0);
@@ -1090,7 +1156,7 @@ namespace System.Net.Http
                 {
                     remaining = (int)length;
                 }
-                await CopyFromBufferAsync(destination, remaining, cancellationToken).ConfigureAwait(false);
+                await CopyFromBufferAsync(destination, remaining).ConfigureAwait(false);
 
                 length -= (ulong)remaining;
                 if (length == 0)
@@ -1101,10 +1167,10 @@ namespace System.Net.Http
 
             while (true)
             {
-                await FillAsync(cancellationToken).ConfigureAwait(false);
+                await FillAsync().ConfigureAwait(false);
 
                 remaining = (ulong)_readLength < length ? _readLength : (int)length;
-                await CopyFromBufferAsync(destination, remaining, cancellationToken).ConfigureAwait(false);
+                await CopyFromBufferAsync(destination, remaining).ConfigureAwait(false);
 
                 length -= (ulong)remaining;
                 if (length == 0)
@@ -1177,8 +1243,7 @@ namespace System.Net.Http
             {
                 try
                 {
-                    // Null out the associated request before the connection is potentially reused by another.
-                    _currentRequest = null;
+                    // Any remaining request content has completed successfully.  Drop it.
                     _sendRequestContentTask = null;
 
                     // When putting a connection back into the pool, we initiate a pre-emptive
@@ -11,16 +11,10 @@ namespace System.Net.Http
 {
     internal partial class HttpConnection : IDisposable
     {
-        private sealed class HttpConnectionContent : HttpContent
+        private sealed class HttpConnectionResponseContent : HttpContent
         {
-            private readonly CancellationToken _cancellationToken;
             private HttpContentStream _stream;
 
-            public HttpConnectionContent(CancellationToken cancellationToken)
-            {
-                _cancellationToken = cancellationToken;
-            }
-
             public void SetStream(HttpContentStream stream)
             {
                 Debug.Assert(stream != null);
@@ -41,30 +35,33 @@ namespace System.Net.Http
                 return stream;
             }
 
-            protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context)
+            protected sealed override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+                SerializeToStreamAsync(stream, context, CancellationToken.None);
+
+            internal sealed override async Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken)
             {
                 Debug.Assert(stream != null);
 
                 using (HttpContentStream contentStream = ConsumeStream())
                 {
                     const int BufferSize = 8192;
-                    await contentStream.CopyToAsync(stream, BufferSize, _cancellationToken).ConfigureAwait(false);
+                    await contentStream.CopyToAsync(stream, BufferSize, cancellationToken).ConfigureAwait(false);
                 }
             }
 
-            protected internal override bool TryComputeLength(out long length)
+            protected internal sealed override bool TryComputeLength(out long length)
             {
                 length = 0;
                 return false;
             }
 
-            protected override Task<Stream> CreateContentReadStreamAsync() =>
+            protected sealed override Task<Stream> CreateContentReadStreamAsync() =>
                 Task.FromResult<Stream>(ConsumeStream());
 
-            internal override Stream TryCreateContentReadStream() =>
+            internal sealed override Stream TryCreateContentReadStream() =>
                 ConsumeStream();
 
-            protected override void Dispose(bool disposing)
+            protected sealed override void Dispose(bool disposing)
             {
                 if (disposing)
                 {
index 022e876..73e04ed 100644 (file)
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-using System.Diagnostics;
 using System.IO;
 using System.Threading;
 
@@ -14,21 +13,21 @@ namespace System.Net.Http
         {
         }
 
-        public override bool CanRead => true;
-        public override bool CanWrite => true;
+        public sealed override bool CanRead => true;
+        public sealed override bool CanWrite => true;
 
-        public override void Flush() => FlushAsync().GetAwaiter().GetResult();
+        public sealed override void Flush() => FlushAsync().GetAwaiter().GetResult();
 
-        public override int Read(byte[] buffer, int offset, int count)
+        public sealed override int Read(byte[] buffer, int offset, int count)
         {
             ValidateBufferArgs(buffer, offset, count);
             return ReadAsync(new Memory<byte>(buffer, offset, count), CancellationToken.None).GetAwaiter().GetResult();
         }
 
-        public override void Write(byte[] buffer, int offset, int count) =>
+        public sealed override void Write(byte[] buffer, int offset, int count) =>
             WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult();
 
-        public override void CopyTo(Stream destination, int bufferSize) =>
+        public sealed override void CopyTo(Stream destination, int bufferSize) =>
             CopyToAsync(destination, bufferSize, CancellationToken.None).GetAwaiter().GetResult();
     }
 }
index ed5cce8..fdad589 100644 (file)
@@ -13,20 +13,20 @@ namespace System.Net.Http
         {
         }
 
-        public override bool CanRead => true;
-        public override bool CanWrite => false;
+        public sealed override bool CanRead => true;
+        public sealed override bool CanWrite => false;
 
-        public override void Flush() { }
+        public sealed override void Flush() { }
 
-        public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
+        public sealed override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
 
-        public override int Read(byte[] buffer, int offset, int count)
+        public sealed override int Read(byte[] buffer, int offset, int count)
         {
             ValidateBufferArgs(buffer, offset, count);
             return ReadAsync(new Memory<byte>(buffer, offset, count), CancellationToken.None).GetAwaiter().GetResult();
         }
 
-        public override void CopyTo(Stream destination, int bufferSize) =>
+        public sealed override void CopyTo(Stream destination, int bufferSize) =>
             CopyToAsync(destination, bufferSize, CancellationToken.None).GetAwaiter().GetResult();
     }
 }
index b54f6c4..ad329ff 100644 (file)
@@ -31,27 +31,27 @@ namespace System.Net.Http
             base.Dispose(disposing);
         }
 
-        public override bool CanSeek => false;
+        public sealed override bool CanSeek => false;
 
-        public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) =>
+        public sealed override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) =>
             TaskToApm.Begin(ReadAsync(buffer, offset, count, default(CancellationToken)), callback, state);
 
-        public override int EndRead(IAsyncResult asyncResult) =>
+        public sealed override int EndRead(IAsyncResult asyncResult) =>
             TaskToApm.End<int>(asyncResult);
 
-        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) =>
+        public sealed override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) =>
             TaskToApm.Begin(WriteAsync(buffer, offset, count, default(CancellationToken)), callback, state);
 
-        public override void EndWrite(IAsyncResult asyncResult) =>
+        public sealed override void EndWrite(IAsyncResult asyncResult) =>
             TaskToApm.End(asyncResult);
 
-        public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
+        public sealed override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
 
-        public override void SetLength(long value) => throw new NotSupportedException();
+        public sealed override void SetLength(long value) => throw new NotSupportedException();
 
-        public override long Length => throw new NotSupportedException();
+        public sealed override long Length => throw new NotSupportedException();
 
-        public override long Position
+        public sealed override long Position
         {
             get { throw new NotSupportedException(); }
             set { throw new NotSupportedException(); }
index 54c68df..b302e89 100644 (file)
@@ -3,7 +3,6 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Diagnostics;
-using System.IO;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -11,28 +10,17 @@ namespace System.Net.Http
 {
     internal abstract class HttpContentWriteStream : HttpContentStream
     {
-        public HttpContentWriteStream(HttpConnection connection, CancellationToken cancellationToken) : base(connection)
-        {
+        public HttpContentWriteStream(HttpConnection connection) : base(connection) =>
             Debug.Assert(connection != null);
-            RequestCancellationToken = cancellationToken;
-        }
 
-        /// <summary>Cancellation token associated with the send operation.</summary>
-        /// <remarks>
-        /// Because of how this write stream is used, the CancellationToken passed into the individual
-        /// stream operations will be the default non-cancelable token and can be ignored.  Instead,
-        /// this token is used.
-        /// </remarks>
-        internal CancellationToken RequestCancellationToken { get; }
+        public sealed override bool CanRead => false;
+        public sealed override bool CanWrite => true;
 
-        public override bool CanRead => false;
-        public override bool CanWrite => true;
+        public sealed override void Flush() => FlushAsync().GetAwaiter().GetResult();
 
-        public override void Flush() => FlushAsync().GetAwaiter().GetResult();
+        public sealed override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
 
-        public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
-
-        public override void Write(byte[] buffer, int offset, int count) =>
+        public sealed override void Write(byte[] buffer, int offset, int count) =>
             WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult();
 
         public abstract Task FinishAsync();
index 7073afc..e30f76d 100644 (file)
@@ -22,17 +22,44 @@ namespace System.Net.Http
                 return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
             }
 
-            public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
+            public override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken)
             {
+                cancellationToken.ThrowIfCancellationRequested();
+
                 if (_connection == null || destination.Length == 0)
                 {
                     // Response body fully consumed or the caller didn't ask for any data
                     return 0;
                 }
 
-                int bytesRead = await _connection.ReadAsync(destination, cancellationToken).ConfigureAwait(false);
+                ValueTask<int> readTask = _connection.ReadAsync(destination);
+                int bytesRead;
+                if (readTask.IsCompletedSuccessfully)
+                {
+                    bytesRead = readTask.Result;
+                }
+                else
+                {
+                    CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                    try
+                    {
+                        bytesRead = await readTask.ConfigureAwait(false);
+                    }
+                    catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                    {
+                        throw CreateOperationCanceledException(exc, cancellationToken);
+                    }
+                    finally
+                    {
+                        ctr.Dispose();
+                    }
+                }
+
                 if (bytesRead == 0)
                 {
+                    // A cancellation request may have caused the EOF.
+                    cancellationToken.ThrowIfCancellationRequested();
+
                     // We cannot reuse this connection, so close it.
                     _connection.Dispose();
                     _connection = null;
@@ -48,15 +75,40 @@ namespace System.Net.Http
                 {
                     throw new ArgumentNullException(nameof(destination));
                 }
+                if (bufferSize <= 0)
+                {
+                    throw new ArgumentOutOfRangeException(nameof(bufferSize));
+                }
 
-                if (_connection != null) // null if response body fully consumed
+                cancellationToken.ThrowIfCancellationRequested();
+
+                if (_connection == null)
                 {
-                    await _connection.CopyToAsync(destination, cancellationToken).ConfigureAwait(false);
+                    // Response body fully consumed
+                    return;
+                }
 
-                    // We cannot reuse this connection, so close it.
-                    _connection.Dispose();
-                    _connection = null;
+                Task copyTask = _connection.CopyToAsync(destination);
+                if (!copyTask.IsCompletedSuccessfully)
+                {
+                    CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                    try
+                    {
+                        await copyTask.ConfigureAwait(false);
+                    }
+                    catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                    {
+                        throw CreateOperationCanceledException(exc, cancellationToken);
+                    }
+                    finally
+                    {
+                        ctr.Dispose();
+                    }
                 }
+
+                // We cannot reuse this connection, so close it.
+                _connection.Dispose();
+                _connection = null;
             }
 
             public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
@@ -65,14 +117,63 @@ namespace System.Net.Http
                 return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken);
             }
 
-            public override Task WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default) =>
-                _connection == null ? Task.FromException(new IOException(SR.net_http_io_write)) :
-                source.Length > 0 ? _connection.WriteWithoutBufferingAsync(source, cancellationToken) :
-                Task.CompletedTask;
+            public override Task WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken)
+            {
+                if (cancellationToken.IsCancellationRequested)
+                {
+                    return Task.FromCanceled(cancellationToken);
+                }
+
+                if (_connection == null)
+                {
+                    return Task.FromException(new IOException(SR.net_http_io_write));
+                }
+
+                if (source.Length == 0)
+                {
+                    return Task.CompletedTask;
+                }
+
+                Task writeTask = _connection.WriteWithoutBufferingAsync(source);
+                return writeTask.IsCompleted ?
+                    writeTask :
+                    WaitWithConnectionCancellationAsync(writeTask, cancellationToken);
+            }
+
+            public override Task FlushAsync(CancellationToken cancellationToken)
+            {
+                if (cancellationToken.IsCancellationRequested)
+                {
+                    return Task.FromCanceled(cancellationToken);
+                }
+
+                if (_connection == null)
+                {
+                    return Task.CompletedTask;
+                }
+
+                Task flushTask = _connection.FlushAsync();
+                return flushTask.IsCompleted ?
+                    flushTask :
+                    WaitWithConnectionCancellationAsync(flushTask, cancellationToken);
+            }
 
-            public override Task FlushAsync(CancellationToken cancellationToken) =>
-                _connection != null ? _connection.FlushAsync(cancellationToken) :
-                Task.CompletedTask;
+            private async Task WaitWithConnectionCancellationAsync(Task task, CancellationToken cancellationToken)
+            {
+                CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken);
+                try
+                {
+                    await task.ConfigureAwait(false);
+                }
+                catch (Exception exc) when (ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                {
+                    throw CreateOperationCanceledException(exc, cancellationToken);
+                }
+                finally
+                {
+                    ctr.Dispose();
+                }
+            }
         }
     }
 }
diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/CancellationTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/CancellationTest.cs
deleted file mode 100644 (file)
index f29922b..0000000
+++ /dev/null
@@ -1,162 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Diagnostics;
-using System.IO;
-using System.Net.Test.Common;
-using System.Threading;
-using System.Threading.Tasks;
-
-using Xunit;
-using Xunit.Abstractions;
-
-namespace System.Net.Http.Functional.Tests
-{
-    public class CancellationTest : HttpClientTestBase
-    {
-        private readonly ITestOutputHelper _output;
-
-        public CancellationTest(ITestOutputHelper output)
-        {
-            _output = output;
-        }
-
-        [OuterLoop] // includes seconds of delay
-        [Theory]
-        [InlineData(false, false)]
-        [InlineData(false, true)]
-        [InlineData(true, false)]
-        [InlineData(true, true)]
-        [ActiveIssue("dotnet/corefx #20010", TargetFrameworkMonikers.Uap)]
-        [ActiveIssue("dotnet/corefx #19038", TargetFrameworkMonikers.NetFramework)]
-        public async Task GetAsync_ResponseContentRead_CancelUsingTimeoutOrToken_TaskCanceledQuickly(
-            bool useTimeout, bool startResponseBody)
-        {
-            var cts = new CancellationTokenSource(); // ignored if useTimeout==true
-            TimeSpan timeout = useTimeout ? new TimeSpan(0, 0, 1) : Timeout.InfiniteTimeSpan;
-            CancellationToken cancellationToken = useTimeout ? CancellationToken.None : cts.Token;
-
-            using (HttpClient client = CreateHttpClient())
-            {
-                client.Timeout = timeout;
-
-                var triggerResponseWrite = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
-                var triggerRequestCancel = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
-
-                await LoopbackServer.CreateServerAsync(async (server, url) =>
-                {
-                    Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) =>
-                    {
-                        while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ;
-                        await writer.WriteAsync(
-                            "HTTP/1.1 200 OK\r\n" +
-                            $"Date: {DateTimeOffset.UtcNow:R}\r\n" +
-                            "Content-Length: 16000\r\n" +
-                            "\r\n" +
-                            (startResponseBody ? "less than 16000 bytes" : ""));
-
-                        await Task.Delay(1000);
-                        triggerRequestCancel.SetResult(true); // allow request to cancel
-                        await triggerResponseWrite.Task; // pause until we're released
-                        
-                        return null;
-                    });
-
-                    var stopwatch = Stopwatch.StartNew();
-                    if (PlatformDetection.IsFullFramework)
-                    {
-                        // .NET Framework throws WebException instead of OperationCanceledException.
-                        await Assert.ThrowsAnyAsync<WebException>(async () =>
-                        {
-                            Task<HttpResponseMessage> getResponse = client.GetAsync(url, HttpCompletionOption.ResponseContentRead, cancellationToken);
-                            await triggerRequestCancel.Task;
-                            cts.Cancel();
-                            await getResponse;
-                        });
-                    }
-                    else
-                    {
-                        await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
-                        {
-                            Task<HttpResponseMessage> getResponse = client.GetAsync(url, HttpCompletionOption.ResponseContentRead, cancellationToken);
-                            await triggerRequestCancel.Task;
-                            cts.Cancel();
-                            await getResponse;
-                        });
-                    }
-                    stopwatch.Stop();
-                    _output.WriteLine("GetAsync() completed at: {0}", stopwatch.Elapsed.ToString());
-
-                    triggerResponseWrite.SetResult(true);
-                    Assert.True(stopwatch.Elapsed < new TimeSpan(0, 0, 30), $"Elapsed time {stopwatch.Elapsed} should be less than 30 seconds, was {stopwatch.Elapsed.TotalSeconds}");
-                });
-            }
-        }
-
-        [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, "dotnet/corefx #18864")] // Hangs on NETFX
-        [ActiveIssue(9075, TestPlatforms.AnyUnix)] // recombine this test into the subsequent one when issue is fixed
-        [OuterLoop] // includes seconds of delay
-        [Fact]
-        public Task ReadAsStreamAsync_ReadAsync_Cancel_BodyNeverStarted_TaskCanceledQuickly()
-        {
-            return ReadAsStreamAsync_ReadAsync_Cancel_TaskCanceledQuickly(false);
-        }
-
-        [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, "dotnet/corefx #18864")] // Hangs on NETFX
-        [OuterLoop] // includes seconds of delay
-        [Theory]
-        [InlineData(true)]
-        public async Task ReadAsStreamAsync_ReadAsync_Cancel_TaskCanceledQuickly(bool startResponseBody)
-        {
-            using (HttpClient client = CreateHttpClient())
-            {
-                await LoopbackServer.CreateServerAsync(async (server, url) =>
-                {
-                    var triggerResponseWrite = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
-
-                    Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) =>
-                    {
-                        while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ;
-                        await writer.WriteAsync(
-                            "HTTP/1.1 200 OK\r\n" +
-                            $"Date: {DateTimeOffset.UtcNow:R}\r\n" +
-                            "Content-Length: 16000\r\n" +
-                            "\r\n" +
-                            (startResponseBody ? "20 bytes of the body" : ""));
-
-                        await triggerResponseWrite.Task; // pause until we're released
-                        
-                        return null;
-                    });
-
-                    using (HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead))
-                    using (Stream responseStream = await response.Content.ReadAsStreamAsync())
-                    {
-                        // Read all expected content
-                        byte[] buffer = new byte[20];
-                        if (startResponseBody)
-                        {
-                            int totalRead = 0;
-                            int bytesRead;
-                            while (totalRead < 20 && (bytesRead = await responseStream.ReadAsync(buffer, 0, buffer.Length)) > 0)
-                            {
-                                totalRead += bytesRead;
-                            }
-                        }
-
-                        // Now do a read that'll need to be canceled
-                        var stopwatch = Stopwatch.StartNew();
-                        await Assert.ThrowsAnyAsync<OperationCanceledException>(
-                            () => responseStream.ReadAsync(buffer, 0, buffer.Length, new CancellationTokenSource(1000).Token));
-                        stopwatch.Stop();
-
-                        triggerResponseWrite.SetResult(true);
-                        _output.WriteLine("ReadAsync() completed at: {0}", stopwatch.Elapsed.ToString());
-                        Assert.True(stopwatch.Elapsed < new TimeSpan(0, 0, 30), $"Elapsed time {stopwatch.Elapsed} should be less than 30 seconds, was {stopwatch.Elapsed.TotalSeconds}");
-                    }
-                });
-            }
-        }
-    }
-}
diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs
new file mode 100644 (file)
index 0000000..db3f4a8
--- /dev/null
@@ -0,0 +1,478 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.IO;
+using System.Net.Test.Common;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.Http.Functional.Tests
+{
+    public class HttpClientHandler_Cancellation_Test : HttpClientTestBase
+    {
+        [Theory]
+        [MemberData(nameof(TwoBoolsAndCancellationMode))]
+        public async Task PostAsync_CancelDuringRequestContentSend_TaskCanceledQuickly(bool chunkedTransfer, bool connectionClose, CancellationMode mode)
+        {
+            if (IsWinHttpHandler || IsNetfxHandler)
+            {
+                // Issue #27063: hangs / doesn't cancel
+                return;
+            }
+
+            using (HttpClient client = CreateHttpClient())
+            {
+                client.Timeout = Timeout.InfiniteTimeSpan;
+                var cts = new CancellationTokenSource();
+
+                await LoopbackServer.CreateServerAsync(async (server, url) =>
+                {
+                    Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) =>
+                    {
+                        // Since we won't receive all of the request, just read everything we do get
+                        byte[] ignored = new byte[100];
+                        while (await stream.ReadAsync(ignored, 0, ignored.Length) > 0);
+                        return null;
+                    });
+
+                    var preContentSent = new TaskCompletionSource<bool>();
+                    var sendPostContent = new TaskCompletionSource<bool>();
+
+                    await ValidateClientCancellationAsync(async () =>
+                    {
+                        var req = new HttpRequestMessage(HttpMethod.Post, url);
+                        req.Content = new DelayedByteContent(2000, 3000, preContentSent, sendPostContent.Task);
+                        req.Headers.TransferEncodingChunked = chunkedTransfer;
+                        req.Headers.ConnectionClose = connectionClose;
+
+                        Task<HttpResponseMessage> postResponse = client.SendAsync(req, HttpCompletionOption.ResponseHeadersRead, cts.Token);
+                        await preContentSent.Task;
+                        Cancel(mode, client, cts);
+                        await postResponse;
+                    });
+
+                    try
+                    {
+                        sendPostContent.SetResult(true);
+                        await serverTask;
+                    } catch { }
+                });
+            }
+        }
+        
+        [Theory]
+        [MemberData(nameof(TwoBoolsAndCancellationMode))]
+        public async Task GetAsync_CancelDuringResponseHeadersReceived_TaskCanceledQuickly(bool chunkedTransfer, bool connectionClose, CancellationMode mode)
+        {
+            using (HttpClient client = CreateHttpClient())
+            {
+                client.Timeout = Timeout.InfiniteTimeSpan;
+                var cts = new CancellationTokenSource();
+
+                await LoopbackServer.CreateServerAsync(async (server, url) =>
+                {
+                    var partialResponseHeadersSent = new TaskCompletionSource<bool>();
+                    var clientFinished = new TaskCompletionSource<bool>();
+
+                    Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) =>
+                    {
+                        while (!string.IsNullOrEmpty(await reader.ReadLineAsync()));
+
+                        await writer.WriteAsync($"HTTP/1.1 200 OK\r\nDate: {DateTimeOffset.UtcNow:R}\r\n"); // missing final \r\n so headers don't complete
+
+                        partialResponseHeadersSent.TrySetResult(true);
+                        await clientFinished.Task;
+
+                        return null;
+                    });
+
+                    await ValidateClientCancellationAsync(async () =>
+                    {
+                        var req = new HttpRequestMessage(HttpMethod.Get, url);
+                        req.Headers.ConnectionClose = connectionClose;
+
+                        Task<HttpResponseMessage> getResponse = client.SendAsync(req, HttpCompletionOption.ResponseHeadersRead, cts.Token);
+                        await partialResponseHeadersSent.Task;
+                        Cancel(mode, client, cts);
+                        await getResponse;
+                    });
+
+                    try
+                    {
+                        clientFinished.SetResult(true);
+                        await serverTask;
+                    } catch { }
+                });
+            }
+        }
+
+        [Theory]
+        [MemberData(nameof(TwoBoolsAndCancellationMode))]
+        public async Task GetAsync_CancelDuringResponseBodyReceived_Buffered_TaskCanceledQuickly(bool chunkedTransfer, bool connectionClose, CancellationMode mode)
+        {
+            using (HttpClient client = CreateHttpClient())
+            {
+                client.Timeout = Timeout.InfiniteTimeSpan;
+                var cts = new CancellationTokenSource();
+
+                await LoopbackServer.CreateServerAsync(async (server, url) =>
+                {
+                    var responseHeadersSent = new TaskCompletionSource<bool>();
+                    var clientFinished = new TaskCompletionSource<bool>();
+
+                    Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) =>
+                    {
+                        while (!string.IsNullOrEmpty(await reader.ReadLineAsync()));
+
+                        await writer.WriteAsync(
+                            $"HTTP/1.1 200 OK\r\n" +
+                            $"Date: {DateTimeOffset.UtcNow:R}\r\n" +
+                            (!chunkedTransfer ? "Content-Length: 20\r\n" : "") +
+                            (connectionClose ? "Connection: close\r\n" : "") +
+                            $"\r\n123"); // "123" is part of body and could either be chunked size or part of content-length bytes, both incomplete
+
+                        responseHeadersSent.TrySetResult(true);
+                        await clientFinished.Task;
+
+                        return null;
+                    });
+
+                    await ValidateClientCancellationAsync(async () =>
+                    {
+                        var req = new HttpRequestMessage(HttpMethod.Get, url);
+                        req.Headers.ConnectionClose = connectionClose;
+
+                        Task<HttpResponseMessage> getResponse = client.SendAsync(req, HttpCompletionOption.ResponseContentRead, cts.Token);
+                        await responseHeadersSent.Task;
+                        await Task.Delay(1); // make it more likely that client will have started processing response body
+                        Cancel(mode, client, cts);
+                        await getResponse;
+                    });
+
+                    try
+                    {
+                        clientFinished.SetResult(true);
+                        await serverTask;
+                    } catch { }
+                });
+            }
+        }
+
+        [Theory]
+        [MemberData(nameof(ThreeBools))]
+        public async Task GetAsync_CancelDuringResponseBodyReceived_Unbuffered_TaskCanceledQuickly(bool chunkedTransfer, bool connectionClose, bool readOrCopyToAsync)
+        {
+            if (IsNetfxHandler || IsCurlHandler)
+            {
+                // doesn't cancel
+                return;
+            }
+
+            using (HttpClient client = CreateHttpClient())
+            {
+                client.Timeout = Timeout.InfiniteTimeSpan;
+                var cts = new CancellationTokenSource();
+
+                await LoopbackServer.CreateServerAsync(async (server, url) =>
+                {
+                    var clientFinished = new TaskCompletionSource<bool>();
+
+                    Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) =>
+                    {
+                        while (!string.IsNullOrEmpty(await reader.ReadLineAsync()));
+
+                        await writer.WriteAsync(
+                            $"HTTP/1.1 200 OK\r\n" +
+                            $"Date: {DateTimeOffset.UtcNow:R}\r\n" +
+                            (!chunkedTransfer ? "Content-Length: 20\r\n" : "") +
+                            (connectionClose ? "Connection: close\r\n" : "") +
+                            $"\r\n");
+
+                        await clientFinished.Task;
+
+                        return null;
+                    });
+
+                    var req = new HttpRequestMessage(HttpMethod.Get, url);
+                    req.Headers.ConnectionClose = connectionClose;
+                    Task<HttpResponseMessage> getResponse = client.SendAsync(req, HttpCompletionOption.ResponseHeadersRead, cts.Token);
+                    await ValidateClientCancellationAsync(async () =>
+                    {
+                        HttpResponseMessage resp = await getResponse;
+                        Stream respStream = await resp.Content.ReadAsStreamAsync();
+                        Task readTask = readOrCopyToAsync ?
+                            respStream.ReadAsync(new byte[1], 0, 1, cts.Token) :
+                            respStream.CopyToAsync(Stream.Null, 10, cts.Token);
+                        cts.Cancel();
+                        await readTask;
+                    });
+
+                    try
+                    {
+                        clientFinished.SetResult(true);
+                        await serverTask;
+                    } catch { }
+                });
+            }
+        }
+
+        [Theory]
+        [InlineData(CancellationMode.CancelPendingRequests, false)]
+        [InlineData(CancellationMode.DisposeHttpClient, true)]
+        [InlineData(CancellationMode.CancelPendingRequests, false)]
+        [InlineData(CancellationMode.DisposeHttpClient, true)]
+        public async Task GetAsync_CancelPendingRequests_DoesntCancelReadAsyncOnResponseStream(CancellationMode mode, bool copyToAsync)
+        {
+            if (IsNetfxHandler)
+            {
+                // throws ObjectDisposedException as part of Stream.CopyToAsync/ReadAsync
+                return;
+            }
+            if (IsCurlHandler)
+            {
+                // Issue #27065
+                // throws OperationCanceledException from Stream.CopyToAsync/ReadAsync
+                return;
+            }
+
+            using (HttpClient client = CreateHttpClient())
+            {
+                client.Timeout = Timeout.InfiniteTimeSpan;
+
+                await LoopbackServer.CreateServerAsync(async (server, url) =>
+                {
+                    var clientReadSomeBody = new TaskCompletionSource<bool>();
+                    var clientFinished = new TaskCompletionSource<bool>();
+
+                    var responseContentSegment = new string('s', 3000);
+                    int responseSegments = 4;
+                    int contentLength = responseContentSegment.Length * responseSegments;
+
+                    Task serverTask = LoopbackServer.AcceptSocketAsync(server, async (socket, stream, reader, writer) =>
+                    {
+                        while (!string.IsNullOrEmpty(await reader.ReadLineAsync()));
+
+                        await writer.WriteAsync(
+                            $"HTTP/1.1 200 OK\r\n" +
+                            $"Date: {DateTimeOffset.UtcNow:R}\r\n" +
+                            $"Content-Length: {contentLength}\r\n" +
+                            $"\r\n");
+
+                        for (int i = 0; i < responseSegments; i++)
+                        {
+                            await writer.WriteAsync(responseContentSegment);
+                            if (i == 0)
+                            {
+                                await clientReadSomeBody.Task;
+                            }
+                        }
+
+                        await clientFinished.Task;
+
+                        return null;
+                    });
+
+
+                    using (HttpResponseMessage resp = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead))
+                    using (Stream respStream = await resp.Content.ReadAsStreamAsync())
+                    {
+                        var result = new MemoryStream();
+                        int b = respStream.ReadByte();
+                        Assert.NotEqual(-1, b);
+                        result.WriteByte((byte)b);
+
+                        Cancel(mode, client, null); // should not cancel the operation, as using ResponseHeadersRead
+                        clientReadSomeBody.SetResult(true);
+
+                        if (copyToAsync)
+                        {
+                            await respStream.CopyToAsync(result, 10, new CancellationTokenSource().Token);
+                        }
+                        else
+                        {
+                            byte[] buffer = new byte[10];
+                            int bytesRead;
+                            while ((bytesRead = await respStream.ReadAsync(buffer, 0, buffer.Length)) > 0)
+                            {
+                                result.Write(buffer, 0, bytesRead);
+                            }
+                        }
+
+                        Assert.Equal(contentLength, result.Length);
+                    }
+
+                    clientFinished.SetResult(true);
+                    await serverTask;
+                });
+            }
+        }
+
+        [Fact]
+        public async Task MaxConnectionsPerServer_WaitingConnectionsAreCancelable()
+        {
+            if (IsWinHttpHandler)
+            {
+                // Issue #27064:
+                // Throws WinHttpException ("The server returned an invalid or unrecognized response")
+                // while parsing headers.
+                return;
+            }
+            if (IsNetfxHandler)
+            {
+                // Throws HttpRequestException wrapping a WebException for the canceled request
+                // instead of throwing an OperationCanceledException or a canceled WebException directly.
+                return;
+            }
+
+            using (HttpClientHandler handler = CreateHttpClientHandler())
+            using (HttpClient client = new HttpClient(handler))
+            {
+                handler.MaxConnectionsPerServer = 1;
+                client.Timeout = Timeout.InfiniteTimeSpan;
+
+                await LoopbackServer.CreateServerAsync(async (server, url) =>
+                {
+                    var serverAboutToBlock = new TaskCompletionSource<bool>();
+                    var blockServerResponse = new TaskCompletionSource<bool>();
+
+                    Task serverTask1 = LoopbackServer.AcceptSocketAsync(server, async (socket1, stream1, reader1, writer1) =>
+                    {
+                        while (!string.IsNullOrEmpty(await reader1.ReadLineAsync()));
+                        await writer1.WriteAsync($"HTTP/1.1 200 OK\r\nDate: {DateTimeOffset.UtcNow:R}\r\n");
+                        serverAboutToBlock.SetResult(true);
+                        await blockServerResponse.Task;
+                        await writer1.WriteAsync("Content-Length: 5\r\n\r\nhello");
+                        return null;
+                    });
+
+                    Task get1 = client.GetAsync(url);
+                    await serverAboutToBlock.Task;
+
+                    var cts = new CancellationTokenSource();
+                    Task get2 = ValidateClientCancellationAsync(() => client.GetAsync(url, cts.Token));
+                    Task get3 = ValidateClientCancellationAsync(() => client.GetAsync(url, cts.Token));
+
+                    Task get4 = client.GetAsync(url);
+
+                    cts.Cancel();
+                    await get2;
+                    await get3;
+
+                    blockServerResponse.SetResult(true);
+                    await new[] { get1, serverTask1 }.WhenAllOrAnyFailed();
+
+                    Task serverTask4 = LoopbackServer.AcceptSocketAsync(server, async (socket2, stream2, reader2, writer2) =>
+                    {
+                        while (!string.IsNullOrEmpty(await reader2.ReadLineAsync()));
+                        await writer2.WriteAsync($"HTTP/1.1 200 OK\r\nDate: {DateTimeOffset.UtcNow:R}\r\nContent-Length: 0\r\n\r\n");
+                        return null;
+                    });
+
+                    await new[] { get4, serverTask4 }.WhenAllOrAnyFailed();
+                });
+            }
+        }
+
+        private async Task ValidateClientCancellationAsync(Func<Task> clientBodyAsync)
+        {
+            var stopwatch = Stopwatch.StartNew();
+            Exception error = await Record.ExceptionAsync(clientBodyAsync);
+            stopwatch.Stop();
+
+            Assert.NotNull(error);
+
+            if (IsNetfxHandler)
+            {
+                Assert.True(
+                    error is WebException we && we.Status == WebExceptionStatus.RequestCanceled ||
+                    error is OperationCanceledException,
+                    "Expected cancellation exception, got:" + Environment.NewLine + error);
+            }
+            else
+            {
+                Assert.True(
+                    error is OperationCanceledException,
+                    "Expected cancellation exception, got:" + Environment.NewLine + error);
+            }
+
+            Assert.True(stopwatch.Elapsed < new TimeSpan(0, 0, 30), $"Elapsed time {stopwatch.Elapsed} should be less than 30 seconds, was {stopwatch.Elapsed.TotalSeconds}");
+        }
+
+        private static void Cancel(CancellationMode mode, HttpClient client, CancellationTokenSource cts)
+        {
+            if ((mode & CancellationMode.Token) != 0)
+            {
+                cts?.Cancel();
+            }
+
+            if ((mode & CancellationMode.CancelPendingRequests) != 0)
+            {
+                client?.CancelPendingRequests();
+            }
+
+            if ((mode & CancellationMode.DisposeHttpClient) != 0)
+            {
+                client?.Dispose();
+            }
+        }
+
+        [Flags]
+        public enum CancellationMode
+        {
+            Token = 0x1,
+            CancelPendingRequests = 0x2,
+            DisposeHttpClient = 0x4
+        }
+
+        private static readonly bool[] s_bools = new[] { true, false };
+
+        public static IEnumerable<object[]> TwoBoolsAndCancellationMode() =>
+            from first in s_bools
+            from second in s_bools
+            from mode in new[] { CancellationMode.Token, CancellationMode.CancelPendingRequests, CancellationMode.DisposeHttpClient, CancellationMode.Token | CancellationMode.CancelPendingRequests }
+            select new object[] { first, second, mode };
+
+        public static IEnumerable<object[]> ThreeBools() =>
+            from first in s_bools
+            from second in s_bools
+            from third in s_bools
+            select new object[] { first, second, third };
+
+        private sealed class DelayedByteContent : HttpContent
+        {
+            private readonly TaskCompletionSource<bool> _preContentSent;
+            private readonly Task _waitToSendPostContent;
+
+            public DelayedByteContent(int preTriggerLength, int postTriggerLength, TaskCompletionSource<bool> preContentSent, Task waitToSendPostContent)
+            {
+                PreTriggerLength = preTriggerLength;
+                _preContentSent = preContentSent;
+                _waitToSendPostContent = waitToSendPostContent;
+                Content = new byte[preTriggerLength + postTriggerLength];
+                new Random().NextBytes(Content);
+            }
+
+            public byte[] Content { get; }
+            public int PreTriggerLength { get; }
+
+            protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context)
+            {
+                await stream.WriteAsync(Content, 0, PreTriggerLength);
+                _preContentSent.TrySetResult(true);
+                await _waitToSendPostContent;
+                await stream.WriteAsync(Content, PreTriggerLength, Content.Length - PreTriggerLength);
+            }
+
+            protected override bool TryComputeLength(out long length)
+            {
+                length = Content.Length;
+                return true;
+            }
+        }
+    }
+}
index fa8b3a1..eafb1b5 100644 (file)
@@ -125,16 +125,10 @@ namespace System.Net.Http.Functional.Tests
         protected override bool UseSocketsHttpHandler => true;
     }
 
-    // TODO #23141: Socket's don't support canceling individual operations, so ReadStream on NetworkStream
-    // isn't cancelable once the operation has started.  We either need to wrap the operation with one that's
-    // "cancelable", meaning that the underlying operation will still be running even though we've returned "canceled",
-    // or we need to just recognize that cancellation in such situations can be left up to the caller to do the
-    // same thing if it's really important.
-    //public sealed class SocketsHttpHandler_CancellationTest : CancellationTest
-    //{
-    //    public SocketsHttpHandler_CancellationTest(ITestOutputHelper output) : base(output) { }
-    //    protected override bool UseSocketsHttpHandler => true;
-    //}
+    public sealed class SocketsHttpHandler_HttpClientHandler_Cancellation_Test : HttpClientHandler_Cancellation_Test
+    {
+        protected override bool UseSocketsHttpHandler => true;
+    }
 
     public sealed class SocketsHttpHandler_HttpClientHandler_MaxResponseHeadersLength_Test : HttpClientHandler_MaxResponseHeadersLength_Test
     {
index 78bd665..578e3fa 100644 (file)
       <Link>Common\System\Threading\Tasks\TaskTimeoutExtensions.cs</Link>
     </Compile>
     <Compile Include="ByteArrayContentTest.cs" />
-    <Compile Include="CancellationTest.cs" />
     <Compile Include="ChannelBindingAwareContent.cs" />
     <Compile Include="CustomContent.cs" />
     <Compile Include="DelegatingHandlerTest.cs" />
     <Compile Include="FakeDiagnosticSourceListenerObserver.cs" />
     <Compile Include="FormUrlEncodedContentTest.cs" />
     <Compile Include="HttpClientHandlerTest.cs" />
+    <Compile Include="HttpClientHandlerTest.Cancellation.cs" />
     <Compile Include="HttpClientHandlerTest.ClientCertificates.cs" />
     <Compile Include="HttpClientHandlerTest.DefaultProxyCredentials.cs" />
     <Compile Include="HttpClientHandlerTest.MaxConnectionsPerServer.cs" />
     <TestCommandLines Include="ulimit -Sn 4096" />
   </ItemGroup>
   <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
-</Project>
\ No newline at end of file
+</Project>