Implement SocketsHttpHandler's Expect100ContinueTimeout and ConnectTimeout (dotnet...
authorStephen Toub <stoub@microsoft.com>
Mon, 19 Feb 2018 16:06:18 +0000 (11:06 -0500)
committerGitHub <noreply@github.com>
Mon, 19 Feb 2018 16:06:18 +0000 (11:06 -0500)
With the expectation that we'll want to expose this in 2.1, implement Expect100ContinueTimeout and ConnectTimeout.  The members are currently internal but can be flipped public easily once the APIs are approved.  This also fixes an issue with cancellation around the connect phase, where a cancellation request that came in during the SSL auth phase would not be respected.

Commit migrated from https://github.com/dotnet/corefx/commit/01fa16ffd618846a913cad719269c8bb441ceb28

src/libraries/Common/src/System/Net/Http/HttpHandlerDefaults.cs
src/libraries/System.Net.Http/src/ILLinkTrim.xml
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs
src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj
src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj

index f7269c1..bf60536 100644 (file)
@@ -24,5 +24,7 @@ namespace System.Net.Http
         public const bool DefaultCheckCertificateRevocationList = false;
         public static readonly TimeSpan DefaultPooledConnectionLifetime = Timeout.InfiniteTimeSpan;
         public static readonly TimeSpan DefaultPooledConnectionIdleTimeout = TimeSpan.FromMinutes(2);
+        public static readonly TimeSpan DefaultExpect100ContinueTimeout = TimeSpan.FromSeconds(1);
+        public static readonly TimeSpan DefaultConnectTimeout = Timeout.InfiniteTimeSpan;
     }
 }
index 3957caf..e628fc1 100644 (file)
@@ -2,5 +2,6 @@
   <assembly fullname="System.Net.Http">
     <!-- Anonymous types are used with DiagnosticSource logging and subscribers reflect over those, calling their public getters. -->
     <type fullname="*f__AnonymousType*" />
+    <type fullname="System.Net.Http.SocketsHttpHandler" /> <!-- TODO #27235, #27145: Remove once public -->
   </assembly>
 </linker>
index c6cbcef..6f83457 100644 (file)
@@ -28,11 +28,6 @@ namespace System.Net.Http
         /// <summary>Default size of the write buffer used for the connection.</summary>
         private const int InitialWriteBufferSize = InitialReadBufferSize;
         /// <summary>
-        /// Delay after which we'll send the request payload for ExpectContinue if
-        /// the server hasn't yet responded.
-        /// </summary>
-        private const int Expect100TimeoutMilliseconds = 1000;
-        /// <summary>
         /// Size after which we'll close the connection rather than send the payload in response
         /// to final error status code sent by the server when using Expect: 100-continue.
         /// </summary>
@@ -353,7 +348,7 @@ namespace System.Net.Http
                         allowExpect100ToContinue = new TaskCompletionSource<bool>();
                         var expect100Timer = new Timer(
                             s => ((TaskCompletionSource<bool>)s).TrySetResult(true),
-                            allowExpect100ToContinue, TimeSpan.FromMilliseconds(Expect100TimeoutMilliseconds), Timeout.InfiniteTimeSpan);
+                            allowExpect100ToContinue, _pool.Settings._expect100ContinueTimeout, Timeout.InfiniteTimeSpan);
                         _sendRequestContentTask = SendRequestContentWithExpect100ContinueAsync(
                             request, allowExpect100ToContinue.Task, stream, expect100Timer, cancellationToken);
                     }
@@ -580,10 +575,10 @@ namespace System.Net.Http
             }, _weakThisRef);
         }
 
-        private static bool ShouldWrapInOperationCanceledException(Exception error, CancellationToken cancellationToken) =>
+        internal static bool ShouldWrapInOperationCanceledException(Exception error, CancellationToken cancellationToken) =>
             !(error is OperationCanceledException) && cancellationToken.IsCancellationRequested;
 
-        private static Exception CreateOperationCanceledException(Exception error, CancellationToken cancellationToken) =>
+        internal static Exception CreateOperationCanceledException(Exception error, CancellationToken cancellationToken) =>
             new OperationCanceledException(s_cancellationMessage, error, cancellationToken);
 
         private static bool LineIsEmpty(ArraySegment<byte> line) => line.Count == 0;
index c04b76c..5829e52 100644 (file)
@@ -206,24 +206,60 @@ namespace System.Net.Http
 
         private async ValueTask<HttpConnection> CreateConnectionAsync(HttpRequestMessage request, CancellationToken cancellationToken)
         {
-            Stream stream = await
-                (_proxyUri == null ?
-                    ConnectHelper.ConnectAsync(_host, _port, cancellationToken) :
-                    (_sslOptions == null ?
-                        ConnectHelper.ConnectAsync(_proxyUri.IdnHost, _proxyUri.Port, cancellationToken) :
-                        EstablishProxyTunnel(cancellationToken))).ConfigureAwait(false);
-
-            TransportContext transportContext = null;
-            if (_sslOptions != null)
+            // If a non-infinite connect timeout has been set, create and use a new CancellationToken that'll be canceled
+            // when either the original token is canceled or a connect timeout occurs.
+            CancellationTokenSource cancellationWithConnectTimeout = null;
+            if (Settings._connectTimeout != Timeout.InfiniteTimeSpan)
             {
-                SslStream sslStream = await ConnectHelper.EstablishSslConnectionAsync(_sslOptions, request, stream, cancellationToken).ConfigureAwait(false);
-                stream = sslStream;
-                transportContext = sslStream.TransportContext;
+                cancellationWithConnectTimeout = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, default);
+                cancellationWithConnectTimeout.CancelAfter(Settings._connectTimeout);
+                cancellationToken = cancellationWithConnectTimeout.Token;
             }
 
-            return _maxConnections == int.MaxValue ?
-                new HttpConnection(this, stream, transportContext) :
-                new HttpConnectionWithFinalizer(this, stream, transportContext); // finalizer needed to signal the pool when a connection is dropped
+            try
+            {
+                Stream stream = await
+                    (_proxyUri == null ?
+                        ConnectHelper.ConnectAsync(_host, _port, cancellationToken) :
+                        (_sslOptions == null ?
+                            ConnectHelper.ConnectAsync(_proxyUri.IdnHost, _proxyUri.Port, cancellationToken) :
+                            EstablishProxyTunnel(cancellationToken))).ConfigureAwait(false);
+
+                TransportContext transportContext = null;
+                if (_sslOptions != null)
+                {
+                    // TODO #25206 and #24430: Register/IsCancellationRequested should be removable once SslStream auth and sockets respect cancellation.
+                    CancellationTokenRegistration ctr = cancellationToken.Register(s => ((Stream)s).Dispose(), stream);
+                    try
+                    {
+                        SslStream sslStream = await ConnectHelper.EstablishSslConnectionAsync(_sslOptions, request, stream, cancellationToken).ConfigureAwait(false);
+                        stream = sslStream;
+                        transportContext = sslStream.TransportContext;
+                        cancellationToken.ThrowIfCancellationRequested(); // to handle race condition where stream is dispose of by cancellation after auth
+                    }
+                    catch (Exception exc)
+                    {
+                        stream.Dispose(); // in case cancellation occurs after successful SSL auth
+                        if (HttpConnection.ShouldWrapInOperationCanceledException(exc, cancellationToken))
+                        {
+                            throw HttpConnection.CreateOperationCanceledException(exc, cancellationToken);
+                        }
+                        throw;
+                    }
+                    finally
+                    {
+                        ctr.Dispose();
+                    }
+                }
+
+                return _maxConnections == int.MaxValue ?
+                    new HttpConnection(this, stream, transportContext) :
+                    new HttpConnectionWithFinalizer(this, stream, transportContext); // finalizer needed to signal the pool when a connection is dropped
+            }
+            finally
+            {
+                cancellationWithConnectTimeout?.Dispose();
+            }
         }
 
         // TODO (#23136):
index 2f12b19..05f0d1c 100644 (file)
@@ -30,6 +30,8 @@ namespace System.Net.Http
 
         internal TimeSpan _pooledConnectionLifetime = HttpHandlerDefaults.DefaultPooledConnectionLifetime;
         internal TimeSpan _pooledConnectionIdleTimeout = HttpHandlerDefaults.DefaultPooledConnectionIdleTimeout;
+        internal TimeSpan _expect100ContinueTimeout = HttpHandlerDefaults.DefaultExpect100ContinueTimeout;
+        internal TimeSpan _connectTimeout = HttpHandlerDefaults.DefaultConnectTimeout;
 
         internal SslClientAuthenticationOptions _sslOptions;
 
@@ -48,8 +50,10 @@ namespace System.Net.Http
                 _allowAutoRedirect = _allowAutoRedirect,
                 _automaticDecompression = _automaticDecompression,
                 _cookieContainer = _cookieContainer,
+                _connectTimeout = _connectTimeout,
                 _credentials = _credentials,
                 _defaultProxyCredentials = _defaultProxyCredentials,
+                _expect100ContinueTimeout = _expect100ContinueTimeout,
                 _maxAutomaticRedirections = _maxAutomaticRedirections,
                 _maxConnectionsPerServer = _maxConnectionsPerServer,
                 _maxResponseHeadersLength = _maxResponseHeadersLength,
index 681c513..455fdcb 100644 (file)
@@ -207,6 +207,38 @@ namespace System.Net.Http
             }
         }
 
+        internal TimeSpan ConnectTimeout // TODO #27235: Expose publicly
+        {
+            get => _settings._connectTimeout;
+            set
+            {
+                if ((value <= TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) ||
+                    (value.TotalMilliseconds > int.MaxValue))
+                {
+                    throw new ArgumentOutOfRangeException(nameof(value));
+                }
+
+                CheckDisposedOrStarted();
+                _settings._connectTimeout = value;
+            }
+        }
+
+        internal TimeSpan Expect100ContinueTimeout // TODO #27145: Expose publicly
+        {
+            get => _settings._expect100ContinueTimeout;
+            set
+            {
+                if ((value < TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) ||
+                    (value.TotalMilliseconds > int.MaxValue))
+                {
+                    throw new ArgumentOutOfRangeException(nameof(value));
+                }
+
+                CheckDisposedOrStarted();
+                _settings._expect100ContinueTimeout = value;
+            }
+        }
+
         public IDictionary<string, object> Properties =>
             _settings._properties ?? (_settings._properties = new Dictionary<string, object>());
 
index c1357e0..162dc19 100644 (file)
@@ -4,11 +4,13 @@
 
 using System.Collections.Concurrent;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.IO;
 using System.Linq;
 using System.Net.Security;
 using System.Net.Sockets;
 using System.Net.Test.Common;
+using System.Reflection;
 using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
 using System.Text;
@@ -129,6 +131,209 @@ namespace System.Net.Http.Functional.Tests
     public sealed class SocketsHttpHandler_HttpClientHandler_Cancellation_Test : HttpClientHandler_Cancellation_Test
     {
         protected override bool UseSocketsHttpHandler => true;
+
+        // TODO #27235:
+        // Remove these reflection helpers once the property is exposed.
+        private TimeSpan GetConnectTimeout(SocketsHttpHandler handler) =>
+            (TimeSpan)typeof(SocketsHttpHandler).GetProperty("ConnectTimeout", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(handler);
+        private void SetConnectTimeout(SocketsHttpHandler handler, TimeSpan timeout)
+        {
+            try
+            {
+                typeof(SocketsHttpHandler).GetProperty("ConnectTimeout", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(handler, timeout);
+            }
+            catch (TargetInvocationException tie)
+            {
+                if (tie.InnerException != null) throw tie.InnerException;
+                throw;
+            }
+        }
+
+        // TODO #27145:
+        // Remove these reflection helpers once the property is exposed.
+        private TimeSpan GetExpect100ContinueTimeout(SocketsHttpHandler handler) =>
+            (TimeSpan)typeof(SocketsHttpHandler).GetProperty("Expect100ContinueTimeout", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(handler);
+        private void SetExpect100ContinueTimeout(SocketsHttpHandler handler, TimeSpan timeout)
+        {
+            try
+            {
+                typeof(SocketsHttpHandler).GetProperty("Expect100ContinueTimeout", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(handler, timeout);
+            }
+            catch (TargetInvocationException tie)
+            {
+                if (tie.InnerException != null) throw tie.InnerException;
+                throw;
+            }
+        }
+
+        [Fact]
+        public void ConnectTimeout_Default()
+        {
+            using (var handler = new SocketsHttpHandler())
+            {
+                Assert.Equal(Timeout.InfiniteTimeSpan, GetConnectTimeout(handler));
+            }
+        }
+
+        [Theory]
+        [InlineData(0)]
+        [InlineData(-2)]
+        [InlineData(int.MaxValue + 1L)]
+        public void ConnectTimeout_InvalidValues(long ms)
+        {
+            using (var handler = new SocketsHttpHandler())
+            {
+                Assert.Throws<ArgumentOutOfRangeException>(() => SetConnectTimeout(handler, TimeSpan.FromMilliseconds(ms)));
+            }
+        }
+
+        [Theory]
+        [InlineData(-1)]
+        [InlineData(1)]
+        [InlineData(int.MaxValue - 1)]
+        [InlineData(int.MaxValue)]
+        public void ConnectTimeout_ValidValues_Roundtrip(long ms)
+        {
+            using (var handler = new SocketsHttpHandler())
+            {
+                SetConnectTimeout(handler, TimeSpan.FromMilliseconds(ms));
+                Assert.Equal(TimeSpan.FromMilliseconds(ms), GetConnectTimeout(handler));
+            }
+        }
+
+        [Fact]
+        public void ConnectTimeout_SetAfterUse_Throws()
+        {
+            using (var handler = new SocketsHttpHandler())
+            using (var client = new HttpClient(handler))
+            {
+                SetConnectTimeout(handler, TimeSpan.FromMilliseconds(int.MaxValue));
+                client.GetAsync("http://" + Guid.NewGuid().ToString("N")); // ignoring failure
+                Assert.Equal(TimeSpan.FromMilliseconds(int.MaxValue), GetConnectTimeout(handler));
+                Assert.Throws<InvalidOperationException>(() => SetConnectTimeout(handler, TimeSpan.FromMilliseconds(1)));
+            }
+        }
+
+        [OuterLoop]
+        [Fact]
+        public async Task ConnectTimeout_TimesOutSSLAuth_Throws()
+        {
+            var releaseServer = new TaskCompletionSource<bool>();
+            await LoopbackServer.CreateClientAndServerAsync(async uri =>
+            {
+                using (var handler = new SocketsHttpHandler())
+                using (var invoker = new HttpMessageInvoker(handler))
+                {
+                    SetConnectTimeout(handler, TimeSpan.FromSeconds(1));
+
+                    var sw = Stopwatch.StartNew();
+                    await Assert.ThrowsAsync<OperationCanceledException>(() =>
+                        invoker.SendAsync(new HttpRequestMessage(HttpMethod.Get,
+                            new UriBuilder(uri) { Scheme = "https" }.ToString()), default));
+                    sw.Stop();
+
+                    Assert.InRange(sw.ElapsedMilliseconds, 500, 10_000);
+                    releaseServer.SetResult(true);
+                }
+            }, server => releaseServer.Task); // doesn't establish SSL connection
+        }
+
+
+        [Fact]
+        public void Expect100ContinueTimeout_Default()
+        {
+            using (var handler = new SocketsHttpHandler())
+            {
+                Assert.Equal(TimeSpan.FromSeconds(1), GetExpect100ContinueTimeout(handler));
+            }
+        }
+
+        [Theory]
+        [InlineData(-2)]
+        [InlineData(int.MaxValue + 1L)]
+        public void Expect100ContinueTimeout_InvalidValues(long ms)
+        {
+            using (var handler = new SocketsHttpHandler())
+            {
+                Assert.Throws<ArgumentOutOfRangeException>(() => SetExpect100ContinueTimeout(handler, TimeSpan.FromMilliseconds(ms)));
+            }
+        }
+
+        [Theory]
+        [InlineData(-1)]
+        [InlineData(1)]
+        [InlineData(int.MaxValue - 1)]
+        [InlineData(int.MaxValue)]
+        public void Expect100ContinueTimeout_ValidValues_Roundtrip(long ms)
+        {
+            using (var handler = new SocketsHttpHandler())
+            {
+                SetExpect100ContinueTimeout(handler, TimeSpan.FromMilliseconds(ms));
+                Assert.Equal(TimeSpan.FromMilliseconds(ms), GetExpect100ContinueTimeout(handler));
+            }
+        }
+
+        [Fact]
+        public void Expect100ContinueTimeout_SetAfterUse_Throws()
+        {
+            using (var handler = new SocketsHttpHandler())
+            using (var client = new HttpClient(handler))
+            {
+                SetExpect100ContinueTimeout(handler, TimeSpan.FromMilliseconds(int.MaxValue));
+                client.GetAsync("http://" + Guid.NewGuid().ToString("N")); // ignoring failure
+                Assert.Equal(TimeSpan.FromMilliseconds(int.MaxValue), GetExpect100ContinueTimeout(handler));
+                Assert.Throws<InvalidOperationException>(() => SetExpect100ContinueTimeout(handler, TimeSpan.FromMilliseconds(1)));
+            }
+        }
+
+        [OuterLoop("Incurs significant delay")]
+        [Fact]
+        public async Task Expect100Continue_WaitsExpectedPeriodOfTimeBeforeSendingContent()
+        {
+            await LoopbackServer.CreateClientAndServerAsync(async uri =>
+            {
+                using (var handler = new SocketsHttpHandler())
+                using (var invoker = new HttpMessageInvoker(handler))
+                {
+                    TimeSpan delay = TimeSpan.FromSeconds(3);
+
+                    // TODO #27145: Remove reflection once publicly exposed
+                    typeof(SocketsHttpHandler).GetProperty("Expect100ContinueTimeout", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(handler, delay);
+
+                    var tcs = new TaskCompletionSource<bool>();
+                    var content = new SetTcsContent(new MemoryStream(new byte[1]), tcs);
+                    var request = new HttpRequestMessage(HttpMethod.Post, uri) { Content = content };
+                    request.Headers.ExpectContinue = true;
+
+                    var sw = Stopwatch.StartNew();
+                    (await invoker.SendAsync(request, default)).Dispose();
+                    sw.Stop();
+
+                    Assert.InRange(sw.Elapsed, delay - TimeSpan.FromSeconds(.5), delay * 5); // arbitrary wiggle room
+                }
+            }, async server =>
+            {
+                await server.AcceptConnectionAsync(async connection =>
+                {
+                    await connection.ReadRequestHeaderAsync();
+                    await connection.Reader.ReadAsync(new char[1]);
+                    await connection.SendResponseAsync();
+                });
+            });
+        }
+
+        private sealed class SetTcsContent : StreamContent
+        {
+            private readonly TaskCompletionSource<bool> _tcs;
+
+            public SetTcsContent(Stream stream, TaskCompletionSource<bool> tcs) : base(stream) => _tcs = tcs;
+
+            protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
+            {
+                _tcs.SetResult(true);
+                return base.SerializeToStreamAsync(stream, context);
+            }
+        }
     }
 
     public sealed class SocketsHttpHandler_HttpClientHandler_MaxResponseHeadersLength_Test : HttpClientHandler_MaxResponseHeadersLength_Test
index 523c9d8..3446979 100644 (file)
   <ItemGroup Condition="'$(TargetsOSX)'=='true'">
     <TestCommandLines Include="ulimit -Sn 4096" />
   </ItemGroup>
+  <ItemGroup>
+    <Service Include="{82A7F48D-3B50-4B1E-B82E-3ADA8210C358}" />
+  </ItemGroup>
   <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
-</Project>
+</Project>
\ No newline at end of file
index d299cdd..33dc3a7 100644 (file)
       <Name>RemoteExecutorConsoleApp</Name>
     </ProjectReference>
   </ItemGroup>
+  <ItemGroup>
+    <Service Include="{82A7F48D-3B50-4B1E-B82E-3ADA8210C358}" />
+  </ItemGroup>
   <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
-</Project>
+</Project>
\ No newline at end of file