Improve ConnectCallback tests (#42562)
authorGeoff Kizer <geoffrek@microsoft.com>
Tue, 22 Sep 2020 16:11:59 +0000 (09:11 -0700)
committerGitHub <noreply@github.com>
Tue, 22 Sep 2020 16:11:59 +0000 (09:11 -0700)
* make existing tests run on HTTP2

* add exception test

Co-authored-by: Geoffrey Kizer <geoffrek@windows.microsoft.com>
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs

index 959331e..b356b55 100644 (file)
@@ -2297,6 +2297,8 @@ namespace System.Net.Http.Functional.Tests
                 async uri =>
                 {
                     HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri);
+                    requestMessage.Version = UseVersion;
+                    requestMessage.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
 
                     using HttpClientHandler handler = CreateHttpClientHandler();
                     handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
@@ -2342,6 +2344,7 @@ namespace System.Net.Http.Functional.Tests
                     };
 
                     using HttpClient client = CreateHttpClient(handler);
+                    client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
 
                     string response = await client.GetStringAsync(uri);
                     Assert.Equal("foo", response);
@@ -2382,6 +2385,7 @@ namespace System.Net.Http.Functional.Tests
                 socketsHandler.ConnectCallback = (context, token) => new ValueTask<Stream>(clientStream);
 
                 using HttpClient client = CreateHttpClient(handler);
+                client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
 
                 string response = await client.GetStringAsync($"{(options.UseSsl ? "https" : "http")}://nowhere.invalid/foo");
                 Assert.Equal("foo", response);
@@ -2418,6 +2422,7 @@ namespace System.Net.Http.Functional.Tests
             };
 
             using HttpClient client = CreateHttpClient(handler);
+            client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
 
             Task<string> clientTask = client.GetStringAsync($"{(options.UseSsl ? "https" : "http")}://{guid}/foo");
 
@@ -2434,9 +2439,13 @@ namespace System.Net.Http.Functional.Tests
             Assert.Equal("foo", response);
         }
 
-        [Fact]
-        public async Task ConnectCallback_ConnectionPrefix_Success()
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ConnectCallback_ConnectionPrefix_Success(bool useSsl)
         {
+            GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+
             byte[] RequestPrefix = Encoding.UTF8.GetBytes("request prefix\r\n");
             byte[] ResponsePrefix = Encoding.UTF8.GetBytes("response prefix\r\n");
 
@@ -2445,6 +2454,7 @@ namespace System.Net.Http.Functional.Tests
             listenSocket.Listen();
 
             using HttpClientHandler handler = CreateHttpClientHandler();
+            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
             var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
             socketsHandler.ConnectCallback = async (context, token) =>
             {
@@ -2463,8 +2473,9 @@ namespace System.Net.Http.Functional.Tests
             };
 
             using HttpClient client = CreateHttpClient(handler);
+            client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
 
-            Task<string> clientTask = client.GetStringAsync($"http://nowhere.invalid/foo");
+            Task<string> clientTask = client.GetStringAsync($"{(options.UseSsl ? "https" : "http")}://nowhere.invalid/foo");
 
             Socket serverSocket = await listenSocket.AcceptAsync();
             Stream serverStream = new NetworkStream(serverSocket, ownsSocket: true);
@@ -2475,7 +2486,7 @@ namespace System.Net.Http.Functional.Tests
 
             await serverStream.WriteAsync(ResponsePrefix);
 
-            using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream);
+            using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options);
             await loopbackConnection.InitializeConnectionAsync();
 
             HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync();
@@ -2487,6 +2498,28 @@ namespace System.Net.Http.Functional.Tests
             Assert.Equal("foo", response);
         }
 
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ConnectCallback_ExceptionDuringCallback_ThrowsHttpRequestExceptionWithInnerException(bool useSsl)
+        {
+            Exception e = new Exception("hello!");
+
+            using HttpClientHandler handler = CreateHttpClientHandler();
+            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+            var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+            socketsHandler.ConnectCallback = (context, token) =>
+            {
+                throw e;
+            };
+
+            using HttpClient client = CreateHttpClient(handler);
+            client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+            HttpRequestException hre = await Assert.ThrowsAnyAsync<HttpRequestException>(async () => await client.GetAsync($"{(useSsl ? "https" : "http")}://nowhere.invalid/foo"));
+            Assert.Equal(e, hre.InnerException);
+        }
+
         private static bool PlatformSupportsUnixDomainSockets => Socket.OSSupportsUnixDomainSockets;
    }
 
@@ -2499,6 +2532,7 @@ namespace System.Net.Http.Functional.Tests
     public sealed class SocketsHttpHandlerTest_ConnectCallback_Http2 : SocketsHttpHandlerTest_ConnectCallback
     {
         public SocketsHttpHandlerTest_ConnectCallback_Http2(ITestOutputHelper output) : base(output) { }
+        protected override Version UseVersion => HttpVersion.Version20;
     }
 
     [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]