From 56fb662a62dcd780c5d0fbda510e2a80c877c627 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 24 Aug 2021 11:02:08 -0600 Subject: [PATCH] [release/6.0-rc1] [HTTP/3] Abort response stream on dispose if content not finished (#57999) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit * Sends abort read/write if H/3 stream is disposed before respective contents are finsihed * Minor tweaks in abort conditions * Prevent reverting SendState from Aborted/ConnectionClosed back to sending state within Send* methods. Co-authored-by: ManickaP Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com> --- .../tests/System/Net/Http/Http3LoopbackStream.cs | 8 +- .../Http/SocketsHttpHandler/Http3RequestStream.cs | 61 +++++++---- .../FunctionalTests/HttpClientHandlerTest.Http3.cs | 115 ++++++++++++++++++++- .../Quic/Implementations/MsQuic/MsQuicStream.cs | 40 +++---- 4 files changed, 175 insertions(+), 49 deletions(-) diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs index 958cb2a..3460984 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs @@ -20,10 +20,10 @@ namespace System.Net.Test.Common private const int MaximumVarIntBytes = 8; private const long VarIntMax = (1L << 62) - 1; - private const long DataFrame = 0x0; - private const long HeadersFrame = 0x1; - private const long SettingsFrame = 0x4; - private const long GoAwayFrame = 0x7; + public const long DataFrame = 0x0; + public const long HeadersFrame = 0x1; + public const long SettingsFrame = 0x4; + public const long GoAwayFrame = 0x7; public const long ControlStream = 0x0; public const long PushStream = 0x1; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index aa5f754..852507a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -84,6 +84,7 @@ namespace System.Net.Http if (!_disposed) { _disposed = true; + AbortStream(); _stream.Dispose(); DisposeSyncHelper(); } @@ -94,6 +95,7 @@ namespace System.Net.Http if (!_disposed) { _disposed = true; + AbortStream(); await _stream.DisposeAsync().ConfigureAwait(false); DisposeSyncHelper(); } @@ -365,6 +367,9 @@ namespace System.Net.Http await content.CopyToAsync(writeStream, null, cancellationToken).ConfigureAwait(false); } + // Set to 0 to recognize that the whole request body has been sent and therefore there's no need to abort write side in case of a premature disposal. + _requestContentLengthRemaining = 0; + if (_sendBuffer.ActiveLength != 0) { // Our initial send buffer, which has our headers, is normally sent out on the first write to the Http3WriteStream. @@ -1217,6 +1222,20 @@ namespace System.Net.Http public void Trace(string message, [CallerMemberName] string? memberName = null) => _connection.Trace(StreamId, message, memberName); + private void AbortStream() + { + // If the request body isn't completed, cancel it now. + if (_requestContentLengthRemaining != 0) // 0 is used for the end of content writing, -1 is used for unknown Content-Length + { + _stream.AbortWrite((long)Http3ErrorCode.RequestCancelled); + } + // If the response body isn't completed, cancel it now. + if (_responseDataPayloadRemaining != -1) // -1 is used for EOF, 0 for consumed DATA frame payload before the next read + { + _stream.AbortRead((long)Http3ErrorCode.RequestCancelled); + } + } + // TODO: it may be possible for Http3RequestStream to implement Stream directly and avoid this allocation. private sealed class Http3ReadStream : HttpBaseStream { @@ -1240,36 +1259,42 @@ namespace System.Net.Http protected override void Dispose(bool disposing) { - if (_stream != null) + Http3RequestStream? stream = Interlocked.Exchange(ref _stream, null); + if (stream is null) { - if (disposing) - { - // This will remove the stream from the connection properly. - _stream.Dispose(); - } - else - { - // We shouldn't be using a managed instance here, but don't have much choice -- we - // need to remove the stream from the connection's GOAWAY collection. - _stream._connection.RemoveStream(_stream._stream); - _stream._connection = null!; - } + return; + } - _stream = null; - _response = null; + if (disposing) + { + // This will remove the stream from the connection properly. + stream.Dispose(); + } + else + { + // We shouldn't be using a managed instance here, but don't have much choice -- we + // need to remove the stream from the connection's GOAWAY collection. + stream._connection.RemoveStream(stream._stream); + stream._connection = null!; } + _response = null; + base.Dispose(disposing); } public override async ValueTask DisposeAsync() { - if (_stream != null) + Http3RequestStream? stream = Interlocked.Exchange(ref _stream, null); + if (stream is null) { - await _stream.DisposeAsync().ConfigureAwait(false); - _stream = null!; + return; } + await stream.DisposeAsync().ConfigureAwait(false); + + _response = null; + await base.DisposeAsync().ConfigureAwait(false); } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs index 547f91d..febbf3c 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs @@ -320,6 +320,119 @@ namespace System.Net.Http.Functional.Tests } [Fact] + public async Task RequestSentResponseDisposed_ThrowsOnServer() + { + byte[] data = Encoding.UTF8.GetBytes(new string('a', 1024)); + + using Http3LoopbackServer server = CreateHttp3LoopbackServer(); + + Task serverTask = Task.Run(async () => + { + using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + HttpRequestData request = await stream.ReadRequestDataAsync(); + await stream.SendResponseHeadersAsync(); + + Stopwatch sw = Stopwatch.StartNew(); + bool hasFailed = false; + while (sw.Elapsed < TimeSpan.FromSeconds(15)) + { + try + { + await stream.SendResponseBodyAsync(data, isFinal: false); + } + catch (QuicStreamAbortedException) + { + hasFailed = true; + break; + } + } + Assert.True(hasFailed, $"Expected {nameof(QuicStreamAbortedException)}, instead ran successfully for {sw.Elapsed}"); + }); + + Task clientTask = Task.Run(async () => + { + using HttpClient client = CreateHttpClient(); + using HttpRequestMessage request = new() + { + Method = HttpMethod.Get, + RequestUri = server.Address, + Version = HttpVersion30, + VersionPolicy = HttpVersionPolicy.RequestVersionExact + }; + + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + var stream = await response.Content.ReadAsStreamAsync(); + byte[] buffer = new byte[512]; + for (int i = 0; i < 5; ++i) + { + var count = await stream.ReadAsync(buffer); + } + + // We haven't finished reading the whole respose, but we're disposing it, which should turn into an exception on the server-side. + response.Dispose(); + await serverTask; + }); + + await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000); + } + + [Fact] + public async Task RequestSendingResponseDisposed_ThrowsOnServer() + { + byte[] data = Encoding.UTF8.GetBytes(new string('a', 1024)); + + using Http3LoopbackServer server = CreateHttp3LoopbackServer(); + + Task serverTask = Task.Run(async () => + { + using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + HttpRequestData request = await stream.ReadRequestDataAsync(false); + await stream.SendResponseHeadersAsync(); + + Stopwatch sw = Stopwatch.StartNew(); + bool hasFailed = false; + while (sw.Elapsed < TimeSpan.FromSeconds(15)) + { + try + { + var (frameType, payload) = await stream.ReadFrameAsync(); + Assert.Equal(Http3LoopbackStream.DataFrame, frameType); + } + catch (QuicStreamAbortedException) + { + hasFailed = true; + break; + } + } + Assert.True(hasFailed, $"Expected {nameof(QuicStreamAbortedException)}, instead ran successfully for {sw.Elapsed}"); + }); + + Task clientTask = Task.Run(async () => + { + using HttpClient client = CreateHttpClient(); + using HttpRequestMessage request = new() + { + Method = HttpMethod.Get, + RequestUri = server.Address, + Version = HttpVersion30, + VersionPolicy = HttpVersionPolicy.RequestVersionExact, + Content = new ByteAtATimeContent(60*4, Task.CompletedTask, new TaskCompletionSource(), 250) + }; + + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + var stream = await response.Content.ReadAsStreamAsync(); + + // We haven't finished sending the whole request, but we're disposing the response, which should turn into an exception on the server-side. + response.Dispose(); + await serverTask; + }); + + await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000); + } + + [Fact] public async Task ServerCertificateCustomValidationCallback_Succeeds() { // Mock doesn't make use of cart validation callback. @@ -885,7 +998,7 @@ namespace System.Net.Http.Functional.Tests VersionPolicy = HttpVersionPolicy.RequestVersionExact }; HttpResponseMessage response = await client.SendAsync(request).WaitAsync(TimeSpan.FromSeconds(10)); - + Assert.Equal(statusCode, response.StatusCode); await serverTask; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index bbaf9c4..827029a 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -265,7 +265,7 @@ namespace System.Net.Quic.Implementations.MsQuic { ThrowIfDisposed(); - using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken); + using CancellationTokenRegistration registration = HandleWriteStartState(buffers.IsEmpty, cancellationToken); await SendReadOnlySequenceAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false); @@ -281,7 +281,7 @@ namespace System.Net.Quic.Implementations.MsQuic { ThrowIfDisposed(); - using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken); + using CancellationTokenRegistration registration = HandleWriteStartState(buffers.IsEmpty, cancellationToken); await SendReadOnlyMemoryListAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false); @@ -292,20 +292,20 @@ namespace System.Net.Quic.Implementations.MsQuic { ThrowIfDisposed(); - using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken); + using CancellationTokenRegistration registration = HandleWriteStartState(buffer.IsEmpty, cancellationToken); await SendReadOnlyMemoryAsync(buffer, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false); HandleWriteCompletedState(); } - private CancellationTokenRegistration HandleWriteStartState(CancellationToken cancellationToken) + private CancellationTokenRegistration HandleWriteStartState(bool emptyBuffer, CancellationToken cancellationToken) { if (_state.SendState == SendState.Closed) { throw new InvalidOperationException(SR.net_quic_writing_notallowed); } - else if ( _state.SendState == SendState.Aborted) + if (_state.SendState == SendState.Aborted) { if (_state.SendErrorCode != -1) { @@ -363,10 +363,14 @@ namespace System.Net.Quic.Implementations.MsQuic throw new OperationCanceledException(SR.net_quic_sending_aborted); } - else if (_state.SendState == SendState.ConnectionClosed) + if (_state.SendState == SendState.ConnectionClosed) { throw GetConnectionAbortedException(_state); } + + // Change the state in the same lock where we check for final states to prevent coming back from Aborted/ConnectionClosed. + Debug.Assert(_state.SendState != SendState.Pending); + _state.SendState = emptyBuffer ? SendState.Finished : SendState.Pending; } return registration; @@ -632,7 +636,10 @@ namespace System.Net.Quic.Implementations.MsQuic lock (_state) { - _state.SendState = SendState.Finished; + if (_state.SendState < SendState.Finished) + { + _state.SendState = SendState.Finished; + } } // it is ok to send shutdown several times, MsQuic will ignore it @@ -1157,12 +1164,6 @@ namespace System.Net.Quic.Implementations.MsQuic ReadOnlyMemory buffer, QUIC_SEND_FLAGS flags) { - lock (_state) - { - Debug.Assert(_state.SendState != SendState.Pending); - _state.SendState = buffer.IsEmpty ? SendState.Finished : SendState.Pending; - } - if (buffer.IsEmpty) { if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) @@ -1211,13 +1212,6 @@ namespace System.Net.Quic.Implementations.MsQuic ReadOnlySequence buffers, QUIC_SEND_FLAGS flags) { - - lock (_state) - { - Debug.Assert(_state.SendState != SendState.Pending); - _state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending; - } - if (buffers.IsEmpty) { if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) @@ -1281,12 +1275,6 @@ namespace System.Net.Quic.Implementations.MsQuic ReadOnlyMemory> buffers, QUIC_SEND_FLAGS flags) { - lock (_state) - { - Debug.Assert(_state.SendState != SendState.Pending); - _state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending; - } - if (buffers.IsEmpty) { if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) -- 2.7.4