add SocketsHttpHandler.PlaintextStreamFilter (#42660)
authorGeoff Kizer <geoffrek@microsoft.com>
Thu, 24 Sep 2020 15:37:00 +0000 (08:37 -0700)
committerGitHub <noreply@github.com>
Thu, 24 Sep 2020 15:37:00 +0000 (08:37 -0700)
src/libraries/System.Net.Http/ref/System.Net.Http.cs
src/libraries/System.Net.Http/src/Resources/Strings.resx
src/libraries/System.Net.Http/src/System.Net.Http.csproj
src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpConnectionContext.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpPlaintextStreamFilterContext.cs [new file with mode: 0644]
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs

index 72bffc4..43b8ddb 100644 (file)
@@ -378,14 +378,21 @@ namespace System.Net.Http
         protected internal override System.Threading.Tasks.Task<System.Net.Http.HttpResponseMessage> SendAsync(System.Net.Http.HttpRequestMessage request, System.Threading.CancellationToken cancellationToken) { throw null; }
         public bool EnableMultipleHttp2Connections { get { throw null; } set { } }
         public Func<SocketsHttpConnectionContext, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<System.IO.Stream>>? ConnectCallback { get { throw null; } set { } }
+        public Func<SocketsHttpPlaintextStreamFilterContext, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<System.IO.Stream>>? PlaintextStreamFilter { get { throw null; } set { } }
     }
     public sealed class SocketsHttpConnectionContext
     {
         internal SocketsHttpConnectionContext() { }
         public DnsEndPoint DnsEndPoint { get { throw null; } }
-        public HttpRequestMessage RequestMessage { get { throw null; } }
+        public HttpRequestMessage InitialRequestMessage { get { throw null; } }
+    }
+    public sealed class SocketsHttpPlaintextStreamFilterContext
+    {
+        internal SocketsHttpPlaintextStreamFilterContext() { }
+        public System.IO.Stream PlaintextStream { get { throw null; } }
+        public Version NegotiatedHttpVersion { get { throw null; } }
+        public HttpRequestMessage InitialRequestMessage { get { throw null; } }
     }
-
     public enum HttpKeepAlivePingPolicy
     {
         WithActiveRequests,
index f9ad9e4..a5f53f1 100644 (file)
   <data name="net_http_sync_operations_not_allowed_with_connect_callback" xml:space="preserve">
     <value>Synchronous operation is not supported when a ConnectCallback is specified on the SocketsHttpHandler instance.</value>
   </data>
+  <data name="net_http_sync_operations_not_allowed_with_plaintext_filter" xml:space="preserve">
+    <value>Synchronous operation is not supported when a PlaintextStreamFilter is specified on the SocketsHttpHandler instance.</value>
+  </data>
+  <data name="net_http_exception_during_plaintext_filter" xml:space="preserve">
+    <value>An exception occurred while invoking the PlaintextStreamFilter.</value>
+  </data>
+  <data name="net_http_null_from_connect_callback" xml:space="preserve">
+    <value>The user-supplied ConnectCallback returned null.</value>
+  </data>
+  <data name="net_http_null_from_plaintext_filter" xml:space="preserve">
+    <value>The user-supplied PlaintextStreamFilter returned null.</value>
+  </data>
 </root>
index 1bdcb88..7f474d2 100644 (file)
   <ItemGroup>
     <Compile Include="System\Net\Http\DiagnosticsHandler.cs" />
     <Compile Include="System\Net\Http\DiagnosticsHandlerLoggingStrings.cs" />
+    <Compile Include="System\Net\Http\SocketsHttpHandler\SocketsHttpPlaintextStreamFilterContext.cs" />
     <Compile Include="$(CommonPath)System\Net\Mail\DomainLiteralReader.cs"
              Link="Common\System\Net\Mail\DomainLiteralReader.cs" />
     <Compile Include="$(CommonPath)System\Net\Mail\DotAtomReader.cs"
index 4ad7756..2c77415 100644 (file)
@@ -179,5 +179,11 @@ namespace System.Net.Http
             get => throw new PlatformNotSupportedException();
             set => throw new PlatformNotSupportedException();
         }
+
+        public Func<SocketsHttpPlaintextStreamFilterContext, CancellationToken, ValueTask<Stream>>? PlaintextStreamFilter
+        {
+            get => throw new PlatformNotSupportedException();
+            set => throw new PlatformNotSupportedException();
+        }
     }
 }
index 777b6e5..2e17fad 100644 (file)
@@ -33,9 +33,10 @@ namespace System.Net.Http
 
         public static async ValueTask<Stream> ConnectAsync(Func<SocketsHttpConnectionContext, CancellationToken, ValueTask<Stream>> callback, DnsEndPoint endPoint, HttpRequestMessage requestMessage, CancellationToken cancellationToken)
         {
+            Stream stream;
             try
             {
-                return await callback(new SocketsHttpConnectionContext(endPoint, requestMessage), cancellationToken).ConfigureAwait(false);
+                stream = await callback(new SocketsHttpConnectionContext(endPoint, requestMessage), cancellationToken).ConfigureAwait(false);
             }
             catch (OperationCanceledException ex) when (ex.CancellationToken == cancellationToken)
             {
@@ -45,6 +46,13 @@ namespace System.Net.Http
             {
                 throw CreateWrappedException(ex, endPoint.Host, endPoint.Port, cancellationToken);
             }
+
+            if (stream == null)
+            {
+                throw new HttpRequestException(SR.net_http_null_from_connect_callback);
+            }
+
+            return stream;
         }
 
         public static Stream Connect(string host, int port, CancellationToken cancellationToken)
index 233b590..715cab8 100644 (file)
@@ -614,10 +614,7 @@ namespace System.Net.Http
 
                     if (_kind == HttpConnectionKind.Http)
                     {
-                        http2Connection = new Http2Connection(this, stream);
-                        await http2Connection.SetupAsync().ConfigureAwait(false);
-
-                        AddHttp2Connection(http2Connection);
+                        http2Connection = await ConstructHttp2Connection(stream, request, cancellationToken).ConfigureAwait(false);
 
                         if (NetEventSource.Log.IsEnabled())
                         {
@@ -635,13 +632,11 @@ namespace System.Net.Http
 
                         if (sslStream.SslProtocol < SslProtocols.Tls12)
                         {
+                            sslStream.Dispose();
                             throw new HttpRequestException(SR.Format(SR.net_ssl_http2_requires_tls12, sslStream.SslProtocol));
                         }
 
-                        http2Connection = new Http2Connection(this, stream);
-                        await http2Connection.SetupAsync().ConfigureAwait(false);
-
-                        AddHttp2Connection(http2Connection);
+                        http2Connection = await ConstructHttp2Connection(stream, request, cancellationToken).ConfigureAwait(false);
 
                         if (NetEventSource.Log.IsEnabled())
                         {
@@ -695,7 +690,7 @@ namespace System.Net.Http
 
                 if (canUse)
                 {
-                    return (ConstructHttp11Connection(stream!, transportContext), true, null);
+                    return (await ConstructHttp11Connection(stream!, transportContext, request, cancellationToken).ConfigureAwait(false), true, null);
                 }
                 else
                 {
@@ -1342,7 +1337,7 @@ namespace System.Net.Http
                 return (null, failureResponse);
             }
 
-            return (ConstructHttp11Connection(stream!, transportContext), null);
+            return (await ConstructHttp11Connection(stream!, transportContext, request, cancellationToken).ConfigureAwait(false), null);
         }
 
         private SslClientAuthenticationOptions GetSslOptionsForRequest(HttpRequestMessage request)
@@ -1362,11 +1357,52 @@ namespace System.Net.Http
             return _sslOptionsHttp11!;
         }
 
-        private HttpConnection ConstructHttp11Connection(Stream stream, TransportContext? transportContext)
+        private async ValueTask<Stream> ApplyPlaintextFilter(Stream stream, Version httpVersion, HttpRequestMessage request, CancellationToken cancellationToken)
+        {
+            if (Settings._plaintextStreamFilter is null)
+            {
+                return stream;
+            }
+
+            Stream newStream;
+            try
+            {
+                newStream = await Settings._plaintextStreamFilter(new SocketsHttpPlaintextStreamFilterContext(stream, httpVersion, request), cancellationToken).ConfigureAwait(false);
+            }
+            catch (Exception e)
+            {
+                stream.Dispose();
+                throw new HttpRequestException(SR.net_http_exception_during_plaintext_filter, e);
+            }
+
+            if (newStream == null)
+            {
+                stream.Dispose();
+                throw new HttpRequestException(SR.net_http_null_from_plaintext_filter);
+            }
+
+            return newStream;
+        }
+
+        private async ValueTask<HttpConnection> ConstructHttp11Connection(Stream stream, TransportContext? transportContext, HttpRequestMessage request, CancellationToken cancellationToken)
         {
+            stream = await ApplyPlaintextFilter(stream, HttpVersion.Version11, request, cancellationToken).ConfigureAwait(false);
             return new HttpConnection(this, stream, transportContext);
         }
 
+        private async ValueTask<Http2Connection> ConstructHttp2Connection(Stream stream, HttpRequestMessage request, CancellationToken cancellationToken)
+        {
+            stream = await ApplyPlaintextFilter(stream, HttpVersion.Version20, request, cancellationToken).ConfigureAwait(false);
+
+            Http2Connection http2Connection = new Http2Connection(this, stream);
+            await http2Connection.SetupAsync().ConfigureAwait(false);
+
+            AddHttp2Connection(http2Connection);
+
+            return http2Connection;
+        }
+
+
         // Returns the established stream or an HttpResponseMessage from the proxy indicating failure.
         private async ValueTask<(Stream?, HttpResponseMessage?)> EstablishProxyTunnel(bool async, HttpRequestHeaders? headers, CancellationToken cancellationToken)
         {
index 00ce715..2e3fd77 100644 (file)
@@ -57,6 +57,7 @@ namespace System.Net.Http
         internal bool _enableMultipleHttp2Connections;
 
         internal Func<SocketsHttpConnectionContext, CancellationToken, ValueTask<Stream>>? _connectCallback;
+        internal Func<SocketsHttpPlaintextStreamFilterContext, CancellationToken, ValueTask<Stream>>? _plaintextStreamFilter;
 
         internal IDictionary<string, object?>? _properties;
 
@@ -112,6 +113,7 @@ namespace System.Net.Http
                 _responseHeaderEncodingSelector = _responseHeaderEncodingSelector,
                 _enableMultipleHttp2Connections = _enableMultipleHttp2Connections,
                 _connectCallback = _connectCallback,
+                _plaintextStreamFilter = _plaintextStreamFilter,
             };
         }
 
index fbd38df..21f8fe7 100644 (file)
@@ -9,12 +9,12 @@ namespace System.Net.Http
     public sealed class SocketsHttpConnectionContext
     {
         private readonly DnsEndPoint _dnsEndPoint;
-        private readonly HttpRequestMessage _requestMessage;
+        private readonly HttpRequestMessage _initialRequestMessage;
 
-        internal SocketsHttpConnectionContext(DnsEndPoint dnsEndPoint, HttpRequestMessage requestMessage)
+        internal SocketsHttpConnectionContext(DnsEndPoint dnsEndPoint, HttpRequestMessage initialRequestMessage)
         {
             _dnsEndPoint = dnsEndPoint;
-            _requestMessage = requestMessage;
+            _initialRequestMessage = initialRequestMessage;
         }
 
         /// <summary>
@@ -25,6 +25,6 @@ namespace System.Net.Http
         /// <summary>
         /// The initial HttpRequestMessage that is causing the connection to be created.
         /// </summary>
-        public HttpRequestMessage RequestMessage => _requestMessage;
+        public HttpRequestMessage InitialRequestMessage => _initialRequestMessage;
     }
 }
index 7a5827e..1e8ce69 100644 (file)
@@ -378,6 +378,19 @@ namespace System.Net.Http
             }
         }
 
+        /// <summary>
+        /// Gets or sets a custom callback that provides access to the plaintext HTTP protocol stream.
+        /// </summary>
+        public Func<SocketsHttpPlaintextStreamFilterContext, CancellationToken, ValueTask<Stream>>? PlaintextStreamFilter
+        {
+            get => _settings._plaintextStreamFilter;
+            set
+            {
+                CheckDisposedOrStarted();
+                _settings._plaintextStreamFilter = value;
+            }
+        }
+
         public IDictionary<string, object?> Properties =>
             _settings._properties ?? (_settings._properties = new Dictionary<string, object?>());
 
@@ -483,6 +496,11 @@ namespace System.Net.Http
             CheckDisposed();
             HttpMessageHandlerStage handler = _handler ?? SetupHandlerChain();
 
+            if (_settings._plaintextStreamFilter is not null)
+            {
+                throw new NotSupportedException(SR.net_http_sync_operations_not_allowed_with_plaintext_filter);
+            }
+
             Exception? error = ValidateAndNormalizeRequest(request);
             if (error != null)
             {
diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpPlaintextStreamFilterContext.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpPlaintextStreamFilterContext.cs
new file mode 100644 (file)
index 0000000..8611943
--- /dev/null
@@ -0,0 +1,39 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.IO;
+
+namespace System.Net.Http
+{
+    /// <summary>
+    /// Represents the context passed to the PlaintextStreamFilter for a SocketsHttpHandler instance.
+    /// </summary>
+    public sealed class SocketsHttpPlaintextStreamFilterContext
+    {
+        private readonly Stream _plaintextStream;
+        private readonly Version _negotiatedHttpVersion;
+        private readonly HttpRequestMessage _initialRequestMessage;
+
+        internal SocketsHttpPlaintextStreamFilterContext(Stream plaintextStream, Version negotiatedHttpVersion, HttpRequestMessage initialRequestMessage)
+        {
+            _plaintextStream = plaintextStream;
+            _negotiatedHttpVersion = negotiatedHttpVersion;
+            _initialRequestMessage = initialRequestMessage;
+        }
+
+        /// <summary>
+        /// The plaintext Stream that will be used for HTTP protocol requests and responses.
+        /// </summary>
+        public Stream PlaintextStream => _plaintextStream;
+
+        /// <summary>
+        /// The version of HTTP in use for this stream.
+        /// </summary>
+        public Version NegotiatedHttpVersion => _negotiatedHttpVersion;
+
+        /// <summary>
+        /// The initial HttpRequestMessage that is causing the stream to be used.
+        /// </summary>
+        public HttpRequestMessage InitialRequestMessage => _initialRequestMessage;
+    }
+}
index b356b55..777bda4 100644 (file)
@@ -1853,6 +1853,23 @@ namespace System.Net.Http.Functional.Tests
             }
         }
 
+        [Fact]
+        public void PlaintextStreamFilter_GetSet_Roundtrips()
+        {
+            using (var handler = new SocketsHttpHandler())
+            {
+                Assert.Null(handler.PlaintextStreamFilter);
+
+                Func<SocketsHttpPlaintextStreamFilterContext, CancellationToken, ValueTask<Stream>> f = (context, token) => default;
+
+                handler.PlaintextStreamFilter = f;
+                Assert.Equal(f, handler.PlaintextStreamFilter);
+
+                handler.PlaintextStreamFilter = null;
+                Assert.Null(handler.PlaintextStreamFilter);
+            }
+        }
+
         [Theory]
         [InlineData(false)]
         [InlineData(true)]
@@ -1891,6 +1908,7 @@ namespace System.Net.Http.Functional.Tests
                 Assert.True(handler.UseCookies);
                 Assert.True(handler.UseProxy);
                 Assert.Null(handler.ConnectCallback);
+                Assert.Null(handler.PlaintextStreamFilter);
 
                 Assert.Throws(expectedExceptionType, () => handler.AllowAutoRedirect = false);
                 Assert.Throws(expectedExceptionType, () => handler.AutomaticDecompression = DecompressionMethods.GZip);
@@ -1911,6 +1929,7 @@ namespace System.Net.Http.Functional.Tests
                 Assert.Throws(expectedExceptionType, () => handler.KeepAlivePingDelay = TimeSpan.FromSeconds(5));
                 Assert.Throws(expectedExceptionType, () => handler.KeepAlivePingPolicy = HttpKeepAlivePingPolicy.WithActiveRequests);
                 Assert.Throws(expectedExceptionType, () => handler.ConnectCallback = (context, token) => default);
+                Assert.Throws(expectedExceptionType, () => handler.PlaintextStreamFilter = (context, token) => default);
             }
         }
     }
@@ -2272,7 +2291,6 @@ namespace System.Net.Http.Functional.Tests
         }
     }
 
-
     public abstract class SocketsHttpHandlerTest_ConnectCallback : HttpClientHandlerTestBase
     {
         public SocketsHttpHandlerTest_ConnectCallback(ITestOutputHelper output) : base(output) { }
@@ -2307,7 +2325,7 @@ namespace System.Net.Http.Functional.Tests
                     {
                         Assert.Equal(uri.Host, context.DnsEndPoint.Host);
                         Assert.Equal(uri.Port, context.DnsEndPoint.Port);
-                        Assert.Equal(requestMessage, context.RequestMessage);
+                        Assert.Equal(requestMessage, context.InitialRequestMessage);
 
                         Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                         await s.ConnectAsync(context.DnsEndPoint, token);
@@ -2520,6 +2538,25 @@ namespace System.Net.Http.Functional.Tests
             Assert.Equal(e, hre.InnerException);
         }
 
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ConnectCallback_ReturnsNull_ThrowsHttpRequestException(bool useSsl)
+        {
+            using HttpClientHandler handler = CreateHttpClientHandler();
+            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+            var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+            socketsHandler.ConnectCallback = (context, token) =>
+            {
+                return ValueTask.FromResult<Stream>(null);
+            };
+
+            using HttpClient client = CreateHttpClient(handler);
+            client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+            HttpRequestException hre = await Assert.ThrowsAnyAsync<HttpRequestException>(async () => await client.GetAsync($"{(useSsl ? "https" : "http")}://nowhere.invalid/foo"));
+        }
+
         private static bool PlatformSupportsUnixDomainSockets => Socket.OSSupportsUnixDomainSockets;
    }
 
@@ -2535,6 +2572,309 @@ namespace System.Net.Http.Functional.Tests
         protected override Version UseVersion => HttpVersion.Version20;
     }
 
+    public abstract class SocketsHttpHandlerTest_PlaintextStreamFilter : HttpClientHandlerTestBase
+    {
+        public SocketsHttpHandlerTest_PlaintextStreamFilter(ITestOutputHelper output) : base(output) { }
+
+        [Fact]
+        public void PlaintextStreamFilter_SyncRequest_Fails()
+        {
+            using SocketsHttpHandler handler = new SocketsHttpHandler
+            {
+                PlaintextStreamFilter = (context, token) => default,
+            };
+
+            using HttpClient client = CreateHttpClient(handler);
+
+            Assert.ThrowsAny<NotSupportedException>(() => client.Send(new HttpRequestMessage(HttpMethod.Get, "http://bing.com")));
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async void PlaintextStreamFilter_ContextHasCorrectProperties_Success(bool useSsl)
+        {
+            GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri);
+                    requestMessage.Version = UseVersion;
+                    requestMessage.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+                    using HttpClientHandler handler = CreateHttpClientHandler();
+                    handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+                    var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+                    socketsHandler.PlaintextStreamFilter = (context, token) =>
+                    {
+                        Assert.Equal(UseVersion, context.NegotiatedHttpVersion);
+                        Assert.Equal(requestMessage, context.InitialRequestMessage);
+
+                        return ValueTask.FromResult(context.PlaintextStream);
+                    };
+
+                    using HttpClient client = CreateHttpClient(handler);
+
+                    HttpResponseMessage response = await client.SendAsync(requestMessage);
+                    Assert.Equal("foo", await response.Content.ReadAsStringAsync());
+                },
+                async server =>
+                {
+                    await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+                }, options: options);
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async void PlaintextStreamFilter_SimpleDelegatingStream_Success(bool useSsl)
+        {
+            GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    using HttpClientHandler handler = CreateHttpClientHandler();
+                    handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+                    var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+                    socketsHandler.PlaintextStreamFilter = (context, token) =>
+                    {
+                        Assert.Equal(UseVersion, context.NegotiatedHttpVersion);
+
+                        DelegateStream newStream = new DelegateStream(
+                            canReadFunc: () => true,
+                            canWriteFunc: () => true,
+                            readAsyncFunc: context.PlaintextStream.ReadAsync,
+                            writeAsyncFunc: context.PlaintextStream.WriteAsync,
+                            disposeFunc: (disposing) => { if (disposing) { context.PlaintextStream.Dispose(); } });
+
+                        return ValueTask.FromResult<Stream>(newStream);
+                    };
+
+                    using HttpClient client = CreateHttpClient(handler);
+                    client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+                    using HttpResponseMessage response = await client.GetAsync(uri);
+                    Assert.Equal("foo", await response.Content.ReadAsStringAsync());
+                },
+                async server =>
+                {
+                    await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+                }, options: options);
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task PlaintextStreamFilter_ConnectionPrefix_Success(bool useSsl)
+        {
+            byte[] RequestPrefix = Encoding.UTF8.GetBytes("request prefix\r\n");
+            byte[] ResponsePrefix = Encoding.UTF8.GetBytes("response prefix\r\n");
+
+            using var listenSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            listenSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+            listenSocket.Listen();
+
+            Task clientTask = Task.Run(async () =>
+            {
+                using HttpClientHandler handler = CreateHttpClientHandler();
+                handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+                var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+                socketsHandler.PlaintextStreamFilter = async (context, token) =>
+                {
+                    await context.PlaintextStream.WriteAsync(RequestPrefix);
+
+                    byte[] buffer = new byte[ResponsePrefix.Length];
+                    await context.PlaintextStream.ReadAsync(buffer);
+                    Assert.True(buffer.SequenceEqual(ResponsePrefix));
+
+                    return context.PlaintextStream;
+                };
+
+                using HttpClient client = CreateHttpClient(handler);
+                client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+                string response = await client.GetStringAsync($"{(useSsl ? "https" : "http")}://{listenSocket.LocalEndPoint}/foo");
+                Assert.Equal("foo", response);
+            });
+
+            Task serverTask = Task.Run(async () =>
+            {
+                Socket serverSocket = await listenSocket.AcceptAsync();
+                Stream serverStream = new NetworkStream(serverSocket, ownsSocket: true);
+
+                if (useSsl)
+                {
+                    var sslStream = new SslStream(serverStream, false, delegate { return true; });
+
+                    using (X509Certificate2 cert = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate())
+                    {
+                        SslServerAuthenticationOptions options = new SslServerAuthenticationOptions();
+
+                        options.EnabledSslProtocols = SslProtocols.Tls12;
+
+                        var protocols = new List<SslApplicationProtocol>();
+                        protocols.Add(SslApplicationProtocol.Http2);
+                        options.ApplicationProtocols = protocols;
+
+                        options.ServerCertificate = cert;
+
+                        await sslStream.AuthenticateAsServerAsync(options, CancellationToken.None).ConfigureAwait(false);
+                    }
+
+                    serverStream = sslStream;
+                }
+
+                byte[] buffer = new byte[RequestPrefix.Length];
+                await serverStream.ReadAsync(buffer);
+                Assert.True(buffer.SequenceEqual(RequestPrefix));
+
+                await serverStream.WriteAsync(ResponsePrefix);
+
+                using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, new GenericLoopbackOptions() { UseSsl = false });
+                await loopbackConnection.InitializeConnectionAsync();
+
+                HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync();
+                Assert.Equal("/foo", requestData.Path);
+
+                await loopbackConnection.SendResponseAsync(content: "foo");
+            });
+
+            await new Task[] { clientTask, serverTask }.WhenAllOrAnyFailed();
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async void PlaintextStreamFilter_ExceptionDuringCallback_ThrowsHttpRequestExceptionWithInnerException(bool useSsl)
+        {
+            Exception e = new Exception("hello!");
+
+            GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri);
+                    requestMessage.Version = UseVersion;
+                    requestMessage.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+                    using HttpClientHandler handler = CreateHttpClientHandler();
+                    handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+                    var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+                    socketsHandler.PlaintextStreamFilter = (context, token) =>
+                    {
+                        throw e;
+                    };
+
+                    using HttpClient client = CreateHttpClient(handler);
+
+                    HttpRequestException hre = await Assert.ThrowsAnyAsync<HttpRequestException>(async () => await client.SendAsync(requestMessage));
+                    Assert.Equal(e, hre.InnerException);
+                },
+                async server =>
+                {
+                    try
+                    {
+                        await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+                    }
+                    catch { }
+                }, options: options);
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async void PlaintextStreamFilter_ReturnsNull_ThrowsHttpRequestException(bool useSsl)
+        {
+            GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri);
+                    requestMessage.Version = UseVersion;
+                    requestMessage.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+                    using HttpClientHandler handler = CreateHttpClientHandler();
+                    handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+                    var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+                    socketsHandler.PlaintextStreamFilter = (context, token) =>
+                    {
+                        return ValueTask.FromResult<Stream>(null);
+                    };
+
+                    using HttpClient client = CreateHttpClient(handler);
+
+                    HttpRequestException hre = await Assert.ThrowsAnyAsync<HttpRequestException>(async () => await client.SendAsync(requestMessage));
+                },
+                async server =>
+                {
+                    try
+                    {
+                        await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+                    }
+                    catch { }
+                }, options: options);
+        }
+    }
+
+    public sealed class SocketsHttpHandlerTest_PlaintextStreamFilter_Http11 : SocketsHttpHandlerTest_PlaintextStreamFilter
+    {
+        public SocketsHttpHandlerTest_PlaintextStreamFilter_Http11(ITestOutputHelper output) : base(output) { }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async void PlaintextStreamFilter_CustomStream_Success(bool useSsl)
+        {
+            GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    using HttpClientHandler handler = CreateHttpClientHandler();
+                    handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+                    var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+                    socketsHandler.PlaintextStreamFilter = (context, token) =>
+                    {
+                        Assert.Equal(UseVersion, context.NegotiatedHttpVersion);
+
+                        context.PlaintextStream.Dispose();
+
+                        MemoryStream memoryStream = new MemoryStream();
+                        memoryStream.Write(Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"));
+                        memoryStream.Seek(0, SeekOrigin.Begin);
+
+                        DelegateStream newStream = new DelegateStream(
+                            canReadFunc: () => true,
+                            canWriteFunc: () => true,
+                            readAsyncFunc: (buffer, offset, length, cancellationToken) => memoryStream.ReadAsync(buffer, offset, length, cancellationToken),
+                            writeAsyncFunc: (buffer, offset, length, cancellationToken) => Task.CompletedTask);
+
+                        return ValueTask.FromResult<Stream>(newStream);
+                    };
+
+                    using HttpClient client = CreateHttpClient(handler);
+
+                    HttpResponseMessage response = await client.GetAsync(uri);
+                    Assert.Equal("foo", await response.Content.ReadAsStringAsync());
+                },
+                async server =>
+                {
+                    // Client intentionally disconnects. Ignore exception.
+                    try
+                    {
+                        await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+                    }
+                    catch (IOException) { }
+                }, options: options);
+        }
+    }
+
+    [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]
+    public sealed class SocketsHttpHandlerTest_PlaintextStreamFilter_Http2 : SocketsHttpHandlerTest_PlaintextStreamFilter
+    {
+        public SocketsHttpHandlerTest_PlaintextStreamFilter_Http2(ITestOutputHelper output) : base(output) { }
+        protected override Version UseVersion => HttpVersion.Version20;
+    }
+
     [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]
     public sealed class SocketsHttpHandlerTest_Cookies_Http2 : HttpClientHandlerTest_Cookies
     {