HTTP2: Fix handling of RST_STREAM with error = NO_ERROR after response EndStream...
authorGeoff Kizer <geoffrek@microsoft.com>
Mon, 5 Aug 2019 18:13:46 +0000 (11:13 -0700)
committerGitHub <noreply@github.com>
Mon, 5 Aug 2019 18:13:46 +0000 (11:13 -0700)
* fix handling of RST_STREAM with error = NO_ERROR after response EndStream

* eat exception and ensure we don't cancel the response body

* add comment

* add CancelResponseBody to avoid code duplication, and add checks we are not holding the stream object lock when we don't expect to

* fix _expect100ContinueWaiter to use TaskCreationOptions.RunContinuationsAsynchronously

Commit migrated from https://github.com/dotnet/corefx/commit/ca6c44b72b868c112387b0fb7def42d91f5577aa

src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs

index 16c971b..2a23915 100644 (file)
@@ -236,33 +236,16 @@ namespace System.Net.Test.Common
         }
 
         // Wait for the client to close the connection, e.g. after the HttpClient is disposed.
-        public async Task WaitForClientDisconnectAsync()
+        public async Task WaitForClientDisconnectAsync(bool ignoreUnexpectedFrames = false)
         {
-            Frame frame = await ReadFrameAsync(Timeout);
-            Assert.Null(frame);
-        }
-
-        public void ShutdownSend()
-        {
-            _connectionSocket.Shutdown(SocketShutdown.Send);
-        }
-
-        // This will wait for the client to close the connection,
-        // and ignore any meaningless frames -- i.e. WINDOW_UPDATE or expected SETTINGS ACK --
-        // that we see while waiting for the client to close.
-        // Only call this after sending a GOAWAY.
-        public async Task WaitForConnectionShutdownAsync(bool ignoreUnexpectedFrames = false)
-        {
-            // Shutdown our send side, so the client knows there won't be any more frames coming.
-            ShutdownSend();
-
             IgnoreWindowUpdates();
+
             Frame frame = await ReadFrameAsync(Timeout).ConfigureAwait(false);
             if (frame != null)
             {
                 if (!ignoreUnexpectedFrames)
                 {
-                    throw new Exception($"Unexpected frame received while waiting for client shutdown: {frame}");
+                    throw new Exception($"Unexpected frame received while waiting for client disconnect: {frame}");
                 }
             }
 
@@ -275,6 +258,22 @@ namespace System.Net.Test.Common
             _ignoreWindowUpdates = false;
         }
 
+        public void ShutdownSend()
+        {
+            _connectionSocket.Shutdown(SocketShutdown.Send);
+        }
+
+        // This will cause a server-initiated shutdown of the connection.
+        // For normal operation, you should send a GOAWAY and complete any remaining streams
+        // before calling this method.
+        public async Task WaitForConnectionShutdownAsync(bool ignoreUnexpectedFrames = false)
+        {
+            // Shutdown our send side, so the client knows there won't be any more frames coming.
+            ShutdownSend();
+
+            await WaitForClientDisconnectAsync(ignoreUnexpectedFrames: ignoreUnexpectedFrames);
+        }
+
         // This is similar to WaitForConnectionShutdownAsync but will send GOAWAY for you
         // and will ignore any errors if client has already shutdown
         public async Task ShutdownIgnoringErrorsAsync(int lastStreamId, ProtocolErrors errorCode = ProtocolErrors.NO_ERROR)
index 50e90e8..6f4781d 100644 (file)
@@ -655,11 +655,11 @@ namespace System.Net.Http
 
             if (protocolError == Http2ProtocolErrorCode.RefusedStream)
             {
-                http2Stream.OnReset(new Http2StreamException(protocolError), canRetry: true);
+                http2Stream.OnReset(new Http2StreamException(protocolError), resetStreamErrorCode: protocolError, canRetry: true);
             }
             else
             {
-                http2Stream.OnReset(new Http2StreamException(protocolError));
+                http2Stream.OnReset(new Http2StreamException(protocolError), resetStreamErrorCode: protocolError);
             }
         }
 
index eff4730..e22db8b 100644 (file)
@@ -47,6 +47,10 @@ namespace System.Net.Http
             private Exception _resetException;
             private bool _canRetry;             // if _resetException != null, this indicates the stream was refused and so the request is retryable
 
+            // This flag indicates that, per section 8.1 of the RFC, the server completed the response and then sent a RST_STREAM with error = NO_ERROR.
+            // This is a signal to stop sending the request body, but the request is still considered successful.
+            private bool _requestBodyAbandoned;
+
             /// <summary>
             /// The core logic for the IValueTaskSource implementation.
             /// 
@@ -116,7 +120,7 @@ namespace System.Net.Http
                     {
                         // Create a TCS for handling Expect: 100-continue semantics. See WaitFor100ContinueAsync.
                         // Note we need to create this in the constructor, because we can receive a 100 Continue response at any time after the constructor finishes.
-                        _expect100ContinueWaiter = new TaskCompletionSource<bool>(TaskContinuationOptions.RunContinuationsAsynchronously);
+                        _expect100ContinueWaiter = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
                     }
                 }
 
@@ -176,25 +180,40 @@ namespace System.Net.Http
                 {
                     if (NetEventSource.IsEnabled) Trace($"Failed to send request body: {e}");
 
-                    // Cancel the stream before we set _requestCompletionState below.
-                    // Otherwise, a response stream reader may race with the actual Cancel.
-                    Cancel();
+                    bool signalWaiter = false;
 
+                    Debug.Assert(!Monitor.IsEntered(SyncObject));
                     lock (SyncObject)
                     {
                         Debug.Assert(_requestCompletionState == StreamCompletionState.InProgress, $"Request already completed with state={_requestCompletionState}");
 
-                        _requestCompletionState = StreamCompletionState.Failed;
+                        if (_requestBodyAbandoned)
+                        {
+                            // See comments on _requestBodyAbandoned.
+                            // In this case, the request is still considered successful and we do not want to send a RST_STREAM, 
+                            // and we also don't want to propagate any error to the caller, in particular for non-duplex scenarios.
+                            Debug.Assert(_responseCompletionState == StreamCompletionState.Completed);
+                            _requestCompletionState = StreamCompletionState.Completed;
+                            Complete();
+                            return;
+                        }
 
-                        // Cancel above should ensure that the response is either Completed or Failed now.
-                        Debug.Assert(_responseCompletionState != StreamCompletionState.InProgress);
+                        // This should not cause RST_STREAM to be sent because the request is still marked as in progress.
+                        signalWaiter = CancelResponseBody();
 
+                        _requestCompletionState = StreamCompletionState.Failed;
                         Reset();
                     }
 
+                    if (signalWaiter)
+                    {
+                        _waitSource.SetResult(true);
+                    }
+
                     throw;
                 }
 
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
                     Debug.Assert(_requestCompletionState == StreamCompletionState.InProgress, $"Request already completed with state={_requestCompletionState}");
@@ -289,6 +308,7 @@ namespace System.Net.Http
                 CancellationTokenSource requestBodyCancellationSource = null;
                 bool signalWaiter = false;
 
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
                     if (_requestCompletionState == StreamCompletionState.InProgress)
@@ -297,25 +317,7 @@ namespace System.Net.Http
                         Debug.Assert(requestBodyCancellationSource != null);
                     }
 
-                    if (_responseCompletionState == StreamCompletionState.InProgress)
-                    {
-                        _responseCompletionState = StreamCompletionState.Failed;
-                        if (_requestCompletionState != StreamCompletionState.InProgress)
-                        {
-                            Reset();
-                        }
-                    }
-
-                    // Discard any remaining buffered response data
-                    if (_responseBuffer.ActiveLength != 0)
-                    {
-                        _responseBuffer.Discard(_responseBuffer.ActiveLength);
-                    }
-
-                    _responseProtocolState = ResponseProtocolState.Aborted;
-
-                    signalWaiter = _hasWaiter;
-                    _hasWaiter = false;
+                    signalWaiter = CancelResponseBody();
                 }
 
                 if (requestBodyCancellationSource != null)
@@ -330,6 +332,34 @@ namespace System.Net.Http
                 }
             }
 
+            // Returns whether the waiter should be signalled or not.
+            private bool CancelResponseBody()
+            {
+                Debug.Assert(Monitor.IsEntered(SyncObject));
+
+                if (_responseCompletionState == StreamCompletionState.InProgress)
+                {
+                    _responseCompletionState = StreamCompletionState.Failed;
+                    if (_requestCompletionState != StreamCompletionState.InProgress)
+                    {
+                        Reset();
+                    }
+                }
+
+                // Discard any remaining buffered response data
+                if (_responseBuffer.ActiveLength != 0)
+                {
+                    _responseBuffer.Discard(_responseBuffer.ActiveLength);
+                }
+
+                _responseProtocolState = ResponseProtocolState.Aborted;
+
+                bool signalWaiter = _hasWaiter;
+                _hasWaiter = false;
+
+                return signalWaiter;
+            }
+
             public void OnWindowUpdate(int amount) => _streamWindow.AdjustCredit(amount);
 
             public void OnResponseHeader(ReadOnlySpan<byte> name, ReadOnlySpan<byte> value)
@@ -345,6 +375,7 @@ namespace System.Net.Http
 
                 // TODO: ISSUE 31309: Optimize HPACK static table decoding
 
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
                     if (_responseProtocolState == ResponseProtocolState.Aborted)
@@ -465,6 +496,7 @@ namespace System.Net.Http
 
             public void OnResponseHeadersStart()
             {
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
                     if (_responseProtocolState == ResponseProtocolState.Aborted)
@@ -487,6 +519,7 @@ namespace System.Net.Http
 
             public void OnResponseHeadersComplete(bool endStream)
             {
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 bool signalWaiter;
                 lock (SyncObject)
                 {
@@ -558,6 +591,7 @@ namespace System.Net.Http
 
             public void OnResponseData(ReadOnlySpan<byte> buffer, bool endStream)
             {
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 bool signalWaiter;
                 lock (SyncObject)
                 {
@@ -609,10 +643,19 @@ namespace System.Net.Http
                 }
             }
 
-            public void OnReset(Exception resetException, bool canRetry = false)
+            // This is called in several different cases:
+            // (1) Receiving RST_STREAM on this stream. If so, the resetStreamErrorCode will be non-null, and canRetry will be true only if the error code was REFUSED_STREAM.
+            // (2) Receiving GOAWAY that indicates this stream has not been processed. If so, canRetry will be true.
+            // (3) Connection IO failure or protocol violation. If so, resetException will contain the relevant exception and canRetry will be false.
+            // (4) Receiving EOF from the server. If so, resetException will contain an exception like "expected 9 bytes of data", and canRetry will be false.
+            public void OnReset(Exception resetException, Http2ProtocolErrorCode? resetStreamErrorCode = null, bool canRetry = false)
             {
-                if (NetEventSource.IsEnabled) Trace($"{nameof(resetException)}={resetException}");
+                if (NetEventSource.IsEnabled) Trace($"{nameof(resetException)}={resetException}, {nameof(resetStreamErrorCode )}={resetStreamErrorCode}");
 
+                bool cancel = false;
+                CancellationTokenSource requestBodyCancellationSource = null;
+
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
                     // If we've already finished, don't actually reset the stream.
@@ -639,11 +682,39 @@ namespace System.Net.Http
                         canRetry = false;
                     }
 
-                    _resetException = resetException;
-                    _canRetry = canRetry;
+                    // Per section 8.1 in the RFC:
+                    // If the server has completed the response body (i.e. we've received EndStream)
+                    // but the request body is still sending, and we then receive a RST_STREAM with errorCode = NO_ERROR,
+                    // we treat this specially and simply cancel sending the request body, rather than treating
+                    // the entire request as failed.
+                    if (resetStreamErrorCode == Http2ProtocolErrorCode.NoError && 
+                        _responseCompletionState == StreamCompletionState.Completed)
+                    {
+                        if (_requestCompletionState == StreamCompletionState.InProgress)
+                        {
+                            _requestBodyAbandoned = true;
+                            requestBodyCancellationSource = _requestBodyCancellationSource;
+                            Debug.Assert(requestBodyCancellationSource != null);
+                        }
+                    }
+                    else
+                    {
+                        _resetException = resetException;
+                        _canRetry = canRetry;
+                        cancel = true;
+                    }
                 }
 
-                Cancel();
+                if (requestBodyCancellationSource != null)
+                {
+                    Debug.Assert(_requestBodyAbandoned);
+                    Debug.Assert(!cancel);
+                    requestBodyCancellationSource.Cancel();
+                }
+                else
+                {
+                    Cancel();
+                }
             }
 
             private void CheckResponseBodyState()
@@ -669,6 +740,7 @@ namespace System.Net.Http
             // Determine if we have enough data to process up to complete final response headers.
             private (bool wait, bool isEmptyResponse) TryEnsureHeaders()
             {
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
                     CheckResponseBodyState();
@@ -765,6 +837,7 @@ namespace System.Net.Http
             {
                 Debug.Assert(buffer.Length > 0);
 
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
                     CheckResponseBodyState();
@@ -900,6 +973,7 @@ namespace System.Net.Http
             {
                 // Check if the response body has been fully consumed.
                 bool fullyConsumed = false;
+                Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
                     if (_responseBuffer.ActiveLength == 0 && _responseProtocolState == ResponseProtocolState.Complete)
@@ -950,6 +1024,7 @@ namespace System.Net.Http
                     var thisRef = (Http2Stream)s;
 
                     bool signalWaiter;
+                    Debug.Assert(!Monitor.IsEntered(thisRef.SyncObject));
                     lock (thisRef.SyncObject)
                     {
                         signalWaiter = thisRef._hasWaiter;
index f11ace0..2348613 100644 (file)
@@ -2585,6 +2585,122 @@ namespace System.Net.Http.Functional.Tests
             }
         }
 
+        [Fact]
+        public async Task PostAsyncDuplex_ServerCompletesResponseBodyThenResetsStreamWithNoError_SuccessAndRequestBodyCancelled()
+        {
+            // Per section 8.1 of the RFC:
+            // Receiving RST_STREAM with NO_ERROR after receiving EndStream on the response body is a special case.
+            // We should stop sending the request body, but treat the request as successful and 
+            // return the completed response body to the user.
+
+            byte[] contentBytes = Encoding.UTF8.GetBytes("Hello world");
+
+            using (var server = Http2LoopbackServer.CreateServer())
+            {
+                Http2LoopbackConnection connection;
+                using (HttpClient client = CreateHttpClient())
+                {
+                    var duplexContent = new DuplexContent();
+
+                    var request = new HttpRequestMessage(HttpMethod.Post, server.Address);
+                    request.Version = new Version(2, 0);
+                    request.Content = duplexContent;
+                    Task<HttpResponseMessage> responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
+
+                    connection = await server.EstablishConnectionAsync();
+
+                    // Client should have sent the request headers, and the request stream should now be available
+                    Stream requestStream = await duplexContent.WaitForStreamAsync();
+
+                    // Flush the content stream. Otherwise, the request headers are not guaranteed to be sent.
+                    await requestStream.FlushAsync();
+
+                    (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
+
+                    // Send data to the server, even before we've received response headers.
+                    await SendAndReceiveRequestDataAsync(contentBytes, requestStream, connection, streamId);
+
+                    // Send response headers
+                    await connection.SendResponseHeadersAsync(streamId, endStream: false);
+                    HttpResponseMessage response = await responseTask;
+                    Stream responseStream = await response.Content.ReadAsStreamAsync();
+
+                    // Send response body and complete response
+                    await connection.SendResponseDataAsync(streamId, contentBytes, endStream: true);
+
+                    // Send RST_STREAM to client with error = NO_ERROR.
+                    await connection.WriteFrameAsync(new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, streamId));
+
+                    // Ensure client has processed the RST_STREAM.
+                    await connection.PingPong();
+
+                    // Attempting to write on the request body should now fail with OperationCanceledException.
+                    Exception e = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => { await SendAndReceiveRequestDataAsync(contentBytes, requestStream, connection, streamId); });
+
+                    // Propagate the exception to the request stream serialization task.
+                    // This allows the request processing to complete.
+                    duplexContent.Fail(e);
+
+                    // We should receive the response body and EOF.
+                    byte[] readBuffer = new byte[contentBytes.Length];
+                    int bytesRead = await responseStream.ReadAsync(readBuffer);
+                    Assert.True(contentBytes.SequenceEqual(readBuffer));
+                    bytesRead = await responseStream.ReadAsync(readBuffer);
+                    Assert.Equal(0, bytesRead);
+                }
+
+                // On handler dispose, client should shutdown the connection without sending additional frames.
+                await connection.WaitForClientDisconnectAsync();
+            }
+        }
+
+        [Fact]
+        public async Task PostAsyncNonDuplex_ServerCompletesResponseBodyThenResetsStreamWithNoError_SuccessAndRequestBodyCancelled()
+        {
+            // Per section 8.1 of the RFC:
+            // Receiving RST_STREAM with NO_ERROR after receiving EndStream on the response body is a special case.
+            // We should stop sending the request body, but treat the request as successful and 
+            // return the completed response body to the user.
+
+            byte[] contentBytes = Encoding.UTF8.GetBytes("Hello world");
+
+            using (var server = Http2LoopbackServer.CreateServer())
+            {
+                Http2LoopbackConnection connection;
+                using (HttpClient client = CreateHttpClient())
+                {
+                    // We want non-duplex content, so use ByteArrayContent,
+                    // but make it large enough to ensure that the content can't be fully sent because of flow control limitations.
+                    // This allows us to validate that the content is actually canceled, not just fully sent and completed.
+                    const int ContentSize = 100_000;
+                    var requestContent = new ByteArrayContent(TestHelper.GenerateRandomContent(ContentSize));
+
+                    var request = new HttpRequestMessage(HttpMethod.Post, server.Address);
+                    request.Version = new Version(2, 0);
+                    request.Content = requestContent;
+                    Task<HttpResponseMessage> responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
+
+                    connection = await server.EstablishConnectionAsync();
+
+                    (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
+
+                    // Send full response
+                    await connection.SendResponseHeadersAsync(streamId, endStream: false);
+                    await connection.SendResponseDataAsync(streamId, contentBytes, endStream: true);
+
+                    // Send RST_STREAM to client with error = NO_ERROR.
+                    await connection.WriteFrameAsync(new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, streamId));
+
+                    // Response should now complete successfully
+                    HttpResponseMessage response = await responseTask;
+                    Assert.Equal("Hello world", await response.Content.ReadAsStringAsync());
+                }
+
+                // On handler dispose, client should shutdown the connection. Ignore any request stream frames already sent.
+                await connection.WaitForClientDisconnectAsync(ignoreUnexpectedFrames: true);
+            }
+        }
+
         [Theory]
         [InlineData(true, HttpStatusCode.Forbidden)]
         [InlineData(false, HttpStatusCode.Forbidden)]