if (_kind == HttpConnectionKind.Http)
{
- http2Connection = new Http2Connection(this, stream);
- await http2Connection.SetupAsync().ConfigureAwait(false);
-
- AddHttp2Connection(http2Connection);
+ http2Connection = await ConstructHttp2Connection(stream, request, cancellationToken).ConfigureAwait(false);
if (NetEventSource.Log.IsEnabled())
{
if (sslStream.SslProtocol < SslProtocols.Tls12)
{
+ sslStream.Dispose();
throw new HttpRequestException(SR.Format(SR.net_ssl_http2_requires_tls12, sslStream.SslProtocol));
}
- http2Connection = new Http2Connection(this, stream);
- await http2Connection.SetupAsync().ConfigureAwait(false);
-
- AddHttp2Connection(http2Connection);
+ http2Connection = await ConstructHttp2Connection(stream, request, cancellationToken).ConfigureAwait(false);
if (NetEventSource.Log.IsEnabled())
{
if (canUse)
{
- return (ConstructHttp11Connection(stream!, transportContext), true, null);
+ return (await ConstructHttp11Connection(stream!, transportContext, request, cancellationToken).ConfigureAwait(false), true, null);
}
else
{
return (null, failureResponse);
}
- return (ConstructHttp11Connection(stream!, transportContext), null);
+ return (await ConstructHttp11Connection(stream!, transportContext, request, cancellationToken).ConfigureAwait(false), null);
}
private SslClientAuthenticationOptions GetSslOptionsForRequest(HttpRequestMessage request)
return _sslOptionsHttp11!;
}
- private HttpConnection ConstructHttp11Connection(Stream stream, TransportContext? transportContext)
+ private async ValueTask<Stream> ApplyPlaintextFilter(Stream stream, Version httpVersion, HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ if (Settings._plaintextStreamFilter is null)
+ {
+ return stream;
+ }
+
+ Stream newStream;
+ try
+ {
+ newStream = await Settings._plaintextStreamFilter(new SocketsHttpPlaintextStreamFilterContext(stream, httpVersion, request), cancellationToken).ConfigureAwait(false);
+ }
+ catch (Exception e)
+ {
+ stream.Dispose();
+ throw new HttpRequestException(SR.net_http_exception_during_plaintext_filter, e);
+ }
+
+ if (newStream == null)
+ {
+ stream.Dispose();
+ throw new HttpRequestException(SR.net_http_null_from_plaintext_filter);
+ }
+
+ return newStream;
+ }
+
+ private async ValueTask<HttpConnection> ConstructHttp11Connection(Stream stream, TransportContext? transportContext, HttpRequestMessage request, CancellationToken cancellationToken)
{
+ stream = await ApplyPlaintextFilter(stream, HttpVersion.Version11, request, cancellationToken).ConfigureAwait(false);
return new HttpConnection(this, stream, transportContext);
}
+ private async ValueTask<Http2Connection> ConstructHttp2Connection(Stream stream, HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ stream = await ApplyPlaintextFilter(stream, HttpVersion.Version20, request, cancellationToken).ConfigureAwait(false);
+
+ Http2Connection http2Connection = new Http2Connection(this, stream);
+ await http2Connection.SetupAsync().ConfigureAwait(false);
+
+ AddHttp2Connection(http2Connection);
+
+ return http2Connection;
+ }
+
+
// Returns the established stream or an HttpResponseMessage from the proxy indicating failure.
private async ValueTask<(Stream?, HttpResponseMessage?)> EstablishProxyTunnel(bool async, HttpRequestHeaders? headers, CancellationToken cancellationToken)
{
}
}
+ [Fact]
+ public void PlaintextStreamFilter_GetSet_Roundtrips()
+ {
+ using (var handler = new SocketsHttpHandler())
+ {
+ Assert.Null(handler.PlaintextStreamFilter);
+
+ Func<SocketsHttpPlaintextStreamFilterContext, CancellationToken, ValueTask<Stream>> f = (context, token) => default;
+
+ handler.PlaintextStreamFilter = f;
+ Assert.Equal(f, handler.PlaintextStreamFilter);
+
+ handler.PlaintextStreamFilter = null;
+ Assert.Null(handler.PlaintextStreamFilter);
+ }
+ }
+
[Theory]
[InlineData(false)]
[InlineData(true)]
Assert.True(handler.UseCookies);
Assert.True(handler.UseProxy);
Assert.Null(handler.ConnectCallback);
+ Assert.Null(handler.PlaintextStreamFilter);
Assert.Throws(expectedExceptionType, () => handler.AllowAutoRedirect = false);
Assert.Throws(expectedExceptionType, () => handler.AutomaticDecompression = DecompressionMethods.GZip);
Assert.Throws(expectedExceptionType, () => handler.KeepAlivePingDelay = TimeSpan.FromSeconds(5));
Assert.Throws(expectedExceptionType, () => handler.KeepAlivePingPolicy = HttpKeepAlivePingPolicy.WithActiveRequests);
Assert.Throws(expectedExceptionType, () => handler.ConnectCallback = (context, token) => default);
+ Assert.Throws(expectedExceptionType, () => handler.PlaintextStreamFilter = (context, token) => default);
}
}
}
}
}
-
public abstract class SocketsHttpHandlerTest_ConnectCallback : HttpClientHandlerTestBase
{
public SocketsHttpHandlerTest_ConnectCallback(ITestOutputHelper output) : base(output) { }
{
Assert.Equal(uri.Host, context.DnsEndPoint.Host);
Assert.Equal(uri.Port, context.DnsEndPoint.Port);
- Assert.Equal(requestMessage, context.RequestMessage);
+ Assert.Equal(requestMessage, context.InitialRequestMessage);
Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await s.ConnectAsync(context.DnsEndPoint, token);
Assert.Equal(e, hre.InnerException);
}
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task ConnectCallback_ReturnsNull_ThrowsHttpRequestException(bool useSsl)
+ {
+ using HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+ var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+ socketsHandler.ConnectCallback = (context, token) =>
+ {
+ return ValueTask.FromResult<Stream>(null);
+ };
+
+ using HttpClient client = CreateHttpClient(handler);
+ client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+ HttpRequestException hre = await Assert.ThrowsAnyAsync<HttpRequestException>(async () => await client.GetAsync($"{(useSsl ? "https" : "http")}://nowhere.invalid/foo"));
+ }
+
private static bool PlatformSupportsUnixDomainSockets => Socket.OSSupportsUnixDomainSockets;
}
protected override Version UseVersion => HttpVersion.Version20;
}
+ public abstract class SocketsHttpHandlerTest_PlaintextStreamFilter : HttpClientHandlerTestBase
+ {
+ public SocketsHttpHandlerTest_PlaintextStreamFilter(ITestOutputHelper output) : base(output) { }
+
+ [Fact]
+ public void PlaintextStreamFilter_SyncRequest_Fails()
+ {
+ using SocketsHttpHandler handler = new SocketsHttpHandler
+ {
+ PlaintextStreamFilter = (context, token) => default,
+ };
+
+ using HttpClient client = CreateHttpClient(handler);
+
+ Assert.ThrowsAny<NotSupportedException>(() => client.Send(new HttpRequestMessage(HttpMethod.Get, "http://bing.com")));
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async void PlaintextStreamFilter_ContextHasCorrectProperties_Success(bool useSsl)
+ {
+ GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+ await LoopbackServerFactory.CreateClientAndServerAsync(
+ async uri =>
+ {
+ HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri);
+ requestMessage.Version = UseVersion;
+ requestMessage.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+ using HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+ var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+ socketsHandler.PlaintextStreamFilter = (context, token) =>
+ {
+ Assert.Equal(UseVersion, context.NegotiatedHttpVersion);
+ Assert.Equal(requestMessage, context.InitialRequestMessage);
+
+ return ValueTask.FromResult(context.PlaintextStream);
+ };
+
+ using HttpClient client = CreateHttpClient(handler);
+
+ HttpResponseMessage response = await client.SendAsync(requestMessage);
+ Assert.Equal("foo", await response.Content.ReadAsStringAsync());
+ },
+ async server =>
+ {
+ await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+ }, options: options);
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async void PlaintextStreamFilter_SimpleDelegatingStream_Success(bool useSsl)
+ {
+ GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+ await LoopbackServerFactory.CreateClientAndServerAsync(
+ async uri =>
+ {
+ using HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+ var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+ socketsHandler.PlaintextStreamFilter = (context, token) =>
+ {
+ Assert.Equal(UseVersion, context.NegotiatedHttpVersion);
+
+ DelegateStream newStream = new DelegateStream(
+ canReadFunc: () => true,
+ canWriteFunc: () => true,
+ readAsyncFunc: context.PlaintextStream.ReadAsync,
+ writeAsyncFunc: context.PlaintextStream.WriteAsync,
+ disposeFunc: (disposing) => { if (disposing) { context.PlaintextStream.Dispose(); } });
+
+ return ValueTask.FromResult<Stream>(newStream);
+ };
+
+ using HttpClient client = CreateHttpClient(handler);
+ client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+ using HttpResponseMessage response = await client.GetAsync(uri);
+ Assert.Equal("foo", await response.Content.ReadAsStringAsync());
+ },
+ async server =>
+ {
+ await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+ }, options: options);
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task PlaintextStreamFilter_ConnectionPrefix_Success(bool useSsl)
+ {
+ byte[] RequestPrefix = Encoding.UTF8.GetBytes("request prefix\r\n");
+ byte[] ResponsePrefix = Encoding.UTF8.GetBytes("response prefix\r\n");
+
+ using var listenSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+ listenSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+ listenSocket.Listen();
+
+ Task clientTask = Task.Run(async () =>
+ {
+ using HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+ var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+ socketsHandler.PlaintextStreamFilter = async (context, token) =>
+ {
+ await context.PlaintextStream.WriteAsync(RequestPrefix);
+
+ byte[] buffer = new byte[ResponsePrefix.Length];
+ await context.PlaintextStream.ReadAsync(buffer);
+ Assert.True(buffer.SequenceEqual(ResponsePrefix));
+
+ return context.PlaintextStream;
+ };
+
+ using HttpClient client = CreateHttpClient(handler);
+ client.DefaultVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+ string response = await client.GetStringAsync($"{(useSsl ? "https" : "http")}://{listenSocket.LocalEndPoint}/foo");
+ Assert.Equal("foo", response);
+ });
+
+ Task serverTask = Task.Run(async () =>
+ {
+ Socket serverSocket = await listenSocket.AcceptAsync();
+ Stream serverStream = new NetworkStream(serverSocket, ownsSocket: true);
+
+ if (useSsl)
+ {
+ var sslStream = new SslStream(serverStream, false, delegate { return true; });
+
+ using (X509Certificate2 cert = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate())
+ {
+ SslServerAuthenticationOptions options = new SslServerAuthenticationOptions();
+
+ options.EnabledSslProtocols = SslProtocols.Tls12;
+
+ var protocols = new List<SslApplicationProtocol>();
+ protocols.Add(SslApplicationProtocol.Http2);
+ options.ApplicationProtocols = protocols;
+
+ options.ServerCertificate = cert;
+
+ await sslStream.AuthenticateAsServerAsync(options, CancellationToken.None).ConfigureAwait(false);
+ }
+
+ serverStream = sslStream;
+ }
+
+ byte[] buffer = new byte[RequestPrefix.Length];
+ await serverStream.ReadAsync(buffer);
+ Assert.True(buffer.SequenceEqual(RequestPrefix));
+
+ await serverStream.WriteAsync(ResponsePrefix);
+
+ using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, new GenericLoopbackOptions() { UseSsl = false });
+ await loopbackConnection.InitializeConnectionAsync();
+
+ HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync();
+ Assert.Equal("/foo", requestData.Path);
+
+ await loopbackConnection.SendResponseAsync(content: "foo");
+ });
+
+ await new Task[] { clientTask, serverTask }.WhenAllOrAnyFailed();
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async void PlaintextStreamFilter_ExceptionDuringCallback_ThrowsHttpRequestExceptionWithInnerException(bool useSsl)
+ {
+ Exception e = new Exception("hello!");
+
+ GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+ await LoopbackServerFactory.CreateClientAndServerAsync(
+ async uri =>
+ {
+ HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri);
+ requestMessage.Version = UseVersion;
+ requestMessage.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+ using HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+ var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+ socketsHandler.PlaintextStreamFilter = (context, token) =>
+ {
+ throw e;
+ };
+
+ using HttpClient client = CreateHttpClient(handler);
+
+ HttpRequestException hre = await Assert.ThrowsAnyAsync<HttpRequestException>(async () => await client.SendAsync(requestMessage));
+ Assert.Equal(e, hre.InnerException);
+ },
+ async server =>
+ {
+ try
+ {
+ await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+ }
+ catch { }
+ }, options: options);
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async void PlaintextStreamFilter_ReturnsNull_ThrowsHttpRequestException(bool useSsl)
+ {
+ GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+ await LoopbackServerFactory.CreateClientAndServerAsync(
+ async uri =>
+ {
+ HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri);
+ requestMessage.Version = UseVersion;
+ requestMessage.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+ using HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+ var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+ socketsHandler.PlaintextStreamFilter = (context, token) =>
+ {
+ return ValueTask.FromResult<Stream>(null);
+ };
+
+ using HttpClient client = CreateHttpClient(handler);
+
+ HttpRequestException hre = await Assert.ThrowsAnyAsync<HttpRequestException>(async () => await client.SendAsync(requestMessage));
+ },
+ async server =>
+ {
+ try
+ {
+ await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+ }
+ catch { }
+ }, options: options);
+ }
+ }
+
+ public sealed class SocketsHttpHandlerTest_PlaintextStreamFilter_Http11 : SocketsHttpHandlerTest_PlaintextStreamFilter
+ {
+ public SocketsHttpHandlerTest_PlaintextStreamFilter_Http11(ITestOutputHelper output) : base(output) { }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async void PlaintextStreamFilter_CustomStream_Success(bool useSsl)
+ {
+ GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+ await LoopbackServerFactory.CreateClientAndServerAsync(
+ async uri =>
+ {
+ using HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+ var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+ socketsHandler.PlaintextStreamFilter = (context, token) =>
+ {
+ Assert.Equal(UseVersion, context.NegotiatedHttpVersion);
+
+ context.PlaintextStream.Dispose();
+
+ MemoryStream memoryStream = new MemoryStream();
+ memoryStream.Write(Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"));
+ memoryStream.Seek(0, SeekOrigin.Begin);
+
+ DelegateStream newStream = new DelegateStream(
+ canReadFunc: () => true,
+ canWriteFunc: () => true,
+ readAsyncFunc: (buffer, offset, length, cancellationToken) => memoryStream.ReadAsync(buffer, offset, length, cancellationToken),
+ writeAsyncFunc: (buffer, offset, length, cancellationToken) => Task.CompletedTask);
+
+ return ValueTask.FromResult<Stream>(newStream);
+ };
+
+ using HttpClient client = CreateHttpClient(handler);
+
+ HttpResponseMessage response = await client.GetAsync(uri);
+ Assert.Equal("foo", await response.Content.ReadAsStringAsync());
+ },
+ async server =>
+ {
+ // Client intentionally disconnects. Ignore exception.
+ try
+ {
+ await server.AcceptConnectionSendResponseAndCloseAsync(content: "foo");
+ }
+ catch (IOException) { }
+ }, options: options);
+ }
+ }
+
+ [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]
+ public sealed class SocketsHttpHandlerTest_PlaintextStreamFilter_Http2 : SocketsHttpHandlerTest_PlaintextStreamFilter
+ {
+ public SocketsHttpHandlerTest_PlaintextStreamFilter_Http2(ITestOutputHelper output) : base(output) { }
+ protected override Version UseVersion => HttpVersion.Version20;
+ }
+
[ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]
public sealed class SocketsHttpHandlerTest_Cookies_Http2 : HttpClientHandlerTest_Cookies
{