Fix Http2 MultiConnection test race conditions (#91343)
authorgithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Wed, 30 Aug 2023 19:40:56 +0000 (12:40 -0700)
committerGitHub <noreply@github.com>
Wed, 30 Aug 2023 19:40:56 +0000 (12:40 -0700)
Co-authored-by: Miha Zupan <mihazupan.zupan1@gmail.com>
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs
src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs

index 602d177..888b38b 100644 (file)
@@ -3,14 +3,31 @@
 
 using System.IO;
 using System.Net.Quic;
+using System.Net.Sockets;
 using System.Net.Test.Common;
 using System.Reflection;
+using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Net.Http.Functional.Tests
 {
     public abstract partial class HttpClientHandlerTestBase : FileCleanupTestBase
     {
+        protected static async Task<Stream> DefaultConnectCallback(EndPoint endPoint, CancellationToken cancellationToken)
+        {
+            Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
+            try
+            {
+                await socket.ConnectAsync(endPoint, cancellationToken);
+                return new NetworkStream(socket, ownsSocket: true);
+            }
+            catch
+            {
+                socket.Dispose();
+                throw;
+            }
+        }
+
         protected static bool IsWinHttpHandler => false;
 
         public static bool IsQuicSupported
index f3a97d2..da36366 100644 (file)
@@ -292,17 +292,8 @@ namespace System.Net.Http.Functional.Tests
                 GetUnderlyingSocketsHttpHandler(Handler).ConnectCallback = async (ctx, cancellationToken) =>
                 {
                     connectionStarted.SetResult();
-                    Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
-                    try
-                    {
-                        await socket.ConnectAsync(ctx.DnsEndPoint, cancellationToken);
-                        return new NetworkStream(socket, ownsSocket: true);
-                    }
-                    catch
-                    {
-                        socket.Dispose();
-                        throw;
-                    }
+
+                    return await DefaultConnectCallback(ctx.DnsEndPoint, cancellationToken);
                 };
 
                 // Enable recording request-duration to test the path with metrics enabled.
index c793a1d..76d7086 100644 (file)
@@ -165,9 +165,7 @@ namespace System.Net.Http.Functional.Tests
                         else
                         {
                             // Succeed the second connection attempt
-                            Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
-                            await socket.ConnectAsync(context.DnsEndPoint, token);
-                            return new NetworkStream(socket, ownsSocket: true);
+                            return await DefaultConnectCallback(context.DnsEndPoint, token);
                         }
                     };
 
index e25f695..2613b45 100644 (file)
@@ -1369,17 +1369,7 @@ namespace System.Net.Http.Functional.Tests
                 {
                     Assert.Equal("foo", context.DnsEndPoint.Host);
 
-                    Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
-                    try
-                    {
-                        await socket.ConnectAsync(lastServerUri.IdnHost, lastServerUri.Port);
-                        return new NetworkStream(socket, ownsSocket: true);
-                    }
-                    catch
-                    {
-                        socket.Dispose();
-                        throw;
-                    }
+                    return await DefaultConnectCallback(new DnsEndPoint(lastServerUri.IdnHost, lastServerUri.Port), ct);
                 };
 
                 TaskCompletionSource waitingForLastRequest = new(TaskCreationOptions.RunContinuationsAsynchronously);
@@ -2659,30 +2649,18 @@ namespace System.Net.Http.Functional.Tests
 
             AcquireAllStreamSlots(server, client, sendTasks, RequestCount);
 
-            List<(Http2LoopbackConnection connection, int streamId)> acceptedRequests = new();
-
             await using Http2LoopbackConnection c1 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
-            for (int i = 0; i < MaxConcurrentStreams; i++)
-            {
-                (int streamId, _) = await c1.ReadAndParseRequestHeaderAsync();
-                acceptedRequests.Add((c1, streamId));
-            }
+            int[] streamIds1 = await AcceptRequests(c1, MaxConcurrentStreams);
 
             await using Http2LoopbackConnection c2 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
-            for (int i = 0; i < MaxConcurrentStreams; i++)
-            {
-                (int streamId, _) = await c2.ReadAndParseRequestHeaderAsync();
-                acceptedRequests.Add((c2, streamId));
-            }
+            int[] streamIds2 = await AcceptRequests(c2, MaxConcurrentStreams);
 
             await using Http2LoopbackConnection c3 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
             (int finalStreamId, _) = await c3.ReadAndParseRequestHeaderAsync();
-            acceptedRequests.Add((c3, finalStreamId));
 
-            foreach ((Http2LoopbackConnection connection, int streamId) request in acceptedRequests)
-            {
-                await request.connection.SendDefaultResponseAsync(request.streamId);
-            }
+            await SendResponses(c1, streamIds1);
+            await SendResponses(c2, streamIds2);
+            await c3.SendDefaultResponseAsync(finalStreamId);
 
             await VerifySendTasks(sendTasks);
         }
@@ -2702,19 +2680,17 @@ namespace System.Net.Http.Functional.Tests
             Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
             AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
 
-            // Block the first connection on infinite requests.
+            // Accept requests but don't send responses on connection 0
             int[] blockedStreamIds = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false);
-            Assert.Equal(MaxConcurrentStreams, blockedStreamIds.Length);
 
             Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
             AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
 
-            await HandleAllPendingRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false);
+            // Send responses on connection 1
+            await SendResponses(connection1, await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false));
 
-            // Complete infinite requests.
-            int handledRequestCount = await SendResponses(connection0, blockedStreamIds);
-
-            Assert.Equal(MaxConcurrentStreams, handledRequestCount);
+            // Send responses on connection 0
+            await SendResponses(connection0, blockedStreamIds);
 
             await VerifySendTasks(sendTasks).ConfigureAwait(false);
         }
@@ -2729,44 +2705,62 @@ namespace System.Net.Http.Functional.Tests
 
             const int MaxConcurrentStreams = 2;
             using Http2LoopbackServer server = Http2LoopbackServer.CreateServer();
+            server.AllowMultipleConnections = true;
+
+            // Allow 5 connections through the ConnectCallback.
+            SemaphoreSlim connectCallbackSemaphore = new(initialCount: 5);
+
             using SocketsHttpHandler handler = CreateHandler();
+
+            handler.ConnectCallback = async (context, ct) =>
+            {
+                await connectCallbackSemaphore.WaitAsync(ct);
+
+                return await DefaultConnectCallback(context.DnsEndPoint, ct);
+            };
+
             using (HttpClient client = CreateHttpClient(handler))
             {
-                server.AllowMultipleConnections = true;
-                List<Task<HttpResponseMessage>> sendTasks = new List<Task<HttpResponseMessage>>();
+                List<Task<HttpResponseMessage>> sendTasks = new();
+
                 Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
                 AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
+                int[] streamIds0 = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false);
+
                 Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
                 AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
+                int[] streamIds1 = await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false);
+
                 Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
                 AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
+                int[] streamIds2 = await AcceptRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false);
 
-                Task<int>[] handleRequestTasks = new[] {
-                    HandleAllPendingRequests(connection0, MaxConcurrentStreams),
-                    HandleAllPendingRequests(connection1, MaxConcurrentStreams),
-                    HandleAllPendingRequests(connection2, MaxConcurrentStreams)
-                };
-
-                await TestHelper.WhenAllCompletedOrAnyFailed(handleRequestTasks).ConfigureAwait(false);
+                await TestHelper.WhenAllCompletedOrAnyFailed(
+                    SendResponses(connection0, streamIds0),
+                    SendResponses(connection1, streamIds1),
+                    SendResponses(connection2, streamIds2))
+                    .ConfigureAwait(false);
 
-                await connection0.ShutdownIgnoringErrorsAsync(await handleRequestTasks[0]).ConfigureAwait(false);
-                await connection2.ShutdownIgnoringErrorsAsync(await handleRequestTasks[2]).ConfigureAwait(false);
+                await connection0.ShutdownIgnoringErrorsAsync(streamIds0[^1]).ConfigureAwait(false);
+                await connection2.ShutdownIgnoringErrorsAsync(streamIds2[^1]).ConfigureAwait(false);
 
-                //Fill all connection1's stream slots
+                // Fill all connection1's stream slots
                 AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
+                streamIds1 = await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false);
 
                 Http2LoopbackConnection connection3 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
                 AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
+                int[] streamIds3 = await AcceptRequests(connection3, MaxConcurrentStreams).ConfigureAwait(false);
+
                 Http2LoopbackConnection connection4 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
                 AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
+                int[] streamIds4 = await AcceptRequests(connection4, MaxConcurrentStreams).ConfigureAwait(false);
 
-                Task[] finalHandleTasks = new[] {
-                    HandleAllPendingRequests(connection1, MaxConcurrentStreams),
-                    HandleAllPendingRequests(connection3, MaxConcurrentStreams),
-                    HandleAllPendingRequests(connection4, MaxConcurrentStreams)
-                };
-
-                await TestHelper.WhenAllCompletedOrAnyFailed(finalHandleTasks).ConfigureAwait(false);
+                await TestHelper.WhenAllCompletedOrAnyFailed(
+                   SendResponses(connection1, streamIds1),
+                   SendResponses(connection3, streamIds3),
+                   SendResponses(connection4, streamIds4))
+                   .ConfigureAwait(false);
 
                 await VerifySendTasks(sendTasks).ConfigureAwait(false);
             }
@@ -2778,24 +2772,36 @@ namespace System.Net.Http.Functional.Tests
         {
             const int MaxConcurrentStreams = 2;
             using Http2LoopbackServer server = Http2LoopbackServer.CreateServer();
+            server.AllowMultipleConnections = true;
+
+            SemaphoreSlim connectCallbackSemaphore = new(initialCount: 2);
+
             using SocketsHttpHandler handler = CreateHandler();
             handler.PooledConnectionIdleTimeout = TimeSpan.FromSeconds(20);
+
+            handler.ConnectCallback = async (context, ct) =>
+            {
+                await connectCallbackSemaphore.WaitAsync(ct);
+
+                return await DefaultConnectCallback(context.DnsEndPoint, ct);
+            };
+
             using (HttpClient client = CreateHttpClient(handler))
             {
-                server.AllowMultipleConnections = true;
-                List<Task<HttpResponseMessage>> sendTasks = new List<Task<HttpResponseMessage>>();
+                List<Task<HttpResponseMessage>> sendTasks0 = new();
+                List<Task<HttpResponseMessage>> sendTasks1 = new();
+                List<Task<HttpResponseMessage>> sendTasks2 = new();
+
                 Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
-                AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
-                int[] acceptedStreamIds = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false);
-                Assert.Equal(MaxConcurrentStreams, acceptedStreamIds.Length);
+                AcquireAllStreamSlots(server, client, sendTasks0, MaxConcurrentStreams);
+                int[] streamIds0 = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false);
 
-                List<Task<HttpResponseMessage>> connection1SendTasks = new List<Task<HttpResponseMessage>>();
                 Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
-                AcquireAllStreamSlots(server, client, connection1SendTasks, MaxConcurrentStreams);
-                await HandleAllPendingRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false);
+                AcquireAllStreamSlots(server, client, sendTasks1, MaxConcurrentStreams);
+                await SendResponses(connection1, await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false));
 
-                // Complete all the requests.
-                await VerifySendTasks(connection1SendTasks).ConfigureAwait(false);
+                // Complete all the requests on connection1.
+                await VerifySendTasks(sendTasks1).ConfigureAwait(false);
 
                 // Wait until the idle connection timeout expires.
                 await connection1.WaitForClientDisconnectAsync(false).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false);
@@ -2803,17 +2809,20 @@ namespace System.Net.Http.Functional.Tests
                 Assert.True(connection1.IsInvalid);
                 Assert.False(connection0.IsInvalid);
 
-                Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
-
-                AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams);
+                // Due to a race condition in how a new Http2 connection is returned to the pool, we may have started a third connection attempt in the background.
+                // We were blocking such attempts from going through to the Socket layer until now to avoid having to deal with the extra connect when accepting connection2 below.
+                // Allow the third connection through the ConnectCallback now.
+                connectCallbackSemaphore.Release();
 
-                await HandleAllPendingRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false);
+                Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false);
+                AcquireAllStreamSlots(server, client, sendTasks2, MaxConcurrentStreams);
+                await SendResponses(connection2, await AcceptRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false));
 
-                //Make sure connection0 is still alive.
-                int handledRequests0 = await SendResponses(connection0, acceptedStreamIds).ConfigureAwait(false);
-                Assert.Equal(MaxConcurrentStreams, handledRequests0);
+                // Make sure connection0 is still alive.
+                await SendResponses(connection0, streamIds0).ConfigureAwait(false);
 
-                await VerifySendTasks(sendTasks).ConfigureAwait(false);
+                await VerifySendTasks(sendTasks0).ConfigureAwait(false);
+                await VerifySendTasks(sendTasks2).ConfigureAwait(false);
             }
         }
 
@@ -2842,7 +2851,10 @@ namespace System.Net.Http.Functional.Tests
 
             Task<HttpResponseMessage> warmUpTask = client.GetAsync(server.Address);
 
-            Http2LoopbackConnection connection = await GetConnection(server, maxConcurrentStreams).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false);
+            var concurrentStreamsSetting = new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = maxConcurrentStreams };
+
+            Http2LoopbackConnection connection = await server.EstablishConnectionAsync(timeout: null, ackTimeout: TimeSpan.FromSeconds(10), concurrentStreamsSetting)
+                .WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false);
 
             (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false);
             await connection.SendDefaultResponseAsync(streamId).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false);
@@ -2862,49 +2874,25 @@ namespace System.Net.Http.Functional.Tests
             }
         }
 
-        private static async Task<Http2LoopbackConnection> GetConnection(Http2LoopbackServer server, uint maxConcurrentStreams)
-        {
-            var concurrentStreamsSetting = new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = maxConcurrentStreams };
-
-            return await server.EstablishConnectionAsync(timeout: null, ackTimeout: TimeSpan.FromSeconds(10), concurrentStreamsSetting).ConfigureAwait(false);
-        }
-
-        private async Task<int> HandleAllPendingRequests(Http2LoopbackConnection connection, int totalRequestCount)
-        {
-            int lastStreamId = -1;
-            for (int i = 0; i < totalRequestCount; i++)
-            {
-                (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().ConfigureAwait(false);
-                await connection.SendDefaultResponseAsync(streamId).ConfigureAwait(false);
-                lastStreamId = streamId;
-            }
-
-            return lastStreamId;
-        }
-
         private async Task<int[]> AcceptRequests(Http2LoopbackConnection connection, int requestCount)
         {
             int[] streamIds = new int[requestCount];
 
             for (int i = 0; i < streamIds.Length; i++)
             {
-                (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().ConfigureAwait(false);
+                (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false);
                 streamIds[i] = streamId;
             }
 
             return streamIds;
         }
 
-        private async Task<int> SendResponses(Http2LoopbackConnection connection, IEnumerable<int> streamIds)
+        private async Task SendResponses(Http2LoopbackConnection connection, IEnumerable<int> streamIds)
         {
-            int count = 0;
             foreach (int streamId in streamIds)
             {
-                count++;
-                await connection.SendDefaultResponseAsync(streamId).ConfigureAwait(false);
+                await connection.SendDefaultResponseAsync(streamId).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false);
             }
-
-            return count;
         }
     }
 
@@ -3108,10 +3096,7 @@ namespace System.Net.Http.Functional.Tests
             var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
             socketsHandler.ConnectCallback = async (context, token) =>
             {
-                Socket clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
-                await clientSocket.ConnectAsync(listenSocket.LocalEndPoint);
-
-                Stream clientStream = new NetworkStream(clientSocket, ownsSocket: true);
+                Stream clientStream = await DefaultConnectCallback(listenSocket.LocalEndPoint, token);
 
                 await clientStream.WriteAsync(RequestPrefix);