[release/6.0-rc1] [HTTP/3] Abort response stream on dispose if content not finished...
authorgithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Tue, 24 Aug 2021 17:02:08 +0000 (11:02 -0600)
committerGitHub <noreply@github.com>
Tue, 24 Aug 2021 17:02:08 +0000 (11:02 -0600)
* Sends abort read/write if H/3 stream is disposed before respective contents are finsihed

* Minor tweaks in abort conditions

* Prevent reverting SendState from Aborted/ConnectionClosed back to sending state within Send* methods.

Co-authored-by: ManickaP <mapichov@microsoft.com>
Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com>
src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs

index 958cb2a..3460984 100644 (file)
@@ -20,10 +20,10 @@ namespace System.Net.Test.Common
         private const int MaximumVarIntBytes = 8;
         private const long VarIntMax = (1L << 62) - 1;
 
-        private const long DataFrame = 0x0;
-        private const long HeadersFrame = 0x1;
-        private const long SettingsFrame = 0x4;
-        private const long GoAwayFrame = 0x7;
+        public const long DataFrame = 0x0;
+        public const long HeadersFrame = 0x1;
+        public const long SettingsFrame = 0x4;
+        public const long GoAwayFrame = 0x7;
 
         public const long ControlStream = 0x0;
         public const long PushStream = 0x1;
index aa5f754..852507a 100644 (file)
@@ -84,6 +84,7 @@ namespace System.Net.Http
             if (!_disposed)
             {
                 _disposed = true;
+                AbortStream();
                 _stream.Dispose();
                 DisposeSyncHelper();
             }
@@ -94,6 +95,7 @@ namespace System.Net.Http
             if (!_disposed)
             {
                 _disposed = true;
+                AbortStream();
                 await _stream.DisposeAsync().ConfigureAwait(false);
                 DisposeSyncHelper();
             }
@@ -365,6 +367,9 @@ namespace System.Net.Http
                 await content.CopyToAsync(writeStream, null, cancellationToken).ConfigureAwait(false);
             }
 
+            // Set to 0 to recognize that the whole request body has been sent and therefore there's no need to abort write side in case of a premature disposal.
+            _requestContentLengthRemaining = 0;
+
             if (_sendBuffer.ActiveLength != 0)
             {
                 // Our initial send buffer, which has our headers, is normally sent out on the first write to the Http3WriteStream.
@@ -1217,6 +1222,20 @@ namespace System.Net.Http
         public void Trace(string message, [CallerMemberName] string? memberName = null) =>
             _connection.Trace(StreamId, message, memberName);
 
+        private void AbortStream()
+        {
+            // If the request body isn't completed, cancel it now.
+            if (_requestContentLengthRemaining != 0) // 0 is used for the end of content writing, -1 is used for unknown Content-Length
+            {
+                _stream.AbortWrite((long)Http3ErrorCode.RequestCancelled);
+            }
+            // If the response body isn't completed, cancel it now.
+            if (_responseDataPayloadRemaining != -1) // -1 is used for EOF, 0 for consumed DATA frame payload before the next read
+            {
+                _stream.AbortRead((long)Http3ErrorCode.RequestCancelled);
+            }
+        }
+
         // TODO: it may be possible for Http3RequestStream to implement Stream directly and avoid this allocation.
         private sealed class Http3ReadStream : HttpBaseStream
         {
@@ -1240,36 +1259,42 @@ namespace System.Net.Http
 
             protected override void Dispose(bool disposing)
             {
-                if (_stream != null)
+                Http3RequestStream? stream = Interlocked.Exchange(ref _stream, null);
+                if (stream is null)
                 {
-                    if (disposing)
-                    {
-                        // This will remove the stream from the connection properly.
-                        _stream.Dispose();
-                    }
-                    else
-                    {
-                        // We shouldn't be using a managed instance here, but don't have much choice -- we
-                        // need to remove the stream from the connection's GOAWAY collection.
-                        _stream._connection.RemoveStream(_stream._stream);
-                        _stream._connection = null!;
-                    }
+                    return;
+                }
 
-                    _stream = null;
-                    _response = null;
+                if (disposing)
+                {
+                    // This will remove the stream from the connection properly.
+                    stream.Dispose();
+                }
+                else
+                {
+                    // We shouldn't be using a managed instance here, but don't have much choice -- we
+                    // need to remove the stream from the connection's GOAWAY collection.
+                    stream._connection.RemoveStream(stream._stream);
+                    stream._connection = null!;
                 }
 
+                _response = null;
+
                 base.Dispose(disposing);
             }
 
             public override async ValueTask DisposeAsync()
             {
-                if (_stream != null)
+                Http3RequestStream? stream = Interlocked.Exchange(ref _stream, null);
+                if (stream is null)
                 {
-                    await _stream.DisposeAsync().ConfigureAwait(false);
-                    _stream = null!;
+                    return;
                 }
 
+                await stream.DisposeAsync().ConfigureAwait(false);
+
+                _response = null;
+
                 await base.DisposeAsync().ConfigureAwait(false);
             }
 
index 547f91d..febbf3c 100644 (file)
@@ -320,6 +320,119 @@ namespace System.Net.Http.Functional.Tests
         }
 
         [Fact]
+        public async Task RequestSentResponseDisposed_ThrowsOnServer()
+        {
+            byte[] data = Encoding.UTF8.GetBytes(new string('a', 1024));
+
+            using Http3LoopbackServer server = CreateHttp3LoopbackServer();
+
+            Task serverTask = Task.Run(async () =>
+            {
+                using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
+                using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
+                HttpRequestData request = await stream.ReadRequestDataAsync();
+                await stream.SendResponseHeadersAsync();
+
+                Stopwatch sw = Stopwatch.StartNew();
+                bool hasFailed = false;
+                while (sw.Elapsed < TimeSpan.FromSeconds(15))
+                {
+                    try
+                    {
+                        await stream.SendResponseBodyAsync(data, isFinal: false);
+                    }
+                    catch (QuicStreamAbortedException)
+                    {
+                        hasFailed = true;
+                        break;
+                    }
+                }
+                Assert.True(hasFailed, $"Expected {nameof(QuicStreamAbortedException)}, instead ran successfully for {sw.Elapsed}");
+            });
+
+            Task clientTask = Task.Run(async () =>
+            {
+                using HttpClient client = CreateHttpClient();
+                using HttpRequestMessage request = new()
+                {
+                    Method = HttpMethod.Get,
+                    RequestUri = server.Address,
+                    Version = HttpVersion30,
+                    VersionPolicy = HttpVersionPolicy.RequestVersionExact
+                };
+
+                var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
+                var stream = await response.Content.ReadAsStreamAsync();
+                byte[] buffer = new byte[512];
+                for (int i = 0; i < 5; ++i)
+                {
+                    var count = await stream.ReadAsync(buffer);
+                }
+
+                // We haven't finished reading the whole respose, but we're disposing it, which should turn into an exception on the server-side.
+                response.Dispose();
+                await serverTask;
+            });
+
+            await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
+        }
+
+        [Fact]
+        public async Task RequestSendingResponseDisposed_ThrowsOnServer()
+        {
+            byte[] data = Encoding.UTF8.GetBytes(new string('a', 1024));
+
+            using Http3LoopbackServer server = CreateHttp3LoopbackServer();
+
+            Task serverTask = Task.Run(async () =>
+            {
+                using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
+                using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
+                HttpRequestData request = await stream.ReadRequestDataAsync(false);
+                await stream.SendResponseHeadersAsync();
+
+                Stopwatch sw = Stopwatch.StartNew();
+                bool hasFailed = false;
+                while (sw.Elapsed < TimeSpan.FromSeconds(15))
+                {
+                    try
+                    {
+                        var (frameType, payload) = await stream.ReadFrameAsync();
+                        Assert.Equal(Http3LoopbackStream.DataFrame, frameType);
+                    }
+                    catch (QuicStreamAbortedException)
+                    {
+                        hasFailed = true;
+                        break;
+                    }
+                }
+                Assert.True(hasFailed, $"Expected {nameof(QuicStreamAbortedException)}, instead ran successfully for {sw.Elapsed}");
+            });
+
+            Task clientTask = Task.Run(async () =>
+            {
+                using HttpClient client = CreateHttpClient();
+                using HttpRequestMessage request = new()
+                {
+                    Method = HttpMethod.Get,
+                    RequestUri = server.Address,
+                    Version = HttpVersion30,
+                    VersionPolicy = HttpVersionPolicy.RequestVersionExact,
+                    Content = new ByteAtATimeContent(60*4, Task.CompletedTask, new TaskCompletionSource<bool>(), 250)
+                };
+
+                var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
+                var stream = await response.Content.ReadAsStreamAsync();
+
+                // We haven't finished sending the whole request, but we're disposing the response, which should turn into an exception on the server-side.
+                response.Dispose();
+                await serverTask;
+            });
+
+            await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
+        }
+
+        [Fact]
         public async Task ServerCertificateCustomValidationCallback_Succeeds()
         {
             // Mock doesn't make use of cart validation callback.
@@ -885,7 +998,7 @@ namespace System.Net.Http.Functional.Tests
                 VersionPolicy = HttpVersionPolicy.RequestVersionExact
             };
             HttpResponseMessage response = await client.SendAsync(request).WaitAsync(TimeSpan.FromSeconds(10));
-            
+
             Assert.Equal(statusCode, response.StatusCode);
 
             await serverTask;
index bbaf9c4..827029a 100644 (file)
@@ -265,7 +265,7 @@ namespace System.Net.Quic.Implementations.MsQuic
         {
             ThrowIfDisposed();
 
-            using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
+            using CancellationTokenRegistration registration = HandleWriteStartState(buffers.IsEmpty, cancellationToken);
 
             await SendReadOnlySequenceAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);
 
@@ -281,7 +281,7 @@ namespace System.Net.Quic.Implementations.MsQuic
         {
             ThrowIfDisposed();
 
-            using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
+            using CancellationTokenRegistration registration = HandleWriteStartState(buffers.IsEmpty, cancellationToken);
 
             await SendReadOnlyMemoryListAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);
 
@@ -292,20 +292,20 @@ namespace System.Net.Quic.Implementations.MsQuic
         {
             ThrowIfDisposed();
 
-            using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
+            using CancellationTokenRegistration registration = HandleWriteStartState(buffer.IsEmpty, cancellationToken);
 
             await SendReadOnlyMemoryAsync(buffer, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);
 
             HandleWriteCompletedState();
         }
 
-        private CancellationTokenRegistration HandleWriteStartState(CancellationToken cancellationToken)
+        private CancellationTokenRegistration HandleWriteStartState(bool emptyBuffer, CancellationToken cancellationToken)
         {
             if (_state.SendState == SendState.Closed)
             {
                 throw new InvalidOperationException(SR.net_quic_writing_notallowed);
             }
-            else if ( _state.SendState == SendState.Aborted)
+            if (_state.SendState == SendState.Aborted)
             {
                 if (_state.SendErrorCode != -1)
                 {
@@ -363,10 +363,14 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                     throw new OperationCanceledException(SR.net_quic_sending_aborted);
                 }
-                else if (_state.SendState == SendState.ConnectionClosed)
+                if (_state.SendState == SendState.ConnectionClosed)
                 {
                     throw GetConnectionAbortedException(_state);
                 }
+
+                // Change the state in the same lock where we check for final states to prevent coming back from Aborted/ConnectionClosed.
+                Debug.Assert(_state.SendState != SendState.Pending);
+                _state.SendState = emptyBuffer ? SendState.Finished : SendState.Pending;
             }
 
             return registration;
@@ -632,7 +636,10 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             lock (_state)
             {
-                _state.SendState = SendState.Finished;
+                if (_state.SendState < SendState.Finished)
+                {
+                    _state.SendState = SendState.Finished;
+                }
             }
 
             // it is ok to send shutdown several times, MsQuic will ignore it
@@ -1157,12 +1164,6 @@ namespace System.Net.Quic.Implementations.MsQuic
            ReadOnlyMemory<byte> buffer,
            QUIC_SEND_FLAGS flags)
         {
-            lock (_state)
-            {
-                Debug.Assert(_state.SendState != SendState.Pending);
-                _state.SendState = buffer.IsEmpty ? SendState.Finished : SendState.Pending;
-            }
-
             if (buffer.IsEmpty)
             {
                 if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
@@ -1211,13 +1212,6 @@ namespace System.Net.Quic.Implementations.MsQuic
            ReadOnlySequence<byte> buffers,
            QUIC_SEND_FLAGS flags)
         {
-
-            lock (_state)
-            {
-                Debug.Assert(_state.SendState != SendState.Pending);
-                _state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending;
-            }
-
             if (buffers.IsEmpty)
             {
                 if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
@@ -1281,12 +1275,6 @@ namespace System.Net.Quic.Implementations.MsQuic
            ReadOnlyMemory<ReadOnlyMemory<byte>> buffers,
            QUIC_SEND_FLAGS flags)
         {
-            lock (_state)
-            {
-                Debug.Assert(_state.SendState != SendState.Pending);
-                _state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending;
-            }
-
             if (buffers.IsEmpty)
             {
                 if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)