From d2c2aba33e2f6d699fe49698afc411cdd607ad2e Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 15 Jul 2019 16:17:56 -0400 Subject: [PATCH] Ensure that CancellationToken is propagated to built-in request content types (dotnet/corefx#39474) * Ensure that CancellationToken is propagated to built-in request content types We don't currently expose the overload that would allow CancellationToken to be accessed by HttpContent-derived types in general, but we can at least ensure that when using the built-in content types, the CancellationToken provided to SendAsync is appropriately threaded through. We were doing this previously in only a few cases, where we knew that the previous overload wouldn't be overridden (namely internal types and sealed types), but we can enable the 90% scenario as well by doing a type check at run-time. * Disable cancellation test for UAP handler Commit migrated from https://github.com/dotnet/corefx/commit/450f49a1a80663529b31d3defafbd5e59822a16a --- .../src/System/Net/Http/WinHttpHandler.cs | 6 +- .../src/System/Net/Http/ByteArrayContent.cs | 17 ++-- .../src/System/Net/Http/FormUrlEncodedContent.cs | 11 ++- .../src/System/Net/Http/MultipartContent.cs | 23 ++++-- .../System/Net/Http/MultipartFormDataContent.cs | 9 ++ .../src/System/Net/Http/StreamContent.cs | 20 ++++- .../src/System/Net/Http/StreamToStreamCopy.cs | 50 ++++++------ .../src/System/Net/Http/StringContent.cs | 7 ++ .../HttpClientHandlerTest.Cancellation.cs | 95 +++++++++++++++++++++- 9 files changed, 193 insertions(+), 45 deletions(-) diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs index aa69c6e..575fbc9 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs @@ -1394,7 +1394,11 @@ namespace System.Net.Http { await state.RequestMessage.Content.CopyToAsync( requestStream, - state.TransportContext).ConfigureAwait(false); + state.TransportContext +#if HTTP_DLL + , state.CancellationToken +#endif + ).ConfigureAwait(false); await requestStream.EndUploadAsync(state.CancellationToken).ConfigureAwait(false); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/ByteArrayContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/ByteArrayContent.cs index 0002534..7458a64 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/ByteArrayContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/ByteArrayContent.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Http @@ -50,11 +51,17 @@ namespace System.Net.Http SetBuffer(_content, _offset, _count); } - protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) - { - Debug.Assert(stream != null); - return stream.WriteAsync(_content, _offset, _count); - } + protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) => + SerializeToStreamAsyncCore(stream, default); + + internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => + // Only skip the original protected virtual SerializeToStreamAsync if this + // isn't a derived type that may have overridden the behavior. + GetType() == typeof(ByteArrayContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) : + base.SerializeToStreamAsync(stream, context, cancellationToken); + + private protected Task SerializeToStreamAsyncCore(Stream stream, CancellationToken cancellationToken) => + stream.WriteAsync(_content, _offset, _count, cancellationToken); protected internal override bool TryComputeLength(out long length) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/FormUrlEncodedContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/FormUrlEncodedContent.cs index 5938833..249b671 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/FormUrlEncodedContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/FormUrlEncodedContent.cs @@ -3,10 +3,11 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; -using System.Text; -using System.Threading.Tasks; using System.IO; using System.Net.Http.Headers; +using System.Text; +using System.Threading; +using System.Threading.Tasks; namespace System.Net.Http { @@ -52,6 +53,12 @@ namespace System.Net.Http return Uri.EscapeDataString(data).Replace("%20", "+"); } + internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => + // Only skip the original protected virtual SerializeToStreamAsync if this + // isn't a derived type that may have overridden the behavior. + GetType() == typeof(FormUrlEncodedContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) : + base.SerializeToStreamAsync(stream, context, cancellationToken); + internal override Stream TryCreateContentReadStream() => GetType() == typeof(FormUrlEncodedContent) ? CreateMemoryStreamForByteArray() : // type check ensures we use possible derived type's CreateContentReadStreamAsync override null; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs index 7f36bfa..16f903c 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs @@ -169,13 +169,22 @@ namespace System.Net.Http // write "--" + boundary + "--" // Can't be canceled directly by the user. If the overall request is canceled // then the stream will be closed an exception thrown. - protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context) + protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) => + SerializeToStreamAsyncCore(stream, context, default); + + internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => + // Only skip the original protected virtual SerializeToStreamAsync if this + // isn't a derived type that may have overridden the behavior. + GetType() == typeof(MultipartContent) ? SerializeToStreamAsyncCore(stream, context, cancellationToken) : + base.SerializeToStreamAsync(stream, context, cancellationToken); + + private protected async Task SerializeToStreamAsyncCore(Stream stream, TransportContext context, CancellationToken cancellationToken) { Debug.Assert(stream != null); try { // Write start boundary. - await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf).ConfigureAwait(false); + await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false); // Write each nested content. var output = new StringBuilder(); @@ -183,12 +192,12 @@ namespace System.Net.Http { // Write divider, headers, and content. HttpContent content = _nestedContent[contentIndex]; - await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content)).ConfigureAwait(false); - await content.CopyToAsync(stream).ConfigureAwait(false); + await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content), cancellationToken).ConfigureAwait(false); + await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false); } // Write footer boundary. - await EncodeStringToStreamAsync(stream, CrLf + "--" + _boundary + "--" + CrLf).ConfigureAwait(false); + await EncodeStringToStreamAsync(stream, CrLf + "--" + _boundary + "--" + CrLf, cancellationToken).ConfigureAwait(false); } catch (Exception ex) { @@ -271,10 +280,10 @@ namespace System.Net.Http return scratch.ToString(); } - private static ValueTask EncodeStringToStreamAsync(Stream stream, string input) + private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken) { byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input); - return stream.WriteAsync(new ReadOnlyMemory(buffer)); + return stream.WriteAsync(new ReadOnlyMemory(buffer), cancellationToken); } private static Stream EncodeStringToNewStream(string input) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartFormDataContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartFormDataContent.cs index 3d601d6..7f52f66 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartFormDataContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartFormDataContent.cs @@ -3,7 +3,10 @@ // See the LICENSE file in the project root for more information. using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Net.Http.Headers; +using System.Threading; +using System.Threading.Tasks; namespace System.Net.Http { @@ -84,5 +87,11 @@ namespace System.Net.Http } base.Add(content); } + + internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => + // Only skip the original protected virtual SerializeToStreamAsync if this + // isn't a derived type that may have overridden the behavior. + GetType() == typeof(MultipartFormDataContent) ? SerializeToStreamAsyncCore(stream, context, cancellationToken) : + base.SerializeToStreamAsync(stream, context, cancellationToken); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/StreamContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/StreamContent.cs index 0920b30..8970cdb 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/StreamContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/StreamContent.cs @@ -52,13 +52,25 @@ namespace System.Net.Http if (NetEventSource.IsEnabled) NetEventSource.Associate(this, content); } - protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) + protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) => + SerializeToStreamAsyncCore(stream, default); + + internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => + // Only skip the original protected virtual SerializeToStreamAsync if this + // isn't a derived type that may have overridden the behavior. + GetType() == typeof(StreamContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) : + base.SerializeToStreamAsync(stream, context, cancellationToken); + + private Task SerializeToStreamAsyncCore(Stream stream, CancellationToken cancellationToken) { Debug.Assert(stream != null); - PrepareContent(); - // If the stream can't be re-read, make sure that it gets disposed once it is consumed. - return StreamToStreamCopy.CopyAsync(_content, stream, _bufferSize, !_content.CanSeek); + return StreamToStreamCopy.CopyAsync( + _content, + stream, + _bufferSize, + !_content.CanSeek, // If the stream can't be re-read, make sure that it gets disposed once it is consumed. + cancellationToken); } protected internal override bool TryComputeLength(out long length) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/StreamToStreamCopy.cs b/src/libraries/System.Net.Http/src/System/Net/Http/StreamToStreamCopy.cs index a8363b7..5b3abea 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/StreamToStreamCopy.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/StreamToStreamCopy.cs @@ -33,9 +33,31 @@ namespace System.Net.Http Task copyTask = bufferSize == 0 ? source.CopyToAsync(destination, cancellationToken) : source.CopyToAsync(destination, bufferSize, cancellationToken); - return disposeSource ? - DisposeSourceWhenCompleteAsync(copyTask, source) : - copyTask; + + if (!disposeSource) + { + return copyTask; + } + + switch (copyTask.Status) + { + case TaskStatus.RanToCompletion: + DisposeSource(source); + return Task.CompletedTask; + + case TaskStatus.Faulted: + case TaskStatus.Canceled: + return copyTask; + + default: + return DisposeSourceAsync(copyTask, source); + + static async Task DisposeSourceAsync(Task copyTask, Stream source) + { + await copyTask.ConfigureAwait(false); + DisposeSource(source); + } + } } catch (Exception e) { @@ -45,28 +67,6 @@ namespace System.Net.Http } } - private static Task DisposeSourceWhenCompleteAsync(Task task, Stream source) - { - switch (task.Status) - { - case TaskStatus.RanToCompletion: - DisposeSource(source); - return Task.CompletedTask; - - case TaskStatus.Faulted: - case TaskStatus.Canceled: - return task; - - default: - return task.ContinueWith((completed, innerSource) => - { - completed.GetAwaiter().GetResult(); // propagate any exceptions - DisposeSource((Stream)innerSource); - }, - source, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default); - } - } - /// Disposes the source stream. private static void DisposeSource(Stream source) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/StringContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/StringContent.cs index 6f2915e..8890a66 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/StringContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/StringContent.cs @@ -5,6 +5,7 @@ using System.IO; using System.Net.Http.Headers; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Http @@ -53,6 +54,12 @@ namespace System.Net.Http return encoding.GetBytes(content); } + internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => + // Only skip the original protected virtual SerializeToStreamAsync if this + // isn't a derived type that may have overridden the behavior. + GetType() == typeof(StringContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) : + base.SerializeToStreamAsync(stream, context, cancellationToken); + internal override Stream TryCreateContentReadStream() => GetType() == typeof(StringContent) ? CreateMemoryStreamForByteArray() : // type check ensures we use possible derived type's CreateContentReadStreamAsync override null; diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs index 0704d3f..5ecfc60 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs @@ -4,9 +4,10 @@ using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.IO; +using System.Linq; using System.Net.Test.Common; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -401,6 +402,98 @@ namespace System.Net.Http.Functional.Tests }); } + public static IEnumerable PostAsync_Cancel_CancellationTokenPassedToContent_MemberData() + { + foreach (CancellationToken expectedToken in new[] { CancellationToken.None, new CancellationTokenSource().Token }) + { + // StreamContent + { + var actualToken = new StrongBox(); + bool called = false; + var content = new StreamContent(new DelegateStream( + canReadFunc: () => true, + readAsyncFunc: (buffer, offset, count, cancellationToken) => + { + actualToken.Value = cancellationToken; + int result = called ? 0 : 1; + called = true; + return Task.FromResult(result); + } + )); + yield return new object[] { content, expectedToken, actualToken }; + } + + // MultipartContent + { + var actualToken = new StrongBox(); + bool called = false; + var content = new MultipartContent(); + content.Add(new StreamContent(new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + lengthFunc: () => 1, + positionGetFunc: () => 0, + positionSetFunc: _ => {}, + readAsyncFunc: (buffer, offset, count, cancellationToken) => + { + actualToken.Value = cancellationToken; + int result = called ? 0 : 1; + called = true; + return Task.FromResult(result); + } + ))); + yield return new object[] { content, expectedToken, actualToken }; + } + + // MultipartFormDataContent + { + var actualToken = new StrongBox(); + bool called = false; + var content = new MultipartFormDataContent(); + content.Add(new StreamContent(new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + lengthFunc: () => 1, + positionGetFunc: () => 0, + positionSetFunc: _ => {}, + readAsyncFunc: (buffer, offset, count, cancellationToken) => + { + actualToken.Value = cancellationToken; + int result = called ? 0 : 1; + called = true; + return Task.FromResult(result); + } + ))); + yield return new object[] { content, expectedToken, actualToken }; + } + } + } + + [Theory] + [MemberData(nameof(PostAsync_Cancel_CancellationTokenPassedToContent_MemberData))] + public async Task PostAsync_Cancel_CancellationTokenPassedToContent(HttpContent content, CancellationToken expectedToken, StrongBox actualToken) + { + if (IsUapHandler) + { + // HttpHandlerToFilter doesn't flow the token into the request body. + return; + } + + await LoopbackServerFactory.CreateClientAndServerAsync( + async uri => + { + using (var invoker = new HttpMessageInvoker(CreateHttpClientHandler())) + using (var req = new HttpRequestMessage(HttpMethod.Post, uri) { Content = content, Version = VersionFromUseHttp2 }) + using (HttpResponseMessage resp = await invoker.SendAsync(req, expectedToken)) + { + Assert.Equal("Hello World", await resp.Content.ReadAsStringAsync()); + } + }, + server => server.HandleRequestAsync(content: "Hello World")); + + Assert.Equal(expectedToken, actualToken.Value); + } + private async Task ValidateClientCancellationAsync(Func clientBodyAsync) { var stopwatch = Stopwatch.StartNew(); -- 2.7.4