add basic GOAWAY tests and connection shutdown handling
authorGeoff Kizer <geoffrek>
Thu, 17 Jan 2019 11:11:35 +0000 (03:11 -0800)
committerGeoff Kizer <geoffrek>
Wed, 6 Feb 2019 22:01:45 +0000 (14:01 -0800)
Commit migrated from https://github.com/dotnet/corefx/commit/449a47ba2c6dbb568692d554787405dde7d14d8e

src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs

index 2eaba03..3a3aff2 100644 (file)
@@ -17,6 +17,7 @@ namespace System.Net.Test.Common
     public class Http2LoopbackServer : IDisposable
     {
         private Socket _listenSocket;
+        private Socket _connectionSocket;
         private Stream _connectionStream;
         private Http2Options _options;
         private Uri _uri;
@@ -80,6 +81,31 @@ namespace System.Net.Test.Common
             await _connectionStream.WriteAsync(writeBuffer, 0, writeBuffer.Length).ConfigureAwait(false);
         }
 
+        // Read until the buffer is full
+        // Return false on EOF, throw on partial read
+        private async Task<bool> FillBufferAsync(Memory<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
+        {
+            int readBytes = await _connectionStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
+            if (readBytes == 0)
+            {
+                return false;
+            }
+
+            buffer = buffer.Slice(readBytes);
+            while (buffer.Length > 0)
+            {
+                readBytes = await _connectionStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
+                if (readBytes == 0)
+                {
+                    throw new Exception("Connection closed when expecting more data.");
+                }
+
+                buffer = buffer.Slice(readBytes);
+            }
+
+            return true;
+        }
+
         public async Task<Frame> ReadFrameAsync(TimeSpan timeout)
         {
             // Prep the timeout cancellation token.
@@ -87,32 +113,18 @@ namespace System.Net.Test.Common
 
             // First read the frame headers, which should tell us how long the rest of the frame is.
             byte[] headerBytes = new byte[Frame.FrameHeaderLength];
-
-            int totalReadBytes = 0;
-            while(totalReadBytes < Frame.FrameHeaderLength)
+            if (!await FillBufferAsync(headerBytes, timeoutCts.Token).ConfigureAwait(false))
             {
-                int readBytes = await _connectionStream.ReadAsync(headerBytes, totalReadBytes, Frame.FrameHeaderLength - totalReadBytes, timeoutCts.Token).ConfigureAwait(false);
-                totalReadBytes += readBytes;
-                if (readBytes == 0)
-                {
-                    throw new Exception("Connection stream closed while attempting to read frame header.");
-                }
+                return null;
             }
 
             Frame header = Frame.ReadFrom(headerBytes);
 
             // Read the data segment of the frame, if it is present.
             byte[] data = new byte[header.Length];
-
-            totalReadBytes = 0;
-            while(totalReadBytes < header.Length)
+            if (header.Length > 0 && !await FillBufferAsync(data, timeoutCts.Token).ConfigureAwait(false))
             {
-                int readBytes = await _connectionStream.ReadAsync(data, totalReadBytes, header.Length - totalReadBytes, timeoutCts.Token).ConfigureAwait(false);
-                totalReadBytes += readBytes;
-                if (readBytes == 0)
-                {
-                    throw new Exception("Connection stream closed while attempting to read frame body.");
-                }
+                throw new Exception("Connection stream closed while attempting to read frame body.");
             }
 
             if (_ignoreSettingsAck && header.Type == FrameType.Settings && header.Flags == FrameFlags.Ack)
@@ -149,7 +161,13 @@ namespace System.Net.Test.Common
         // Returns the first 24 bytes read, which should be the connection preface.
         public async Task<string> AcceptConnectionAsync()
         {
-            _connectionStream = new NetworkStream(await _listenSocket.AcceptAsync().ConfigureAwait(false), true);
+            if (_connectionSocket != null)
+            {
+                throw new InvalidOperationException("Connection already established");
+            }
+
+            _connectionSocket = await _listenSocket.AcceptAsync().ConfigureAwait(false);
+            _connectionStream = new NetworkStream(_connectionSocket, true);
 
             if (_options.UseSsl)
             {
@@ -174,17 +192,11 @@ namespace System.Net.Test.Common
                 }
                 _connectionStream = sslStream;
             }
+
             byte[] prefix = new byte[24];
-            
-            int totalReadBytes = 0;
-            while(totalReadBytes < Frame.FrameHeaderLength)
+            if (!await FillBufferAsync(prefix).ConfigureAwait(false))
             {
-                int readBytes = await _connectionStream.ReadAsync(prefix, totalReadBytes, prefix.Length).ConfigureAwait(false);;
-                totalReadBytes += readBytes;
-                if (readBytes == 0)
-                {
-                    throw new Exception("Connection stream closed while attempting to read connection preface.");
-                }
+                throw new Exception("Connection stream closed while attempting to read connection preface.");
             }
 
             return System.Text.Encoding.UTF8.GetString(prefix, 0, prefix.Length);
@@ -234,6 +246,32 @@ namespace System.Net.Test.Common
             ExpectSettingsAck();
         }
 
+        // This will wait for the client to close the connection,
+        // and ignore any meaningless frames -- i.e. WINDOW_UPDATE or expected SETTINGS ACK --
+        // that we see while waiting for the client to close.
+        // Only call this after sending a GOAWAY.
+        public async Task WaitForConnectionShutdownAsync()
+        {
+            IgnoreWindowUpdates();
+
+            // Shutdown our send side, so the client knows there won't be any more frames coming.
+            _connectionSocket.Shutdown(SocketShutdown.Send);
+
+            Frame frame = await ReadFrameAsync(TimeSpan.FromSeconds(30)).ConfigureAwait(false);
+            if (frame != null)
+            {
+                throw new Exception($"Unexpected frame received while waiting for client shutdown: {frame}");
+            }
+
+            _connectionStream.Close();
+
+            _connectionSocket = null;
+            _connectionStream = null;
+
+            _ignoreSettingsAck = false;
+            _ignoreWindowUpdates = false;
+        }
+
         public async Task<int> ReadRequestHeaderAsync()
         {
             // Receive HEADERS frame for request.
@@ -259,6 +297,12 @@ namespace System.Net.Test.Common
             await WriteFrameAsync(headersFrame).ConfigureAwait(false);
         }
 
+        public async Task SendResponseBodyAsync(int streamId, byte[] data, bool endStream = false)
+        {
+            DataFrame dataFrame = new DataFrame(data, endStream ? FrameFlags.EndStream : FrameFlags.None, 0, streamId);
+            await WriteFrameAsync(dataFrame).ConfigureAwait(false);
+        }
+
         public void Dispose()
         {
             if (_listenSocket != null)
index 8fe1e8d..9c7de43 100644 (file)
@@ -517,6 +517,110 @@ namespace System.Net.Http.Functional.Tests
             }
         }
 
+        private static async Task<int> EstablishConnectionAndProcessOneRequestAsync(HttpClient client, Http2LoopbackServer server)
+        {
+            // Establish connection and send initial request/response to ensure connection is available for subsequent use
+            Task<HttpResponseMessage> sendTask = client.GetAsync(server.Address);
+
+            await server.EstablishConnectionAsync();
+
+            int streamId = await server.ReadRequestHeaderAsync();
+            await server.SendDefaultResponseAsync(streamId);
+
+            HttpResponseMessage response = await sendTask;
+            Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+            Assert.Equal(0, (await response.Content.ReadAsByteArrayAsync()).Length);
+
+            return streamId;
+        }
+
+        [ConditionalFact(nameof(SupportsAlpn))]
+        public async Task GoAwayFrame_NoPendingStreams_ConnectionClosed()
+        {
+            HttpClientHandler handler = CreateHttpClientHandler();
+            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+
+            using (var server = Http2LoopbackServer.CreateServer())
+            using (var client = new HttpClient(handler))
+            {
+                int streamId = await EstablishConnectionAndProcessOneRequestAsync(client, server);
+
+                // Send GOAWAY.
+                GoAwayFrame goAwayFrame = new GoAwayFrame(streamId, 0, new byte[0], 0);
+                await server.WriteFrameAsync(goAwayFrame);
+
+                // The client should close the connection.
+                await server.WaitForConnectionShutdownAsync();
+
+                // New request should cause a new connection
+                await EstablishConnectionAndProcessOneRequestAsync(client, server);
+            }
+        }
+
+        [ConditionalFact(nameof(SupportsAlpn))]
+        public async Task GoAwayFrame_AllPendingStreamsValid_RequestsSucceedAndConnectionClosed()
+        {
+            HttpClientHandler handler = CreateHttpClientHandler();
+            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+
+            using (var server = Http2LoopbackServer.CreateServer())
+            using (var client = new HttpClient(handler))
+            {
+                await EstablishConnectionAndProcessOneRequestAsync(client, server);
+
+                // Issue three requests
+                Task<HttpResponseMessage> sendTask1 = client.GetAsync(server.Address);
+                Task<HttpResponseMessage> sendTask2 = client.GetAsync(server.Address);
+                Task<HttpResponseMessage> sendTask3 = client.GetAsync(server.Address);
+
+                // Receive three requests
+                int streamId1 = await server.ReadRequestHeaderAsync();
+                int streamId2 = await server.ReadRequestHeaderAsync();
+                int streamId3 = await server.ReadRequestHeaderAsync();
+
+                Assert.True(streamId1 < streamId2);
+                Assert.True(streamId2 < streamId3);
+
+                // Send various partial responses
+
+                // First response: Don't send anything yet
+
+                // Second response: Send headers, no body yet
+                await server.SendDefaultResponseHeadersAsync(streamId2);
+
+                // Third response: Send headers, partial body
+                await server.SendDefaultResponseHeadersAsync(streamId3);
+                await server.SendResponseBodyAsync(streamId3, new byte[5], endStream: false);
+
+                // Send a GOAWAY frame that indicates that we will process all three streams
+                GoAwayFrame goAwayFrame = new GoAwayFrame(streamId3, 0, new byte[0], 0);
+                await server.WriteFrameAsync(goAwayFrame);
+
+                // Finish sending responses
+                await server.SendDefaultResponseHeadersAsync(streamId1);
+                await server.SendResponseBodyAsync(streamId1, new byte[10], endStream: true);
+                await server.SendResponseBodyAsync(streamId2, new byte[10], endStream: true);
+                await server.SendResponseBodyAsync(streamId3, new byte[5], endStream: true);
+
+                // Receive all responses
+                HttpResponseMessage response1 = await sendTask1;
+                Assert.Equal(HttpStatusCode.OK, response1.StatusCode);
+                Assert.Equal(10, (await response1.Content.ReadAsByteArrayAsync()).Length);
+                HttpResponseMessage response2 = await sendTask2;
+                Assert.Equal(HttpStatusCode.OK, response2.StatusCode);
+                Assert.Equal(10, (await response2.Content.ReadAsByteArrayAsync()).Length);
+                HttpResponseMessage response3 = await sendTask3;
+                Assert.Equal(HttpStatusCode.OK, response3.StatusCode);
+                Assert.Equal(10, (await response3.Content.ReadAsByteArrayAsync()).Length);
+
+                // Now that all pending responses have been sent, the client should close the connection.
+                await server.WaitForConnectionShutdownAsync();
+
+                // New request should cause a new connection
+                await EstablishConnectionAndProcessOneRequestAsync(client, server);
+            }
+        }
+
         private static async Task<int> ReadToEndOfStream(Http2LoopbackServer server, int streamId)
         {
             int bytesReceived = 0;