From 8b8c390755189d45efc0c407992cb7c006b802b5 Mon Sep 17 00:00:00 2001 From: Geoff Kizer Date: Thu, 10 Sep 2020 09:13:39 -0700 Subject: [PATCH] Implement SocketsHttpHandler.ConnectCallback (#41955) * Implement SocketsHttpHandler.ConnectCallback --- .../System.Net.Http/ref/System.Net.Http.cs | 8 + .../System.Net.Http/src/Resources/Strings.resx | 3 + .../System.Net.Http/src/System.Net.Http.csproj | 3 +- .../Http/BrowserHttpHandler/SocketsHttpHandler.cs | 7 + .../Net/Http/SocketsHttpHandler/ConnectHelper.cs | 4 +- .../Http/SocketsHttpHandler/HttpConnectionPool.cs | 28 ++- .../SocketsHttpHandler/HttpConnectionSettings.cs | 4 + .../SocketsHttpHandler/SocketsConnectionFactory.cs | 106 --------- .../SocketsHttpConnectionContext.cs | 30 +++ .../Http/SocketsHttpHandler/SocketsHttpHandler.cs | 14 ++ .../FunctionalTests/SocketsHttpHandlerTest.cs | 249 +++++++++++++++++++++ 11 files changed, 344 insertions(+), 112 deletions(-) delete mode 100644 src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsConnectionFactory.cs create mode 100644 src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpConnectionContext.cs diff --git a/src/libraries/System.Net.Http/ref/System.Net.Http.cs b/src/libraries/System.Net.Http/ref/System.Net.Http.cs index cc6c899..7469866 100644 --- a/src/libraries/System.Net.Http/ref/System.Net.Http.cs +++ b/src/libraries/System.Net.Http/ref/System.Net.Http.cs @@ -360,7 +360,15 @@ namespace System.Net.Http protected internal override System.Net.Http.HttpResponseMessage Send(System.Net.Http.HttpRequestMessage request, System.Threading.CancellationToken cancellationToken) { throw null; } protected internal override System.Threading.Tasks.Task SendAsync(System.Net.Http.HttpRequestMessage request, System.Threading.CancellationToken cancellationToken) { throw null; } public bool EnableMultipleHttp2Connections { get { throw null; } set { } } + public Func>? ConnectCallback { get { throw null; } set { } } } + public sealed class SocketsHttpConnectionContext + { + internal SocketsHttpConnectionContext() { } + public DnsEndPoint DnsEndPoint { get { throw null; } } + public HttpRequestMessage RequestMessage { get { throw null; } } + } + public enum HttpKeepAlivePingPolicy { WithActiveRequests, diff --git a/src/libraries/System.Net.Http/src/Resources/Strings.resx b/src/libraries/System.Net.Http/src/Resources/Strings.resx index 84bc59a..f3d2f81 100644 --- a/src/libraries/System.Net.Http/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Http/src/Resources/Strings.resx @@ -585,4 +585,7 @@ Requesting HTTP version {0} with version policy {1} while server offers only version fallback. + + Synchronous operation is not supported when a ConnectCallback is specified on the SocketsHttpHandler instance. + diff --git a/src/libraries/System.Net.Http/src/System.Net.Http.csproj b/src/libraries/System.Net.Http/src/System.Net.Http.csproj index b77e735..039686d 100644 --- a/src/libraries/System.Net.Http/src/System.Net.Http.csproj +++ b/src/libraries/System.Net.Http/src/System.Net.Http.csproj @@ -173,7 +173,7 @@ - + + diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs b/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs index 1d7ae3a..5322486 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.IO; using System.Net.Security; using System.Threading; using System.Threading.Tasks; @@ -170,5 +171,11 @@ namespace System.Net.Http get => throw new PlatformNotSupportedException(); set => throw new PlatformNotSupportedException(); } + + public Func>? ConnectCallback + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs index 931b888..777b6e5 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs @@ -31,11 +31,11 @@ namespace System.Net.Http } } - public static async ValueTask ConnectAsync(SocketsConnectionFactory factory, DnsEndPoint endPoint, CancellationToken cancellationToken) + public static async ValueTask ConnectAsync(Func> callback, DnsEndPoint endPoint, HttpRequestMessage requestMessage, CancellationToken cancellationToken) { try { - return await factory.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false); + return await callback(new SocketsHttpConnectionContext(endPoint, requestMessage), cancellationToken).ConfigureAwait(false); } catch (OperationCanceledException ex) when (ex.CancellationToken == cancellationToken) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index df4332c..233b590 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -1286,20 +1286,42 @@ namespace System.Net.Http } } - private static readonly SocketsConnectionFactory s_defaultConnectionFactory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + private static async ValueTask DefaultConnectAsync(SocketsHttpConnectionContext context, CancellationToken cancellationToken) + { + Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + socket.NoDelay = true; + + try + { + await socket.ConnectAsync(context.DnsEndPoint, cancellationToken).ConfigureAwait(false); + return new NetworkStream(socket, ownsSocket: true); + } + catch + { + socket.Dispose(); + throw; + } + } + + private static readonly Func> s_defaultConnectCallback = DefaultConnectAsync; private ValueTask ConnectToTcpHostAsync(string host, int port, HttpRequestMessage initialRequest, bool async, CancellationToken cancellationToken) { if (async) { - SocketsConnectionFactory connectionFactory = s_defaultConnectionFactory; + Func> connectCallback = Settings._connectCallback ?? s_defaultConnectCallback; var endPoint = new DnsEndPoint(host, port); - return ConnectHelper.ConnectAsync(connectionFactory, endPoint, cancellationToken); + return ConnectHelper.ConnectAsync(connectCallback, endPoint, initialRequest, cancellationToken); } // Synchronous path. + if (Settings._connectCallback is not null) + { + throw new NotSupportedException(SR.net_http_sync_operations_not_allowed_with_connect_callback); + } + try { return new ValueTask(ConnectHelper.Connect(host, port, cancellationToken)); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs index 177d8bc..00ce715 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Net.Security; +using System.IO; using System.Threading; using System.Threading.Tasks; @@ -55,6 +56,8 @@ namespace System.Net.Http internal bool _enableMultipleHttp2Connections; + internal Func>? _connectCallback; + internal IDictionary? _properties; public HttpConnectionSettings() @@ -108,6 +111,7 @@ namespace System.Net.Http _requestHeaderEncodingSelector = _requestHeaderEncodingSelector, _responseHeaderEncodingSelector = _responseHeaderEncodingSelector, _enableMultipleHttp2Connections = _enableMultipleHttp2Connections, + _connectCallback = _connectCallback, }; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsConnectionFactory.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsConnectionFactory.cs deleted file mode 100644 index 98c8e89..0000000 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsConnectionFactory.cs +++ /dev/null @@ -1,106 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.IO; -using System.Net.Sockets; -using System.Threading; -using System.Threading.Tasks; - -namespace System.Net.Http -{ - /// - /// A factory to establish socket-based connections. - /// - /// - /// When constructed with , this factory will create connections with enabled. - /// In case of IPv6 sockets is also enabled. - /// - internal sealed class SocketsConnectionFactory - { - private readonly AddressFamily _addressFamily; - private readonly SocketType _socketType; - private readonly ProtocolType _protocolType; - - /// - /// Initializes a new instance of the class. - /// - /// The to forward to the socket. - /// The to forward to the socket. - /// The to forward to the socket. - public SocketsConnectionFactory( - AddressFamily addressFamily, - SocketType socketType, - ProtocolType protocolType) - { - _addressFamily = addressFamily; - _socketType = socketType; - _protocolType = protocolType; - } - - /// - /// Initializes a new instance of the class - /// that will forward to the Socket constructor. - /// - /// The to forward to the socket. - /// The to forward to the socket. - /// The created socket will be an IPv6 socket with enabled. - public SocketsConnectionFactory(SocketType socketType, ProtocolType protocolType) - : this(AddressFamily.InterNetworkV6, socketType, protocolType) - { - } - - public async ValueTask ConnectAsync( - EndPoint? endPoint, - CancellationToken cancellationToken = default) - { - if (endPoint == null) throw new ArgumentNullException(nameof(endPoint)); - cancellationToken.ThrowIfCancellationRequested(); - - Socket socket = CreateSocket(_addressFamily, _socketType, _protocolType, endPoint); - - try - { - await socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false); - return new NetworkStream(socket, true); - } - catch - { - socket.Dispose(); - throw; - } - } - - /// - /// Creates the socket that shall be used with the connection. - /// - /// The to forward to the socket. - /// The to forward to the socket. - /// The to forward to the socket. - /// The this socket will be connected to. - /// A new unconnected . - /// - /// In case of TCP sockets, the default implementation of this method will create a socket with enabled. - /// In case of IPv6 sockets is also be enabled. - /// - private Socket CreateSocket( - AddressFamily addressFamily, - SocketType socketType, - ProtocolType protocolType, - EndPoint? endPoint) - { - Socket socket = new Socket(addressFamily, socketType, protocolType); - - if (protocolType == ProtocolType.Tcp) - { - socket.NoDelay = true; - } - - if (addressFamily == AddressFamily.InterNetworkV6) - { - socket.DualMode = true; - } - - return socket; - } - } -} diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpConnectionContext.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpConnectionContext.cs new file mode 100644 index 0000000..fbd38df --- /dev/null +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpConnectionContext.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.Http +{ + /// + /// Represents the context passed to the ConnectCallback for a SocketsHttpHandler instance. + /// + public sealed class SocketsHttpConnectionContext + { + private readonly DnsEndPoint _dnsEndPoint; + private readonly HttpRequestMessage _requestMessage; + + internal SocketsHttpConnectionContext(DnsEndPoint dnsEndPoint, HttpRequestMessage requestMessage) + { + _dnsEndPoint = dnsEndPoint; + _requestMessage = requestMessage; + } + + /// + /// The DnsEndPoint to be used by the ConnectCallback to establish the connection. + /// + public DnsEndPoint DnsEndPoint => _dnsEndPoint; + + /// + /// The initial HttpRequestMessage that is causing the connection to be created. + /// + public HttpRequestMessage RequestMessage => _requestMessage; + } +} diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs index 2aa06f2..5975c1c 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; +using System.IO; using System.Net.Security; using System.Threading; using System.Threading.Tasks; @@ -362,6 +363,19 @@ namespace System.Net.Http internal bool SupportsProxy => true; internal bool SupportsRedirectConfiguration => true; + /// + /// When non-null, a custom callback used to open new connections. + /// + public Func>? ConnectCallback + { + get => _settings._connectCallback; + set + { + CheckDisposedOrStarted(); + _settings._connectCallback = value; + } + } + public IDictionary Properties => _settings._properties ?? (_settings._properties = new Dictionary()); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index bff227f..959331e 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -10,6 +10,7 @@ using System.Net.Security; using System.Net.Sockets; using System.Net.Test.Common; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Text; @@ -1835,6 +1836,23 @@ namespace System.Net.Http.Functional.Tests } } + [Fact] + public void ConnectCallback_GetSet_Roundtrips() + { + using (var handler = new SocketsHttpHandler()) + { + Assert.Null(handler.ConnectCallback); + + Func> f = (context, token) => default; + + handler.ConnectCallback = f; + Assert.Equal(f, handler.ConnectCallback); + + handler.ConnectCallback = null; + Assert.Null(handler.ConnectCallback); + } + } + [Theory] [InlineData(false)] [InlineData(true)] @@ -1872,6 +1890,7 @@ namespace System.Net.Http.Functional.Tests Assert.NotNull(handler.SslOptions); Assert.True(handler.UseCookies); Assert.True(handler.UseProxy); + Assert.Null(handler.ConnectCallback); Assert.Throws(expectedExceptionType, () => handler.AllowAutoRedirect = false); Assert.Throws(expectedExceptionType, () => handler.AutomaticDecompression = DecompressionMethods.GZip); @@ -1891,6 +1910,7 @@ namespace System.Net.Http.Functional.Tests Assert.Throws(expectedExceptionType, () => handler.KeepAlivePingTimeout = TimeSpan.FromSeconds(5)); Assert.Throws(expectedExceptionType, () => handler.KeepAlivePingDelay = TimeSpan.FromSeconds(5)); Assert.Throws(expectedExceptionType, () => handler.KeepAlivePingPolicy = HttpKeepAlivePingPolicy.WithActiveRequests); + Assert.Throws(expectedExceptionType, () => handler.ConnectCallback = (context, token) => default); } } } @@ -2252,6 +2272,235 @@ namespace System.Net.Http.Functional.Tests } } + + public abstract class SocketsHttpHandlerTest_ConnectCallback : HttpClientHandlerTestBase + { + public SocketsHttpHandlerTest_ConnectCallback(ITestOutputHelper output) : base(output) { } + + [Fact] + public void ConnectCallback_SyncRequest_Fails() + { + using SocketsHttpHandler handler = new SocketsHttpHandler + { + ConnectCallback = (context, token) => default, + }; + + using HttpClient client = CreateHttpClient(handler); + + Assert.ThrowsAny(() => client.Send(new HttpRequestMessage(HttpMethod.Get, $"http://bing.com"))); + } + + [Fact] + public async void ConnectCallback_ContextHasCorrectProperties_Success() + { + await LoopbackServerFactory.CreateClientAndServerAsync( + async uri => + { + HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri); + + using HttpClientHandler handler = CreateHttpClientHandler(); + handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; + var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); + socketsHandler.ConnectCallback = async (context, token) => + { + Assert.Equal(uri.Host, context.DnsEndPoint.Host); + Assert.Equal(uri.Port, context.DnsEndPoint.Port); + Assert.Equal(requestMessage, context.RequestMessage); + + Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await s.ConnectAsync(context.DnsEndPoint, token); + return new NetworkStream(s, ownsSocket: true); + }; + + using HttpClient client = CreateHttpClient(handler); + + HttpResponseMessage response = await client.SendAsync(requestMessage); + Assert.Equal("foo", await response.Content.ReadAsStringAsync()); + }, + async server => + { + await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo"); + }); + } + + [Fact] + public async Task ConnectCallback_BindLocalAddress_Success() + { + await LoopbackServerFactory.CreateClientAndServerAsync( + async uri => + { + using HttpClientHandler handler = CreateHttpClientHandler(); + handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; + var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); + socketsHandler.ConnectCallback = async (context, token) => + { + Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + s.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + await s.ConnectAsync(context.DnsEndPoint, token); + s.NoDelay = true; + return new NetworkStream(s, ownsSocket: true); + }; + + using HttpClient client = CreateHttpClient(handler); + + string response = await client.GetStringAsync(uri); + Assert.Equal("foo", response); + }, + async server => + { + await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo"); + }); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ConnectCallback_UseVirtualNetwork_Success(bool useSsl) + { + var vn = new VirtualNetwork(); + using var clientStream = new VirtualNetworkStream(vn, isServer: false, gracefulShutdown: true); + using var serverStream = new VirtualNetworkStream(vn, isServer: true, gracefulShutdown: true); + + GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl }; + + Task serverTask = Task.Run(async () => + { + using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options); + await loopbackConnection.InitializeConnectionAsync(); + + HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync(); + await loopbackConnection.SendResponseAsync(content: "foo"); + + Assert.Equal("/foo", requestData.Path); + }); + + Task clientTask = Task.Run(async () => + { + using HttpClientHandler handler = CreateHttpClientHandler(); + handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; + var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); + socketsHandler.ConnectCallback = (context, token) => new ValueTask(clientStream); + + using HttpClient client = CreateHttpClient(handler); + + string response = await client.GetStringAsync($"{(options.UseSsl ? "https" : "http")}://nowhere.invalid/foo"); + Assert.Equal("foo", response); + }); + + await new[] { serverTask, clientTask }.WhenAllOrAnyFailed(60_000); + } + + [ConditionalTheory(nameof(PlatformSupportsUnixDomainSockets))] + [InlineData(true)] + [InlineData(false)] + public async Task ConnectCallback_UseUnixDomainSocket_Success(bool useSsl) + { + GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl }; + + string guid = $"{Guid.NewGuid():N}"; + UnixDomainSocketEndPoint serverEP = new UnixDomainSocketEndPoint(Path.Combine(Path.GetTempPath(), guid)); + Socket listenSocket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + listenSocket.Bind(serverEP); + listenSocket.Listen(); + + using HttpClientHandler handler = CreateHttpClientHandler(); + handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; + var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); + socketsHandler.ConnectCallback = async (context, token) => + { + string hostname = context.DnsEndPoint.Host; + UnixDomainSocketEndPoint clientEP = new UnixDomainSocketEndPoint(Path.Combine(Path.GetTempPath(), hostname)); + + Socket clientSocket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + await clientSocket.ConnectAsync(clientEP); + + return new NetworkStream(clientSocket, ownsSocket: true); + }; + + using HttpClient client = CreateHttpClient(handler); + + Task clientTask = client.GetStringAsync($"{(options.UseSsl ? "https" : "http")}://{guid}/foo"); + + Socket serverSocket = await listenSocket.AcceptAsync(); + using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, new NetworkStream(serverSocket, ownsSocket: true), options); + await loopbackConnection.InitializeConnectionAsync(); + + HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync(); + Assert.Equal("/foo", requestData.Path); + + await loopbackConnection.SendResponseAsync(content: "foo"); + + string response = await clientTask; + Assert.Equal("foo", response); + } + + [Fact] + public async Task ConnectCallback_ConnectionPrefix_Success() + { + byte[] RequestPrefix = Encoding.UTF8.GetBytes("request prefix\r\n"); + byte[] ResponsePrefix = Encoding.UTF8.GetBytes("response prefix\r\n"); + + Socket listenSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listenSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listenSocket.Listen(); + + using HttpClientHandler handler = CreateHttpClientHandler(); + var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); + socketsHandler.ConnectCallback = async (context, token) => + { + Socket clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await clientSocket.ConnectAsync(listenSocket.LocalEndPoint); + + Stream clientStream = new NetworkStream(clientSocket, ownsSocket: true); + + await clientStream.WriteAsync(RequestPrefix); + + byte[] buffer = new byte[ResponsePrefix.Length]; + await clientStream.ReadAsync(buffer); + Assert.True(buffer.SequenceEqual(ResponsePrefix)); + + return clientStream; + }; + + using HttpClient client = CreateHttpClient(handler); + + Task clientTask = client.GetStringAsync($"http://nowhere.invalid/foo"); + + Socket serverSocket = await listenSocket.AcceptAsync(); + Stream serverStream = new NetworkStream(serverSocket, ownsSocket: true); + + byte[] buffer = new byte[RequestPrefix.Length]; + await serverStream.ReadAsync(buffer); + Assert.True(buffer.SequenceEqual(RequestPrefix)); + + await serverStream.WriteAsync(ResponsePrefix); + + using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream); + await loopbackConnection.InitializeConnectionAsync(); + + HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync(); + Assert.Equal("/foo", requestData.Path); + + await loopbackConnection.SendResponseAsync(content: "foo"); + + string response = await clientTask; + Assert.Equal("foo", response); + } + + private static bool PlatformSupportsUnixDomainSockets => Socket.OSSupportsUnixDomainSockets; + } + + public sealed class SocketsHttpHandlerTest_ConnectCallback_Http11 : SocketsHttpHandlerTest_ConnectCallback + { + public SocketsHttpHandlerTest_ConnectCallback_Http11(ITestOutputHelper output) : base(output) { } + } + + [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public sealed class SocketsHttpHandlerTest_ConnectCallback_Http2 : SocketsHttpHandlerTest_ConnectCallback + { + public SocketsHttpHandlerTest_ConnectCallback_Http2(ITestOutputHelper output) : base(output) { } + } + [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] public sealed class SocketsHttpHandlerTest_Cookies_Http2 : HttpClientHandlerTest_Cookies { -- 2.7.4