Ensure that CancellationToken is propagated to built-in request content types (dotnet...
authorStephen Toub <stoub@microsoft.com>
Mon, 15 Jul 2019 20:17:56 +0000 (16:17 -0400)
committerGitHub <noreply@github.com>
Mon, 15 Jul 2019 20:17:56 +0000 (16:17 -0400)
* Ensure that CancellationToken is propagated to built-in request content types

We don't currently expose the overload that would allow CancellationToken to be accessed by HttpContent-derived types in general, but we can at least ensure that when using the built-in content types, the CancellationToken provided to SendAsync is appropriately threaded through.  We were doing this previously in only a few cases, where we knew that the previous overload wouldn't be overridden (namely internal types and sealed types), but we can enable the 90% scenario as well by doing a type check at run-time.

* Disable cancellation test for UAP handler

Commit migrated from https://github.com/dotnet/corefx/commit/450f49a1a80663529b31d3defafbd5e59822a16a

src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/ByteArrayContent.cs
src/libraries/System.Net.Http/src/System/Net/Http/FormUrlEncodedContent.cs
src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs
src/libraries/System.Net.Http/src/System/Net/Http/MultipartFormDataContent.cs
src/libraries/System.Net.Http/src/System/Net/Http/StreamContent.cs
src/libraries/System.Net.Http/src/System/Net/Http/StreamToStreamCopy.cs
src/libraries/System.Net.Http/src/System/Net/Http/StringContent.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs

index aa69c6e..575fbc9 100644 (file)
@@ -1394,7 +1394,11 @@ namespace System.Net.Http
             {
                 await state.RequestMessage.Content.CopyToAsync(
                     requestStream,
-                    state.TransportContext).ConfigureAwait(false);
+                    state.TransportContext
+#if HTTP_DLL
+                    , state.CancellationToken
+#endif
+                    ).ConfigureAwait(false);
                 await requestStream.EndUploadAsync(state.CancellationToken).ConfigureAwait(false);
             }
         }
index 0002534..7458a64 100644 (file)
@@ -4,6 +4,7 @@
 
 using System.Diagnostics;
 using System.IO;
+using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Net.Http
@@ -50,11 +51,17 @@ namespace System.Net.Http
             SetBuffer(_content, _offset, _count);
         }
 
-        protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
-        {
-            Debug.Assert(stream != null);
-            return stream.WriteAsync(_content, _offset, _count);
-        }
+        protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+            SerializeToStreamAsyncCore(stream, default);
+
+        internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+            // Only skip the original protected virtual SerializeToStreamAsync if this
+            // isn't a derived type that may have overridden the behavior.
+            GetType() == typeof(ByteArrayContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) :
+            base.SerializeToStreamAsync(stream, context, cancellationToken);
+
+        private protected Task SerializeToStreamAsyncCore(Stream stream, CancellationToken cancellationToken) =>
+            stream.WriteAsync(_content, _offset, _count, cancellationToken);
 
         protected internal override bool TryComputeLength(out long length)
         {
index 5938833..249b671 100644 (file)
@@ -3,10 +3,11 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Collections.Generic;
-using System.Text;
-using System.Threading.Tasks;
 using System.IO;
 using System.Net.Http.Headers;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Net.Http
 {
@@ -52,6 +53,12 @@ namespace System.Net.Http
             return Uri.EscapeDataString(data).Replace("%20", "+");
         }
 
+        internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+            // Only skip the original protected virtual SerializeToStreamAsync if this
+            // isn't a derived type that may have overridden the behavior.
+            GetType() == typeof(FormUrlEncodedContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) :
+            base.SerializeToStreamAsync(stream, context, cancellationToken);
+
         internal override Stream TryCreateContentReadStream() =>
             GetType() == typeof(FormUrlEncodedContent) ? CreateMemoryStreamForByteArray() : // type check ensures we use possible derived type's CreateContentReadStreamAsync override
             null;
index 7f36bfa..16f903c 100644 (file)
@@ -169,13 +169,22 @@ namespace System.Net.Http
         // write "--" + boundary + "--"
         // Can't be canceled directly by the user.  If the overall request is canceled 
         // then the stream will be closed an exception thrown.
-        protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context)
+        protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+            SerializeToStreamAsyncCore(stream, context, default);
+
+        internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+            // Only skip the original protected virtual SerializeToStreamAsync if this
+            // isn't a derived type that may have overridden the behavior.
+            GetType() == typeof(MultipartContent) ? SerializeToStreamAsyncCore(stream, context, cancellationToken) :
+            base.SerializeToStreamAsync(stream, context, cancellationToken);
+
+        private protected async Task SerializeToStreamAsyncCore(Stream stream, TransportContext context, CancellationToken cancellationToken)
         {
             Debug.Assert(stream != null);
             try
             {
                 // Write start boundary.
-                await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf).ConfigureAwait(false);
+                await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false);
 
                 // Write each nested content.
                 var output = new StringBuilder();
@@ -183,12 +192,12 @@ namespace System.Net.Http
                 {
                     // Write divider, headers, and content.
                     HttpContent content = _nestedContent[contentIndex];
-                    await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content)).ConfigureAwait(false);
-                    await content.CopyToAsync(stream).ConfigureAwait(false);
+                    await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content), cancellationToken).ConfigureAwait(false);
+                    await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false);
                 }
 
                 // Write footer boundary.
-                await EncodeStringToStreamAsync(stream, CrLf + "--" + _boundary + "--" + CrLf).ConfigureAwait(false);
+                await EncodeStringToStreamAsync(stream, CrLf + "--" + _boundary + "--" + CrLf, cancellationToken).ConfigureAwait(false);
             }
             catch (Exception ex)
             {
@@ -271,10 +280,10 @@ namespace System.Net.Http
             return scratch.ToString();
         }
 
-        private static ValueTask EncodeStringToStreamAsync(Stream stream, string input)
+        private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken)
         {
             byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input);
-            return stream.WriteAsync(new ReadOnlyMemory<byte>(buffer));
+            return stream.WriteAsync(new ReadOnlyMemory<byte>(buffer), cancellationToken);
         }
 
         private static Stream EncodeStringToNewStream(string input)
index 3d601d6..7f52f66 100644 (file)
@@ -3,7 +3,10 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Diagnostics.CodeAnalysis;
+using System.IO;
 using System.Net.Http.Headers;
+using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Net.Http
 {
@@ -84,5 +87,11 @@ namespace System.Net.Http
             }
             base.Add(content);
         }
+
+        internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+            // Only skip the original protected virtual SerializeToStreamAsync if this
+            // isn't a derived type that may have overridden the behavior.
+            GetType() == typeof(MultipartFormDataContent) ? SerializeToStreamAsyncCore(stream, context, cancellationToken) :
+            base.SerializeToStreamAsync(stream, context, cancellationToken);
     }
 }
index 0920b30..8970cdb 100644 (file)
@@ -52,13 +52,25 @@ namespace System.Net.Http
             if (NetEventSource.IsEnabled) NetEventSource.Associate(this, content);
         }
 
-        protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
+        protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+            SerializeToStreamAsyncCore(stream, default);
+
+        internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+            // Only skip the original protected virtual SerializeToStreamAsync if this
+            // isn't a derived type that may have overridden the behavior.
+            GetType() == typeof(StreamContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) :
+            base.SerializeToStreamAsync(stream, context, cancellationToken);
+
+        private Task SerializeToStreamAsyncCore(Stream stream, CancellationToken cancellationToken)
         {
             Debug.Assert(stream != null);
-
             PrepareContent();
-            // If the stream can't be re-read, make sure that it gets disposed once it is consumed.
-            return StreamToStreamCopy.CopyAsync(_content, stream, _bufferSize, !_content.CanSeek);
+            return StreamToStreamCopy.CopyAsync(
+                _content,
+                stream,
+                _bufferSize,
+                !_content.CanSeek, // If the stream can't be re-read, make sure that it gets disposed once it is consumed.
+                cancellationToken);
         }
 
         protected internal override bool TryComputeLength(out long length)
index a8363b7..5b3abea 100644 (file)
@@ -33,9 +33,31 @@ namespace System.Net.Http
                 Task copyTask = bufferSize == 0 ?
                     source.CopyToAsync(destination, cancellationToken) :
                     source.CopyToAsync(destination, bufferSize, cancellationToken);
-                return disposeSource ?
-                    DisposeSourceWhenCompleteAsync(copyTask, source) :
-                    copyTask;
+
+                if (!disposeSource)
+                {
+                    return copyTask;
+                }
+
+                switch (copyTask.Status)
+                {
+                    case TaskStatus.RanToCompletion:
+                        DisposeSource(source);
+                        return Task.CompletedTask;
+
+                    case TaskStatus.Faulted:
+                    case TaskStatus.Canceled:
+                        return copyTask;
+
+                    default:
+                        return DisposeSourceAsync(copyTask, source);
+
+                        static async Task DisposeSourceAsync(Task copyTask, Stream source)
+                        {
+                            await copyTask.ConfigureAwait(false);
+                            DisposeSource(source);
+                        }
+                }
             }
             catch (Exception e)
             {
@@ -45,28 +67,6 @@ namespace System.Net.Http
             }
         }
 
-        private static Task DisposeSourceWhenCompleteAsync(Task task, Stream source)
-        {
-            switch (task.Status)
-            {
-                case TaskStatus.RanToCompletion:
-                    DisposeSource(source);
-                    return Task.CompletedTask;
-
-                case TaskStatus.Faulted:
-                case TaskStatus.Canceled:
-                    return task;
-
-                default:
-                    return task.ContinueWith((completed, innerSource) =>
-                    {
-                        completed.GetAwaiter().GetResult(); // propagate any exceptions
-                        DisposeSource((Stream)innerSource);
-                    },
-                    source, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default);
-            }
-        }
-
         /// <summary>Disposes the source stream.</summary>
         private static void DisposeSource(Stream source)
         {
index 6f2915e..8890a66 100644 (file)
@@ -5,6 +5,7 @@
 using System.IO;
 using System.Net.Http.Headers;
 using System.Text;
+using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Net.Http
@@ -53,6 +54,12 @@ namespace System.Net.Http
             return encoding.GetBytes(content);
         }
 
+        internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+            // Only skip the original protected virtual SerializeToStreamAsync if this
+            // isn't a derived type that may have overridden the behavior.
+            GetType() == typeof(StringContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) :
+            base.SerializeToStreamAsync(stream, context, cancellationToken);
+
         internal override Stream TryCreateContentReadStream() =>
             GetType() == typeof(StringContent) ? CreateMemoryStreamForByteArray() : // type check ensures we use possible derived type's CreateContentReadStreamAsync override
             null;
index 0704d3f..5ecfc60 100644 (file)
@@ -4,9 +4,10 @@
 
 using System.Collections.Generic;
 using System.Diagnostics;
-using System.Linq;
 using System.IO;
+using System.Linq;
 using System.Net.Test.Common;
+using System.Runtime.CompilerServices;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -401,6 +402,98 @@ namespace System.Net.Http.Functional.Tests
                 });
         }
 
+        public static IEnumerable<object[]> PostAsync_Cancel_CancellationTokenPassedToContent_MemberData()
+        {
+            foreach (CancellationToken expectedToken in new[] { CancellationToken.None, new CancellationTokenSource().Token })
+            {
+                // StreamContent
+                {
+                    var actualToken = new StrongBox<CancellationToken>();
+                    bool called = false;
+                    var content = new StreamContent(new DelegateStream(
+                        canReadFunc: () => true,
+                        readAsyncFunc: (buffer, offset, count, cancellationToken) =>
+                        {
+                            actualToken.Value = cancellationToken;
+                            int result = called ? 0 : 1;
+                            called = true;
+                            return Task.FromResult(result);
+                        }
+                    ));
+                    yield return new object[] { content, expectedToken, actualToken };
+                }
+
+                // MultipartContent
+                {
+                    var actualToken = new StrongBox<CancellationToken>();
+                    bool called = false;
+                    var content = new MultipartContent();
+                    content.Add(new StreamContent(new DelegateStream(
+                        canReadFunc: () => true,
+                        canSeekFunc: () => true,
+                        lengthFunc: () => 1,
+                        positionGetFunc: () => 0,
+                        positionSetFunc: _ => {},
+                        readAsyncFunc: (buffer, offset, count, cancellationToken) =>
+                        {
+                            actualToken.Value = cancellationToken;
+                            int result = called ? 0 : 1;
+                            called = true;
+                            return Task.FromResult(result);
+                        }
+                    )));
+                    yield return new object[] { content, expectedToken, actualToken };
+                }
+
+                // MultipartFormDataContent
+                {
+                    var actualToken = new StrongBox<CancellationToken>();
+                    bool called = false;
+                    var content = new MultipartFormDataContent();
+                    content.Add(new StreamContent(new DelegateStream(
+                        canReadFunc: () => true,
+                        canSeekFunc: () => true,
+                        lengthFunc: () => 1,
+                        positionGetFunc: () => 0,
+                        positionSetFunc: _ => {},
+                        readAsyncFunc: (buffer, offset, count, cancellationToken) =>
+                        {
+                            actualToken.Value = cancellationToken;
+                            int result = called ? 0 : 1;
+                            called = true;
+                            return Task.FromResult(result);
+                        }
+                    )));
+                    yield return new object[] { content, expectedToken, actualToken };
+                }
+            }
+        }
+
+        [Theory]
+        [MemberData(nameof(PostAsync_Cancel_CancellationTokenPassedToContent_MemberData))]
+        public async Task PostAsync_Cancel_CancellationTokenPassedToContent(HttpContent content, CancellationToken expectedToken, StrongBox<CancellationToken> actualToken)
+        {
+            if (IsUapHandler)
+            {
+                // HttpHandlerToFilter doesn't flow the token into the request body.
+                return;
+            }
+
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    using (var invoker = new HttpMessageInvoker(CreateHttpClientHandler()))
+                    using (var req = new HttpRequestMessage(HttpMethod.Post, uri) { Content = content, Version = VersionFromUseHttp2 })
+                    using (HttpResponseMessage resp = await invoker.SendAsync(req, expectedToken))
+                    {
+                        Assert.Equal("Hello World", await resp.Content.ReadAsStringAsync());
+                    }
+                },
+                server => server.HandleRequestAsync(content: "Hello World"));
+
+            Assert.Equal(expectedToken, actualToken.Value);
+        }
+
         private async Task ValidateClientCancellationAsync(Func<Task> clientBodyAsync)
         {
             var stopwatch = Stopwatch.StartNew();