{
await state.RequestMessage.Content.CopyToAsync(
requestStream,
- state.TransportContext).ConfigureAwait(false);
+ state.TransportContext
+#if HTTP_DLL
+ , state.CancellationToken
+#endif
+ ).ConfigureAwait(false);
await requestStream.EndUploadAsync(state.CancellationToken).ConfigureAwait(false);
}
}
using System.Diagnostics;
using System.IO;
+using System.Threading;
using System.Threading.Tasks;
namespace System.Net.Http
SetBuffer(_content, _offset, _count);
}
- protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
- {
- Debug.Assert(stream != null);
- return stream.WriteAsync(_content, _offset, _count);
- }
+ protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+ SerializeToStreamAsyncCore(stream, default);
+
+ internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+ // Only skip the original protected virtual SerializeToStreamAsync if this
+ // isn't a derived type that may have overridden the behavior.
+ GetType() == typeof(ByteArrayContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) :
+ base.SerializeToStreamAsync(stream, context, cancellationToken);
+
+ private protected Task SerializeToStreamAsyncCore(Stream stream, CancellationToken cancellationToken) =>
+ stream.WriteAsync(_content, _offset, _count, cancellationToken);
protected internal override bool TryComputeLength(out long length)
{
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
-using System.Text;
-using System.Threading.Tasks;
using System.IO;
using System.Net.Http.Headers;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
namespace System.Net.Http
{
return Uri.EscapeDataString(data).Replace("%20", "+");
}
+ internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+ // Only skip the original protected virtual SerializeToStreamAsync if this
+ // isn't a derived type that may have overridden the behavior.
+ GetType() == typeof(FormUrlEncodedContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) :
+ base.SerializeToStreamAsync(stream, context, cancellationToken);
+
internal override Stream TryCreateContentReadStream() =>
GetType() == typeof(FormUrlEncodedContent) ? CreateMemoryStreamForByteArray() : // type check ensures we use possible derived type's CreateContentReadStreamAsync override
null;
// write "--" + boundary + "--"
// Can't be canceled directly by the user. If the overall request is canceled
// then the stream will be closed an exception thrown.
- protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context)
+ protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+ SerializeToStreamAsyncCore(stream, context, default);
+
+ internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+ // Only skip the original protected virtual SerializeToStreamAsync if this
+ // isn't a derived type that may have overridden the behavior.
+ GetType() == typeof(MultipartContent) ? SerializeToStreamAsyncCore(stream, context, cancellationToken) :
+ base.SerializeToStreamAsync(stream, context, cancellationToken);
+
+ private protected async Task SerializeToStreamAsyncCore(Stream stream, TransportContext context, CancellationToken cancellationToken)
{
Debug.Assert(stream != null);
try
{
// Write start boundary.
- await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf).ConfigureAwait(false);
+ await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false);
// Write each nested content.
var output = new StringBuilder();
{
// Write divider, headers, and content.
HttpContent content = _nestedContent[contentIndex];
- await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content)).ConfigureAwait(false);
- await content.CopyToAsync(stream).ConfigureAwait(false);
+ await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content), cancellationToken).ConfigureAwait(false);
+ await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false);
}
// Write footer boundary.
- await EncodeStringToStreamAsync(stream, CrLf + "--" + _boundary + "--" + CrLf).ConfigureAwait(false);
+ await EncodeStringToStreamAsync(stream, CrLf + "--" + _boundary + "--" + CrLf, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
return scratch.ToString();
}
- private static ValueTask EncodeStringToStreamAsync(Stream stream, string input)
+ private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken)
{
byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input);
- return stream.WriteAsync(new ReadOnlyMemory<byte>(buffer));
+ return stream.WriteAsync(new ReadOnlyMemory<byte>(buffer), cancellationToken);
}
private static Stream EncodeStringToNewStream(string input)
// See the LICENSE file in the project root for more information.
using System.Diagnostics.CodeAnalysis;
+using System.IO;
using System.Net.Http.Headers;
+using System.Threading;
+using System.Threading.Tasks;
namespace System.Net.Http
{
}
base.Add(content);
}
+
+ internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+ // Only skip the original protected virtual SerializeToStreamAsync if this
+ // isn't a derived type that may have overridden the behavior.
+ GetType() == typeof(MultipartFormDataContent) ? SerializeToStreamAsyncCore(stream, context, cancellationToken) :
+ base.SerializeToStreamAsync(stream, context, cancellationToken);
}
}
if (NetEventSource.IsEnabled) NetEventSource.Associate(this, content);
}
- protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
+ protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) =>
+ SerializeToStreamAsyncCore(stream, default);
+
+ internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+ // Only skip the original protected virtual SerializeToStreamAsync if this
+ // isn't a derived type that may have overridden the behavior.
+ GetType() == typeof(StreamContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) :
+ base.SerializeToStreamAsync(stream, context, cancellationToken);
+
+ private Task SerializeToStreamAsyncCore(Stream stream, CancellationToken cancellationToken)
{
Debug.Assert(stream != null);
-
PrepareContent();
- // If the stream can't be re-read, make sure that it gets disposed once it is consumed.
- return StreamToStreamCopy.CopyAsync(_content, stream, _bufferSize, !_content.CanSeek);
+ return StreamToStreamCopy.CopyAsync(
+ _content,
+ stream,
+ _bufferSize,
+ !_content.CanSeek, // If the stream can't be re-read, make sure that it gets disposed once it is consumed.
+ cancellationToken);
}
protected internal override bool TryComputeLength(out long length)
Task copyTask = bufferSize == 0 ?
source.CopyToAsync(destination, cancellationToken) :
source.CopyToAsync(destination, bufferSize, cancellationToken);
- return disposeSource ?
- DisposeSourceWhenCompleteAsync(copyTask, source) :
- copyTask;
+
+ if (!disposeSource)
+ {
+ return copyTask;
+ }
+
+ switch (copyTask.Status)
+ {
+ case TaskStatus.RanToCompletion:
+ DisposeSource(source);
+ return Task.CompletedTask;
+
+ case TaskStatus.Faulted:
+ case TaskStatus.Canceled:
+ return copyTask;
+
+ default:
+ return DisposeSourceAsync(copyTask, source);
+
+ static async Task DisposeSourceAsync(Task copyTask, Stream source)
+ {
+ await copyTask.ConfigureAwait(false);
+ DisposeSource(source);
+ }
+ }
}
catch (Exception e)
{
}
}
- private static Task DisposeSourceWhenCompleteAsync(Task task, Stream source)
- {
- switch (task.Status)
- {
- case TaskStatus.RanToCompletion:
- DisposeSource(source);
- return Task.CompletedTask;
-
- case TaskStatus.Faulted:
- case TaskStatus.Canceled:
- return task;
-
- default:
- return task.ContinueWith((completed, innerSource) =>
- {
- completed.GetAwaiter().GetResult(); // propagate any exceptions
- DisposeSource((Stream)innerSource);
- },
- source, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default);
- }
- }
-
/// <summary>Disposes the source stream.</summary>
private static void DisposeSource(Stream source)
{
using System.IO;
using System.Net.Http.Headers;
using System.Text;
+using System.Threading;
using System.Threading.Tasks;
namespace System.Net.Http
return encoding.GetBytes(content);
}
+ internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) =>
+ // Only skip the original protected virtual SerializeToStreamAsync if this
+ // isn't a derived type that may have overridden the behavior.
+ GetType() == typeof(StringContent) ? SerializeToStreamAsyncCore(stream, cancellationToken) :
+ base.SerializeToStreamAsync(stream, context, cancellationToken);
+
internal override Stream TryCreateContentReadStream() =>
GetType() == typeof(StringContent) ? CreateMemoryStreamForByteArray() : // type check ensures we use possible derived type's CreateContentReadStreamAsync override
null;
using System.Collections.Generic;
using System.Diagnostics;
-using System.Linq;
using System.IO;
+using System.Linq;
using System.Net.Test.Common;
+using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
});
}
+ public static IEnumerable<object[]> PostAsync_Cancel_CancellationTokenPassedToContent_MemberData()
+ {
+ foreach (CancellationToken expectedToken in new[] { CancellationToken.None, new CancellationTokenSource().Token })
+ {
+ // StreamContent
+ {
+ var actualToken = new StrongBox<CancellationToken>();
+ bool called = false;
+ var content = new StreamContent(new DelegateStream(
+ canReadFunc: () => true,
+ readAsyncFunc: (buffer, offset, count, cancellationToken) =>
+ {
+ actualToken.Value = cancellationToken;
+ int result = called ? 0 : 1;
+ called = true;
+ return Task.FromResult(result);
+ }
+ ));
+ yield return new object[] { content, expectedToken, actualToken };
+ }
+
+ // MultipartContent
+ {
+ var actualToken = new StrongBox<CancellationToken>();
+ bool called = false;
+ var content = new MultipartContent();
+ content.Add(new StreamContent(new DelegateStream(
+ canReadFunc: () => true,
+ canSeekFunc: () => true,
+ lengthFunc: () => 1,
+ positionGetFunc: () => 0,
+ positionSetFunc: _ => {},
+ readAsyncFunc: (buffer, offset, count, cancellationToken) =>
+ {
+ actualToken.Value = cancellationToken;
+ int result = called ? 0 : 1;
+ called = true;
+ return Task.FromResult(result);
+ }
+ )));
+ yield return new object[] { content, expectedToken, actualToken };
+ }
+
+ // MultipartFormDataContent
+ {
+ var actualToken = new StrongBox<CancellationToken>();
+ bool called = false;
+ var content = new MultipartFormDataContent();
+ content.Add(new StreamContent(new DelegateStream(
+ canReadFunc: () => true,
+ canSeekFunc: () => true,
+ lengthFunc: () => 1,
+ positionGetFunc: () => 0,
+ positionSetFunc: _ => {},
+ readAsyncFunc: (buffer, offset, count, cancellationToken) =>
+ {
+ actualToken.Value = cancellationToken;
+ int result = called ? 0 : 1;
+ called = true;
+ return Task.FromResult(result);
+ }
+ )));
+ yield return new object[] { content, expectedToken, actualToken };
+ }
+ }
+ }
+
+ [Theory]
+ [MemberData(nameof(PostAsync_Cancel_CancellationTokenPassedToContent_MemberData))]
+ public async Task PostAsync_Cancel_CancellationTokenPassedToContent(HttpContent content, CancellationToken expectedToken, StrongBox<CancellationToken> actualToken)
+ {
+ if (IsUapHandler)
+ {
+ // HttpHandlerToFilter doesn't flow the token into the request body.
+ return;
+ }
+
+ await LoopbackServerFactory.CreateClientAndServerAsync(
+ async uri =>
+ {
+ using (var invoker = new HttpMessageInvoker(CreateHttpClientHandler()))
+ using (var req = new HttpRequestMessage(HttpMethod.Post, uri) { Content = content, Version = VersionFromUseHttp2 })
+ using (HttpResponseMessage resp = await invoker.SendAsync(req, expectedToken))
+ {
+ Assert.Equal("Hello World", await resp.Content.ReadAsStringAsync());
+ }
+ },
+ server => server.HandleRequestAsync(content: "Hello World"));
+
+ Assert.Equal(expectedToken, actualToken.Value);
+ }
+
private async Task ValidateClientCancellationAsync(Func<Task> clientBodyAsync)
{
var stopwatch = Stopwatch.StartNew();