From d59bb4ad4fe2ffe77b17d9a62ba35dc5c3767190 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 30 Apr 2019 14:58:25 -0400 Subject: [PATCH] Harden several SocketsHttpHandler streams against unexpected Dispose (dotnet/corefx#37299) A few tweaks: - We hand out the ChunkedEncodingWriteStream and ContentLengthWriteStream to an HttpContent.SerializeToStreamAsync. If that user code disposes of the Stream, then we end up null ref'ing when we try to finish the processing. We should instead throw a more descriptive error about the misuse. - Make the non-pooled response streams slightly more robust against concurrent disposal (which is not a supported use, but it happens). Commit migrated from https://github.com/dotnet/corefx/commit/422565673deb849b42f656f4a68bec791ee4bed9 --- .../ChunkedEncodingWriteStream.cs | 28 ++++++----- .../ConnectionCloseReadStream.cs | 35 +++++++------ .../SocketsHttpHandler/ContentLengthWriteStream.cs | 6 +-- .../Net/Http/SocketsHttpHandler/Http2Stream.cs | 10 ++-- .../Http/SocketsHttpHandler/HttpContentStream.cs | 11 ++++ .../SocketsHttpHandler/HttpContentWriteStream.cs | 13 +++-- .../Http/SocketsHttpHandler/RawConnectionStream.cs | 58 ++++++++++++---------- .../FunctionalTests/SocketsHttpHandlerTest.cs | 50 +++++++++++++++++++ 8 files changed, 143 insertions(+), 68 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs index 54bb5cf..8e4e438 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs @@ -20,7 +20,8 @@ namespace System.Net.Http public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ignored) { - Debug.Assert(_connection._currentRequest != null); + HttpConnection connection = GetConnectionOrThrow(); + Debug.Assert(connection._currentRequest != null); // The token is ignored because it's coming from SendAsync and the only operations // here are those that are already covered by the token having been registered with @@ -29,28 +30,29 @@ namespace System.Net.Http ValueTask task = buffer.Length == 0 ? // Don't write if nothing was given, especially since we don't want to accidentally send a 0 chunk, // which would indicate end of body. Instead, just ensure no content is stuck in the buffer. - _connection.FlushAsync() : - new ValueTask(WriteChunkAsync(buffer)); + connection.FlushAsync() : + new ValueTask(WriteChunkAsync(connection, buffer)); return task; - } - private async Task WriteChunkAsync(ReadOnlyMemory buffer) - { - // Write chunk length in hex followed by \r\n - await _connection.WriteHexInt32Async(buffer.Length).ConfigureAwait(false); - await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); + static async Task WriteChunkAsync(HttpConnection connection, ReadOnlyMemory buffer) + { + // Write chunk length in hex followed by \r\n + await connection.WriteHexInt32Async(buffer.Length).ConfigureAwait(false); + await connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); - // Write chunk contents followed by \r\n - await _connection.WriteAsync(buffer).ConfigureAwait(false); - await _connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); + // Write chunk contents followed by \r\n + await connection.WriteAsync(buffer).ConfigureAwait(false); + await connection.WriteTwoBytesAsync((byte)'\r', (byte)'\n').ConfigureAwait(false); + } } public override async Task FinishAsync() { // Send 0 byte chunk to indicate end, then final CrLf - await _connection.WriteBytesAsync(s_finalChunkBytes).ConfigureAwait(false); + HttpConnection connection = GetConnectionOrThrow(); _connection = null; + await connection.WriteBytesAsync(s_finalChunkBytes).ConfigureAwait(false); } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs index 8960ad2..262ab20 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectionCloseReadStream.cs @@ -18,18 +18,19 @@ namespace System.Net.Http public override int Read(Span buffer) { - if (_connection == null || buffer.Length == 0) + HttpConnection connection = _connection; + if (connection == null || buffer.Length == 0) { // Response body fully consumed or the caller didn't ask for any data return 0; } - int bytesRead = _connection.Read(buffer); + int bytesRead = connection.Read(buffer); if (bytesRead == 0) { // We cannot reuse this connection, so close it. - _connection.Dispose(); _connection = null; + connection.Dispose(); } return bytesRead; @@ -39,13 +40,14 @@ namespace System.Net.Http { CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - if (_connection == null || buffer.Length == 0) + HttpConnection connection = _connection; + if (connection == null || buffer.Length == 0) { // Response body fully consumed or the caller didn't ask for any data return 0; } - ValueTask readTask = _connection.ReadAsync(buffer); + ValueTask readTask = connection.ReadAsync(buffer); int bytesRead; if (readTask.IsCompletedSuccessfully) { @@ -53,7 +55,7 @@ namespace System.Net.Http } else { - CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + CancellationTokenRegistration ctr = connection.RegisterCancellation(cancellationToken); try { bytesRead = await readTask.ConfigureAwait(false); @@ -79,8 +81,8 @@ namespace System.Net.Http CancellationHelper.ThrowIfCancellationRequested(cancellationToken); // We cannot reuse this connection, so close it. - _connection.Dispose(); _connection = null; + connection.Dispose(); } return bytesRead; @@ -95,25 +97,26 @@ namespace System.Net.Http return Task.FromCanceled(cancellationToken); } - if (_connection == null) + HttpConnection connection = _connection; + if (connection == null) { // null if response body fully consumed return Task.CompletedTask; } - Task copyTask = _connection.CopyToUntilEofAsync(destination, bufferSize, cancellationToken); + Task copyTask = connection.CopyToUntilEofAsync(destination, bufferSize, cancellationToken); if (copyTask.IsCompletedSuccessfully) { - Finish(); + Finish(connection); return Task.CompletedTask; } - return CompleteCopyToAsync(copyTask, cancellationToken); + return CompleteCopyToAsync(copyTask, connection, cancellationToken); } - private async Task CompleteCopyToAsync(Task copyTask, CancellationToken cancellationToken) + private async Task CompleteCopyToAsync(Task copyTask, HttpConnection connection, CancellationToken cancellationToken) { - CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + CancellationTokenRegistration ctr = connection.RegisterCancellation(cancellationToken); try { await copyTask.ConfigureAwait(false); @@ -133,14 +136,14 @@ namespace System.Net.Http // been requested, we assume the copy completed due to cancellation and throw. CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - Finish(); + Finish(connection); } - private void Finish() + private void Finish(HttpConnection connection) { // We cannot reuse this connection, so close it. - _connection.Dispose(); _connection = null; + connection.Dispose(); } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs index 3c79ead..4804427 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs @@ -18,12 +18,12 @@ namespace System.Net.Http public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken ignored) // token ignored as it comes from SendAsync { - Debug.Assert(_connection._currentRequest != null); - // Have the connection write the data, skipping the buffer. Importantly, this will // force a flush of anything already in the buffer, i.e. any remaining request headers // that are still buffered. - return new ValueTask(_connection.WriteAsync(buffer)); + HttpConnection connection = GetConnectionOrThrow(); + Debug.Assert(connection._currentRequest != null); + return new ValueTask(connection.WriteAsync(buffer)); } public override Task FinishAsync() diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 53f793c..1a45dca 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -599,12 +599,10 @@ namespace System.Net.Http public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { Http2Stream http2Stream = _http2Stream; - if (http2Stream == null) - { - return new ValueTask(Task.FromException(new ObjectDisposedException(nameof(Http2WriteStream)))); - } - - return new ValueTask(http2Stream.SendDataAsync(buffer, cancellationToken)); + Task t = http2Stream != null ? + http2Stream.SendDataAsync(buffer, cancellationToken) : + Task.FromException(new ObjectDisposedException(nameof(Http2WriteStream))); + return new ValueTask(t); } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs index b1b40ea..860fb55 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentStream.cs @@ -26,5 +26,16 @@ namespace System.Net.Http base.Dispose(disposing); } + + protected HttpConnection GetConnectionOrThrow() + { + return _connection ?? + // This should only ever happen if the user-code that was handed this instance disposed of + // it, which is misuse, or held onto it and tried to use it later after we've disposed of it, + // which is also misuse. + ThrowObjectDisposedException(); + } + + private HttpConnection ThrowObjectDisposedException() => throw new ObjectDisposedException(GetType().Name); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs index 6802a07..537ee6f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs @@ -20,10 +20,15 @@ namespace System.Net.Http public sealed override bool CanWrite => true; public sealed override void Flush() => - _connection.Flush(); - - public sealed override Task FlushAsync(CancellationToken ignored) => - _connection.FlushAsync().AsTask(); + _connection?.Flush(); + + public sealed override Task FlushAsync(CancellationToken ignored) + { + HttpConnection connection = _connection; + return connection != null ? + connection.FlushAsync().AsTask() : + default; + } public sealed override int Read(Span buffer) => throw new NotSupportedException(); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs index b45475f..87c260f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs @@ -22,18 +22,19 @@ namespace System.Net.Http public override int Read(Span buffer) { - if (_connection == null || buffer.Length == 0) + HttpConnection connection = _connection; + if (connection == null || buffer.Length == 0) { // Response body fully consumed or the caller didn't ask for any data return 0; } - int bytesRead = _connection.ReadBuffered(buffer); + int bytesRead = connection.ReadBuffered(buffer); if (bytesRead == 0) { // We cannot reuse this connection, so close it. - _connection.Dispose(); _connection = null; + connection.Dispose(); } return bytesRead; @@ -43,13 +44,14 @@ namespace System.Net.Http { CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - if (_connection == null || buffer.Length == 0) + HttpConnection connection = _connection; + if (connection == null || buffer.Length == 0) { // Response body fully consumed or the caller didn't ask for any data return 0; } - ValueTask readTask = _connection.ReadBufferedAsync(buffer); + ValueTask readTask = connection.ReadBufferedAsync(buffer); int bytesRead; if (readTask.IsCompletedSuccessfully) { @@ -57,7 +59,7 @@ namespace System.Net.Http } else { - CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + CancellationTokenRegistration ctr = connection.RegisterCancellation(cancellationToken); try { bytesRead = await readTask.ConfigureAwait(false); @@ -78,8 +80,8 @@ namespace System.Net.Http CancellationHelper.ThrowIfCancellationRequested(cancellationToken); // We cannot reuse this connection, so close it. - _connection.Dispose(); _connection = null; + connection.Dispose(); } return bytesRead; @@ -94,25 +96,26 @@ namespace System.Net.Http return Task.FromCanceled(cancellationToken); } - if (_connection == null) + HttpConnection connection = _connection; + if (connection == null) { // null if response body fully consumed return Task.CompletedTask; } - Task copyTask = _connection.CopyToUntilEofAsync(destination, bufferSize, cancellationToken); + Task copyTask = connection.CopyToUntilEofAsync(destination, bufferSize, cancellationToken); if (copyTask.IsCompletedSuccessfully) { - Finish(); + Finish(connection); return Task.CompletedTask; } - return CompleteCopyToAsync(copyTask, cancellationToken); + return CompleteCopyToAsync(copyTask, connection, cancellationToken); } - private async Task CompleteCopyToAsync(Task copyTask, CancellationToken cancellationToken) + private async Task CompleteCopyToAsync(Task copyTask, HttpConnection connection, CancellationToken cancellationToken) { - CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + CancellationTokenRegistration ctr = connection.RegisterCancellation(cancellationToken); try { await copyTask.ConfigureAwait(false); @@ -132,13 +135,13 @@ namespace System.Net.Http // been requested, we assume the copy completed due to cancellation and throw. CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - Finish(); + Finish(connection); } - private void Finish() + private void Finish(HttpConnection connection) { // We cannot reuse this connection, so close it. - _connection.Dispose(); + connection.Dispose(); _connection = null; } @@ -150,14 +153,15 @@ namespace System.Net.Http public override void Write(ReadOnlySpan buffer) { - if (_connection == null) + HttpConnection connection = _connection; + if (connection == null) { throw new IOException(SR.ObjectDisposed_StreamClosed); } if (buffer.Length != 0) { - _connection.WriteWithoutBuffering(buffer); + connection.WriteWithoutBuffering(buffer); } } @@ -168,7 +172,8 @@ namespace System.Net.Http return new ValueTask(Task.FromCanceled(cancellationToken)); } - if (_connection == null) + HttpConnection connection = _connection; + if (connection == null) { return new ValueTask(Task.FromException(new IOException(SR.ObjectDisposed_StreamClosed))); } @@ -178,10 +183,10 @@ namespace System.Net.Http return default; } - ValueTask writeTask = _connection.WriteWithoutBufferingAsync(buffer); + ValueTask writeTask = connection.WriteWithoutBufferingAsync(buffer); return writeTask.IsCompleted ? writeTask : - new ValueTask(WaitWithConnectionCancellationAsync(writeTask, cancellationToken)); + new ValueTask(WaitWithConnectionCancellationAsync(writeTask, connection, cancellationToken)); } public override void Flush() => _connection?.Flush(); @@ -193,20 +198,21 @@ namespace System.Net.Http return Task.FromCanceled(cancellationToken); } - if (_connection == null) + HttpConnection connection = _connection; + if (connection == null) { return Task.CompletedTask; } - ValueTask flushTask = _connection.FlushAsync(); + ValueTask flushTask = connection.FlushAsync(); return flushTask.IsCompleted ? flushTask.AsTask() : - WaitWithConnectionCancellationAsync(flushTask, cancellationToken); + WaitWithConnectionCancellationAsync(flushTask, connection, cancellationToken); } - private async Task WaitWithConnectionCancellationAsync(ValueTask task, CancellationToken cancellationToken) + private static async Task WaitWithConnectionCancellationAsync(ValueTask task, HttpConnection connection, CancellationToken cancellationToken) { - CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); + CancellationTokenRegistration ctr = connection.RegisterCancellation(cancellationToken); try { await task.ConfigureAwait(false); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 198eb27..e37537a 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -457,6 +457,56 @@ namespace System.Net.Http.Functional.Tests { public SocketsHttpHandler_PostScenarioTest(ITestOutputHelper output) : base(output) { } protected override bool UseSocketsHttpHandler => true; + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task DisposeTargetStream_ThrowsObjectDisposedException(bool knownLength) + { + var tcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + await LoopbackServerFactory.CreateClientAndServerAsync(async uri => + { + try + { + using (HttpClient client = CreateHttpClient()) + { + Task t = client.PostAsync(uri, new DisposeStreamWhileCopyingContent(knownLength)); + Assert.IsType((await Assert.ThrowsAsync(() => t)).InnerException); + } + } + finally + { + tcs.SetResult(0); + } + }, server => tcs.Task); + } + + private sealed class DisposeStreamWhileCopyingContent : HttpContent + { + private readonly bool _knownLength; + + public DisposeStreamWhileCopyingContent(bool knownLength) => _knownLength = knownLength; + + protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context) + { + await stream.WriteAsync(new byte[42], 0, 42); + stream.Dispose(); + } + + protected override bool TryComputeLength(out long length) + { + if (_knownLength) + { + length = 42; + return true; + } + else + { + length = 0; + return false; + } + } + } } public sealed class SocketsHttpHandler_ResponseStreamTest : ResponseStreamTest -- 2.7.4