Fix race when receiving HEADERS and RST_STREAM in rapid succession. (#72932)
authorRadek Zikmund <32671551+rzikm@users.noreply.github.com>
Thu, 28 Jul 2022 15:29:01 +0000 (17:29 +0200)
committerGitHub <noreply@github.com>
Thu, 28 Jul 2022 15:29:01 +0000 (17:29 +0200)
* Fix race when receiving HEADERS and RST_STREAM in rapid succession.

* Improve test

src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs

index 4ca66fa..e9e7254 100644 (file)
@@ -130,6 +130,19 @@ namespace System.Net.Test.Common
             await _connectionStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false);
         }
 
+        public async Task WriteFramesAsync(Frame[] frames, CancellationToken cancellationToken = default)
+        {
+            byte[] writeBuffer = new byte[frames.Sum(frame => Frame.FrameHeaderLength + frame.Length)];
+
+            int offset = 0;
+            foreach (Frame frame in frames)
+            {
+                frame.WriteTo(writeBuffer.AsSpan(offset));
+                offset += Frame.FrameHeaderLength + frame.Length;
+            }
+            await _connectionStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false);
+        }
+
         // Read until the buffer is full
         // Return false on EOF, throw on partial read
         private async Task<bool> FillBufferAsync(Memory<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
@@ -567,7 +580,7 @@ namespace System.Net.Test.Common
                 }
                 else if (frame == null || frame.Type == FrameType.RstStream)
                 {
-                    throw new IOException( frame == null ? "End of stream" : "Got RST");
+                    throw new IOException(frame == null ? "End of stream" : "Got RST");
                 }
 
                 Assert.Equal(FrameType.Data, frame.Type);
@@ -586,7 +599,7 @@ namespace System.Net.Test.Common
 
                         body.CopyTo(newBuffer, 0);
                         dataFrame.Data.Span.CopyTo(newBuffer.AsSpan(body.Length));
-                        body= newBuffer;
+                        body = newBuffer;
                     }
                 }
             }
@@ -947,11 +960,11 @@ namespace System.Net.Test.Common
 
             if (string.IsNullOrEmpty(content))
             {
-                await SendResponseHeadersAsync(streamId, endStream: true, statusCode, isTrailingHeader: false, headers : headers).ConfigureAwait(false);
+                await SendResponseHeadersAsync(streamId, endStream: true, statusCode, isTrailingHeader: false, headers: headers).ConfigureAwait(false);
             }
             else
             {
-                await SendResponseHeadersAsync(streamId, endStream: false, statusCode, isTrailingHeader: false, headers : headers).ConfigureAwait(false);
+                await SendResponseHeadersAsync(streamId, endStream: false, statusCode, isTrailingHeader: false, headers: headers).ConfigureAwait(false);
                 await SendResponseBodyAsync(streamId, Encoding.ASCII.GetBytes(content)).ConfigureAwait(false);
             }
 
index b3f8277..350e8d1 100644 (file)
@@ -45,6 +45,7 @@ namespace System.Net.Http
             private StreamCompletionState _requestCompletionState;
             private StreamCompletionState _responseCompletionState;
             private ResponseProtocolState _responseProtocolState;
+            private bool _responseHeadersReceived;
             private bool _webSocketEstablished;
 
             // If this is not null, then we have received a reset from the server
@@ -775,6 +776,7 @@ namespace System.Net.Http
 
                         case ResponseProtocolState.ExpectingHeaders:
                             _responseProtocolState = endStream ? ResponseProtocolState.Complete : ResponseProtocolState.ExpectingData;
+                            _responseHeadersReceived = true;
                             break;
 
                         case ResponseProtocolState.ExpectingTrailingHeaders:
@@ -988,24 +990,16 @@ namespace System.Net.Http
                 Debug.Assert(!Monitor.IsEntered(SyncObject));
                 lock (SyncObject)
                 {
-                    CheckResponseBodyState();
-
-                    if (_responseProtocolState == ResponseProtocolState.ExpectingHeaders || _responseProtocolState == ResponseProtocolState.ExpectingIgnoredHeaders || _responseProtocolState == ResponseProtocolState.ExpectingStatus)
+                    if (!_responseHeadersReceived)
                     {
+                        CheckResponseBodyState();
                         Debug.Assert(!_hasWaiter);
                         _hasWaiter = true;
                         _waitSource.Reset();
                         return (true, false);
                     }
-                    else if (_responseProtocolState == ResponseProtocolState.ExpectingData || _responseProtocolState == ResponseProtocolState.ExpectingTrailingHeaders)
-                    {
-                        return (false, false);
-                    }
-                    else
-                    {
-                        Debug.Assert(_responseProtocolState == ResponseProtocolState.Complete);
-                        return (false, _responseBuffer.IsEmpty);
-                    }
+
+                    return (false, _responseProtocolState == ResponseProtocolState.Complete && _responseBuffer.IsEmpty);
                 }
             }
 
index 4349c6c..7b14f99 100644 (file)
@@ -329,6 +329,31 @@ namespace System.Net.Http.Functional.Tests
         }
 
         [ConditionalFact(nameof(SupportsAlpn))]
+        public async Task Http2_StreamResetByServerAfterHeadersSent_ResponseHeadersRead_ContentThrows()
+        {
+            using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer())
+            using (HttpClient client = CreateHttpClient())
+            {
+                Task<HttpResponseMessage> sendTask = client.GetAsync(server.Address, HttpCompletionOption.ResponseHeadersRead);
+
+                Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
+                int streamId = await connection.ReadRequestHeaderAsync();
+
+                // Send response headers and RST_STREAM combined
+                await connection.WriteFramesAsync(new Frame[] {
+                    new HeadersFrame(new byte[] { 0x88 /* :status: 200 */}, FrameFlags.EndHeaders, 0, 0, 0, streamId),
+                    new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, streamId)
+                });
+
+                // Headers should be received successfully
+                HttpResponseMessage response = await sendTask;
+
+                // Reading the actual content should throw
+                await AssertHttpProtocolException((await response.Content.ReadAsStreamAsync()).ReadAsync(new byte[10]).AsTask(), ProtocolErrors.NO_ERROR);
+            }
+        }
+
+        [ConditionalFact(nameof(SupportsAlpn))]
         public async Task Http2_StreamResetByServerAfterPartialBodySent_RequestFails()
         {
             using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer())
@@ -2021,7 +2046,7 @@ namespace System.Net.Http.Functional.Tests
                 using (HttpClient client = CreateHttpClient())
                 {
                     var request = new HttpRequestMessage(HttpMethod.Post, url);
-                    request.Version = new Version(2,0);
+                    request.Version = new Version(2, 0);
                     request.Content = new CustomContent(stream);
 
                     await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await client.SendAsync(request, cts.Token));
@@ -2031,7 +2056,7 @@ namespace System.Net.Http.Functional.Tests
 
                     // Send another request to verify that connection is still functional.
                     request = new HttpRequestMessage(HttpMethod.Get, url);
-                    request.Version = new Version(2,0);
+                    request.Version = new Version(2, 0);
 
                     await client.SendAsync(request);
                 }
@@ -2040,12 +2065,13 @@ namespace System.Net.Http.Functional.Tests
             {
                 Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
 
-                (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
+                (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
                 int frameCount = 0;
                 Frame frame;
                 do
                 {
-                    if (frameCount == (waitForData ? 1 : 0)) {
+                    if (frameCount == (waitForData ? 1 : 0))
+                    {
                         // Cancel client after receiving Headers or part of request body.
                         cts.Cancel();
                     }
@@ -2053,7 +2079,7 @@ namespace System.Net.Http.Functional.Tests
                     Assert.NotNull(frame); // We should get Rst before closing connection.
                     Assert.Equal(0, (int)(frame.Flags & FrameFlags.EndStream));
                     frameCount++;
-                 } while (frame.Type != FrameType.RstStream);
+                } while (frame.Type != FrameType.RstStream);
 
                 Assert.Equal(1, frame.StreamId);
 
@@ -2089,7 +2115,7 @@ namespace System.Net.Http.Functional.Tests
                 await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
                 {
                     var request = new HttpRequestMessage(HttpMethod.Get, url);
-                    request.Version = new Version(2,0);
+                    request.Version = new Version(2, 0);
 
                     response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cts.Token);
                     using (Stream stream = await response.Content.ReadAsStreamAsync())
@@ -2123,7 +2149,7 @@ namespace System.Net.Http.Functional.Tests
                 {
                     Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
 
-                    (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
+                    (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
                     _output.WriteLine($"{DateTime.Now} Connection established");
 
                     await connection.SendResponseHeadersAsync(streamId, endStream: false, HttpStatusCode.OK);
@@ -2213,7 +2239,7 @@ namespace System.Net.Http.Functional.Tests
                 using (HttpClient client = CreateHttpClient())
                 {
                     var request = new HttpRequestMessage(HttpMethod.Post, url);
-                    request.Version = new Version(2,0);
+                    request.Version = new Version(2, 0);
                     request.Content = new StringContent(new string('*', 3000));
                     request.Headers.ExpectContinue = true;
                     request.Headers.Add("x-test", $"PostAsyncExpect100Continue_SendRequest_Ok({send100Continue}");
@@ -2226,7 +2252,7 @@ namespace System.Net.Http.Functional.Tests
             {
                 Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
 
-                (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
+                (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
                 Assert.Equal("100-continue", requestData.GetSingleHeaderValue("Expect"));
 
                 if (send100Continue)
@@ -2254,7 +2280,7 @@ namespace System.Net.Http.Functional.Tests
                     handler.Expect100ContinueTimeout = TimeSpan.FromSeconds(300);
 
                     var request = new HttpRequestMessage(HttpMethod.Post, url);
-                    request.Version = new Version(2,0);
+                    request.Version = new Version(2, 0);
                     request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
                     request.Content = new StringContent(new string('*', 3000));
                     request.Headers.ExpectContinue = true;
@@ -2269,7 +2295,7 @@ namespace System.Net.Http.Functional.Tests
             {
                 Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
 
-                (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
+                (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
                 Assert.Equal("100-continue", requestData.GetSingleHeaderValue("Expect"));
 
                 // Reject content with 403.
@@ -2277,7 +2303,7 @@ namespace System.Net.Http.Functional.Tests
                 await connection.SendResponseBodyAsync(streamId, Encoding.ASCII.GetBytes(responseContent));
 
                 // Client should send empty request body
-                byte[] requestBody = await connection.ReadBodyAsync(expectEndOfStream:true);
+                byte[] requestBody = await connection.ReadBodyAsync(expectEndOfStream: true);
                 Assert.Null(requestBody);
 
                 await connection.ShutdownIgnoringErrorsAsync(streamId);
@@ -3136,12 +3162,12 @@ namespace System.Net.Http.Functional.Tests
                     Task<HttpResponseMessage> responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
                     connection = await server.EstablishConnectionAsync();
 
-                     // Client should have sent the request headers, and the request stream should now be available
+                    // Client should have sent the request headers, and the request stream should now be available
                     Stream requestStream = await duplexContent.WaitForStreamAsync();
                     // Flush the content stream. Otherwise, the request headers are not guaranteed to be sent.
                     await requestStream.FlushAsync();
 
-                    (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
+                    (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
 
                     // Client finished sending request headers and we received them.
                     // Send request body.
@@ -3170,7 +3196,7 @@ namespace System.Net.Http.Functional.Tests
 
                     // Send trailing headers for good measure and close stream.
                     var headers = new HttpHeaderData[] { new HttpHeaderData("x-last", "done") };
-                    await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader : true, headers: headers);
+                    await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader: true, headers: headers);
 
                     // Finish reading response body and verify it for all cases.
                     string responseBody = await response.Content.ReadAsStringAsync();
@@ -3189,7 +3215,7 @@ namespace System.Net.Http.Functional.Tests
             TaskCompletionSource<bool> tsc = new TaskCompletionSource<bool>();
             string requestContent = new string('*', 300);
             const string responseContent = "SendAsync_ConcurentSendReceive_Fail";
-            var stream = new CustomContent.SlowTestStream(Encoding.UTF8.GetBytes(requestContent), tsc, trigger : 1, count : 50);
+            var stream = new CustomContent.SlowTestStream(Encoding.UTF8.GetBytes(requestContent), tsc, trigger: 1, count: 50);
             bool stopSending = false;
 
             await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
@@ -3197,7 +3223,7 @@ namespace System.Net.Http.Functional.Tests
                 using (HttpClient client = CreateHttpClient())
                 {
                     var request = new HttpRequestMessage(HttpMethod.Post, url);
-                    request.Version = new Version(2,0);
+                    request.Version = new Version(2, 0);
                     request.Content = new CustomContent(stream);
 
                     // This should fail either while getting response headers or while reading response body.
@@ -3217,7 +3243,7 @@ namespace System.Net.Http.Functional.Tests
             {
                 Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
 
-                (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
+                (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
                 await connection.SendResponseHeadersAsync(streamId, endStream: false, HttpStatusCode.OK);
 
                 // Wait for client so start sending body.
@@ -3226,7 +3252,7 @@ namespace System.Net.Http.Functional.Tests
                 int maxCount = 120;
                 while (!stopSending && maxCount != 0)
                 {
-                   try
+                    try
                     {
                         await connection.SendResponseDataAsync(streamId, Encoding.ASCII.GetBytes(responseContent), endStream: false);
                     }
@@ -3237,7 +3263,7 @@ namespace System.Net.Http.Functional.Tests
                         break;
                     }
                     await Task.Delay(500);
-                    maxCount --;
+                    maxCount--;
                 }
                 // We should not reach retry limit without failing.
                 Assert.NotEqual(0, maxCount);
@@ -3349,7 +3375,7 @@ namespace System.Net.Http.Functional.Tests
 
                 Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
                 int streamId = await connection.ReadRequestHeaderAsync();
-                await connection.SendResponseHeadersAsync(streamId, endStream : true, headers: headers);
+                await connection.SendResponseHeadersAsync(streamId, endStream: true, headers: headers);
                 await Assert.ThrowsAsync<HttpRequestException>(() => sendTask);
             }
         }
@@ -3366,7 +3392,7 @@ namespace System.Net.Http.Functional.Tests
 
                 Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
                 int streamId = await connection.ReadRequestHeaderAsync();
-                await connection.SendResponseHeadersAsync(streamId, endStream : true, isTrailingHeader : true, headers: headers);
+                await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader: true, headers: headers);
 
                 await Assert.ThrowsAsync<HttpRequestException>(() => sendTask);
             }
@@ -3386,7 +3412,7 @@ namespace System.Net.Http.Functional.Tests
                 int streamId = await connection.ReadRequestHeaderAsync();
                 await connection.SendDefaultResponseHeadersAsync(streamId);
                 await connection.SendResponseDataAsync(streamId, "hello"u8.ToArray(), endStream: false);
-                await connection.SendResponseHeadersAsync(streamId, endStream : true, isTrailingHeader : true, headers: headers);
+                await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader: true, headers: headers);
 
                 await Assert.ThrowsAsync<HttpRequestException>(() => sendTask);
             }