improve handling of 100continue for H2 (dotnet/corefx#39869)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Tue, 30 Jul 2019 20:43:57 +0000 (13:43 -0700)
committerGitHub <noreply@github.com>
Tue, 30 Jul 2019 20:43:57 +0000 (13:43 -0700)
* improve handling of 100continue for H2

* feedback from review

Commit migrated from https://github.com/dotnet/corefx/commit/556792fdc3f76af238e3955b3d5716ede558d42f

src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs
src/libraries/System.Net.Http/tests/FunctionalTests/NtAuthTests.cs

index 4ed6157..eff4730 100644 (file)
@@ -111,6 +111,13 @@ namespace System.Net.Http
                 {
                     // Create this here because it can be canceled before SendRequestBodyAsync is even called.
                     _requestBodyCancellationSource = new CancellationTokenSource();
+
+                    if (_request.HasHeaders && _request.Headers.ExpectContinue == true)
+                    {
+                        // Create a TCS for handling Expect: 100-continue semantics. See WaitFor100ContinueAsync.
+                        // Note we need to create this in the constructor, because we can receive a 100 Continue response at any time after the constructor finishes.
+                        _expect100ContinueWaiter = new TaskCompletionSource<bool>(TaskContinuationOptions.RunContinuationsAsynchronously);
+                    }
                 }
 
                 if (NetEventSource.IsEnabled) Trace($"{request}, {nameof(initialWindowSize)}={initialWindowSize}");
@@ -150,7 +157,7 @@ namespace System.Net.Http
                 try
                 {
                     bool sendRequestContent = true;
-                    if (_request.HasHeaders && _request.Headers.ExpectContinue == true)
+                    if (_expect100ContinueWaiter != null)
                     {
                         sendRequestContent = await WaitFor100ContinueAsync(_requestBodyCancellationToken).ConfigureAwait(false);
                     }
@@ -221,14 +228,14 @@ namespace System.Net.Http
                 Debug.Assert(_request.Content != null);
                 if (NetEventSource.IsEnabled) Trace($"Waiting to send request body content for 100-Continue.");
 
-                // Create a TCS that will complete when one of two things occurs:
+                // use TCS created in constructor. It will complete when one of two things occurs:
                 // 1. if a timer fires before we receive the relevant response from the server.
                 // 2. if we receive the relevant response from the server before a timer fires.
                 // In the first case, we could run this continuation synchronously, but in the latter, we shouldn't,
                 // as we could end up starting the body copy operation on the main event loop thread, which could
                 // then starve the processing of other requests.  So, we make the TCS RunContinuationsAsynchronously.
                 bool sendRequestContent;
-                var waiter = _expect100ContinueWaiter = new TaskCompletionSource<bool>(TaskContinuationOptions.RunContinuationsAsynchronously);
+                TaskCompletionSource<bool> waiter = _expect100ContinueWaiter;
                 using (var expect100Timer = new Timer(s =>
                 {
                     var thisRef = (Http2Stream)s;
@@ -240,7 +247,6 @@ namespace System.Net.Http
                     // By now, either we got a response from the server or the timer expired.
                 }
 
-                _expect100ContinueWaiter = null;
                 return sendRequestContent;
             }
 
@@ -382,17 +388,15 @@ namespace System.Net.Http
                                 StatusCode = (HttpStatusCode)statusValue
                             };
 
-                            TaskCompletionSource<bool> expect100ContinueWaiter = _expect100ContinueWaiter;
                             if (statusValue < 200)
                             {
                                 // We do not process headers from 1xx responses.
                                 _responseProtocolState = ResponseProtocolState.ExpectingIgnoredHeaders;
 
-                                if (_response.StatusCode == HttpStatusCode.Continue && expect100ContinueWaiter != null)
+                                if (_response.StatusCode == HttpStatusCode.Continue && _expect100ContinueWaiter != null)
                                 {
                                     if (NetEventSource.IsEnabled) Trace("Received 100-Continue status.");
-                                    expect100ContinueWaiter.TrySetResult(true);
-                                    _expect100ContinueWaiter = null;
+                                    _expect100ContinueWaiter.TrySetResult(true);
                                 }
                             }
                             else
@@ -400,14 +404,13 @@ namespace System.Net.Http
                                 _responseProtocolState = ResponseProtocolState.ExpectingHeaders;
 
                                 // If we are waiting for a 100-continue response, signal the waiter now.
-                                if (expect100ContinueWaiter != null)
+                                if (_expect100ContinueWaiter != null)
                                 {
                                     // If the final status code is >= 300, skip sending the body.
                                     bool shouldSendBody = (statusValue < 300);
 
                                     if (NetEventSource.IsEnabled) Trace($"Expecting 100 Continue but received final status {statusValue}.");
-                                    expect100ContinueWaiter.TrySetResult(shouldSendBody);
-                                    _expect100ContinueWaiter = null;
+                                    _expect100ContinueWaiter.TrySetResult(shouldSendBody);
                                 }
                             }
                         }
index 8369924..0219b37 100644 (file)
@@ -1897,8 +1897,13 @@ namespace System.Net.Http.Functional.Tests
 
             await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
             {
+                using (var handler = new SocketsHttpHandler())
                 using (HttpClient client = CreateHttpClient())
                 {
+                    handler.SslOptions.RemoteCertificateValidationCallback = delegate { return true; };
+                    // Increase default Expect: 100-continue timeout to ensure that we don't accidentally fire the timer and send the request body.
+                    handler.Expect100ContinueTimeout = TimeSpan.FromSeconds(300);
+
                     var request = new HttpRequestMessage(HttpMethod.Post, url);
                     request.Version = new Version(2,0);
                     request.Content = new StringContent(new string('*', 3000));
index 94d5875..2d10585 100644 (file)
@@ -137,5 +137,29 @@ namespace System.Net.Http.Functional.Tests
                 Assert.Equal(expectedStatusCode, response.StatusCode);
             }
         }
+
+        [OuterLoop]
+        [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsWindows), nameof(PlatformDetection.IsNotWindowsNanoServer))]    // HttpListener doesn't support nt auth on non-Windows platforms
+        [InlineData(true, 1023)]
+        [InlineData(true, 1024)]
+        [InlineData(true, 1025)]
+        [InlineData(false, 1023)]
+        [InlineData(false, 1024)]
+        [InlineData(false, 1025)]
+        public async Task PostAsync_NtAuthServer_UseExpect100Header_Success(bool ntlm, int contentSize)
+        {
+            NtAuthServer server = ntlm ? _servers.NtlmServer : _servers.NegotiateServer;
+
+            var handler = new HttpClientHandler() { UseDefaultCredentials = true };
+            using (var client = new HttpClient(handler))
+            {
+                client.DefaultRequestHeaders.ExpectContinue = true;
+                var content = new StringContent(new string('A', contentSize));
+                using (HttpResponseMessage response = await client.PostAsync(server.BaseUrl, content))
+                {
+                    Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+                }
+            }
+        }
     }
 }