Fix race condition in timeout handling and GetAsync_ContentCanBeCanceled (#44169)
authorStephen Toub <stoub@microsoft.com>
Tue, 3 Nov 2020 10:55:04 +0000 (05:55 -0500)
committerGitHub <noreply@github.com>
Tue, 3 Nov 2020 10:55:04 +0000 (05:55 -0500)
There are three issues here.

Two tests issues:
1. When the client times out, it could do so while the server is still reading the request or sending the response, in which case the server can fail and cause the test to fail.  The server failure needs to be ignored.
2. when the client times out, it could do so before it even initiates the request, in which case the test would hang while the server waits indefinitely for a request that'll never arrive.  The server waiting needs to factor in the client's completion.

And one product issue:
1. The timeout handling we added in .NET 5 has a race condition that can, in extreme conditions (e.g. a unit test trying to force the interleaving) result in us not throwing the appropriate TimeoutException.  There are two different timer mechanisms used, the one based on Timer inside of CancellationTokenSource, and then the use of Envrionment.TickCount64 to track how much time has progressed.  Their quantums are different, so it's possible for the timer to fire due to the timeout expiring but Environment.TickCount64 not ticking over, at which point we'll read it and determine that the cause of the cancellation wasn't for a timeout.  The fix is to use the pending requests token source to check whether it or the timeout occurred, rather than using the time elapsed to determine which of the two occurred.

src/libraries/System.Net.Http/src/System/Net/Http/HttpClient.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientTest.cs

index cb68aaa..0a9aa07 100644 (file)
@@ -172,7 +172,7 @@ namespace System.Net.Http
             bool telemetryStarted = StartSend(request);
             bool responseContentTelemetryStarted = false;
 
-            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
+            (CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
             HttpResponseMessage? response = null;
             try
             {
@@ -214,7 +214,7 @@ namespace System.Net.Http
             }
             catch (Exception e)
             {
-                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
+                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, pendingRequestsCts);
                 throw;
             }
             finally
@@ -247,7 +247,7 @@ namespace System.Net.Http
             bool telemetryStarted = StartSend(request);
             bool responseContentTelemetryStarted = false;
 
-            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
+            (CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
             HttpResponseMessage? response = null;
             try
             {
@@ -293,7 +293,7 @@ namespace System.Net.Http
             }
             catch (Exception e)
             {
-                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
+                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, pendingRequestsCts);
                 throw;
             }
             finally
@@ -325,7 +325,7 @@ namespace System.Net.Http
         {
             bool telemetryStarted = StartSend(request);
 
-            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
+            (CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
             HttpResponseMessage? response = null;
             try
             {
@@ -339,7 +339,7 @@ namespace System.Net.Http
             }
             catch (Exception e)
             {
-                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
+                HandleFailure(e, telemetryStarted, response, cts, cancellationToken, pendingRequestsCts);
                 throw;
             }
             finally
@@ -458,8 +458,8 @@ namespace System.Net.Http
             // Called outside of async state machine to propagate certain exception even without awaiting the returned task.
             CheckRequestBeforeSend(request);
 
-            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
-            ValueTask<HttpResponseMessage> sendTask = SendAsyncCore(request, completionOption, async: false, cts, disposeCts, timeoutTime, cancellationToken);
+            (CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
+            ValueTask<HttpResponseMessage> sendTask = SendAsyncCore(request, completionOption, async: false, cts, disposeCts, pendingRequestsCts, cancellationToken);
             Debug.Assert(sendTask.IsCompleted);
             return sendTask.GetAwaiter().GetResult();
         }
@@ -478,8 +478,8 @@ namespace System.Net.Http
             // Called outside of async state machine to propagate certain exception even without awaiting the returned task.
             CheckRequestBeforeSend(request);
 
-            (CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
-            return SendAsyncCore(request, completionOption, async: true, cts, disposeCts, timeoutTime, cancellationToken).AsTask();
+            (CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
+            return SendAsyncCore(request, completionOption, async: true, cts, disposeCts, pendingRequestsCts, cancellationToken).AsTask();
         }
 
         private void CheckRequestBeforeSend(HttpRequestMessage request)
@@ -499,7 +499,8 @@ namespace System.Net.Http
 
         private async ValueTask<HttpResponseMessage> SendAsyncCore(
             HttpRequestMessage request, HttpCompletionOption completionOption,
-            bool async, CancellationTokenSource cts, bool disposeCts, long timeoutTime, CancellationToken originalCancellationToken)
+            bool async, CancellationTokenSource cts, bool disposeCts,
+            CancellationTokenSource pendingRequestsCts, CancellationToken originalCancellationToken)
         {
             bool telemetryStarted = StartSend(request);
             bool responseContentTelemetryStarted = false;
@@ -537,7 +538,7 @@ namespace System.Net.Http
             }
             catch (Exception e)
             {
-                HandleFailure(e, telemetryStarted, response, cts, originalCancellationToken, timeoutTime);
+                HandleFailure(e, telemetryStarted, response, cts, originalCancellationToken, pendingRequestsCts);
                 throw;
             }
             finally
@@ -554,7 +555,7 @@ namespace System.Net.Http
             }
         }
 
-        private void HandleFailure(Exception e, bool telemetryStarted, HttpResponseMessage? response, CancellationTokenSource cts, CancellationToken cancellationToken, long timeoutTime)
+        private void HandleFailure(Exception e, bool telemetryStarted, HttpResponseMessage? response, CancellationTokenSource cts, CancellationToken cancellationToken, CancellationTokenSource pendingRequestsCts)
         {
             LogRequestFailed(telemetryStarted);
 
@@ -562,10 +563,10 @@ namespace System.Net.Http
 
             Exception? toThrow = null;
 
-            if (e is OperationCanceledException oce && !cancellationToken.IsCancellationRequested && Environment.TickCount64 >= timeoutTime)
+            if (e is OperationCanceledException oce && !cancellationToken.IsCancellationRequested && !pendingRequestsCts.IsCancellationRequested)
             {
-                // If this exception is for cancellation, but cancellation wasn't requested and instead we find that we've passed a timeout end time,
-                // treat this instead as a timeout.
+                // If this exception is for cancellation, but cancellation wasn't requested, either by the caller's token or by the pending requests source,
+                // the only other cause could be a timeout.  Treat it as such.
                 e = toThrow = new TaskCanceledException(string.Format(SR.net_http_request_timedout, _timeout.TotalSeconds), new TimeoutException(e.Message, e), oce.CancellationToken);
             }
             else if (cts.IsCancellationRequested && e is HttpRequestException) // if cancellationToken is canceled, cts will also be canceled
@@ -745,28 +746,33 @@ namespace System.Net.Http
             }
         }
 
-        private (CancellationTokenSource TokenSource, bool DisposeTokenSource, long TimeoutTime) PrepareCancellationTokenSource(CancellationToken cancellationToken)
+        private (CancellationTokenSource TokenSource, bool DisposeTokenSource, CancellationTokenSource PendingRequestsCts) PrepareCancellationTokenSource(CancellationToken cancellationToken)
         {
             // We need a CancellationTokenSource to use with the request.  We always have the global
             // _pendingRequestsCts to use, plus we may have a token provided by the caller, and we may
             // have a timeout.  If we have a timeout or a caller-provided token, we need to create a new
             // CTS (we can't, for example, timeout the pending requests CTS, as that could cancel other
             // unrelated operations).  Otherwise, we can use the pending requests CTS directly.
+
+            // Snapshot the current pending requests cancellation source. It can change concurrently due to cancellation being requested
+            // and it being replaced, and we need a stable view of it: if cancellation occurs and the caller's token hasn't been canceled,
+            // it's either due to this source or due to the timeout, and checking whether this source is the culprit is reliable whereas
+            // it's more approximate checking elapsed time.
+            CancellationTokenSource pendingRequestsCts = _pendingRequestsCts;
+
             bool hasTimeout = _timeout != s_infiniteTimeout;
-            long timeoutTime = long.MaxValue;
             if (hasTimeout || cancellationToken.CanBeCanceled)
             {
-                CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _pendingRequestsCts.Token);
+                CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, pendingRequestsCts.Token);
                 if (hasTimeout)
                 {
-                    timeoutTime = Environment.TickCount64 + (_timeout.Ticks / TimeSpan.TicksPerMillisecond);
                     cts.CancelAfter(_timeout);
                 }
 
-                return (cts, DisposeTokenSource: true, timeoutTime);
+                return (cts, DisposeTokenSource: true, pendingRequestsCts);
             }
 
-            return (_pendingRequestsCts, DisposeTokenSource: false, timeoutTime);
+            return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts);
         }
 
         private static void CheckBaseAddress(Uri? baseAddress, string parameterName)
index a129b72..94a15be 100644 (file)
@@ -479,7 +479,7 @@ namespace System.Net.Http.Functional.Tests
                 },
                 async server =>
                 {
-                    await server.AcceptConnectionAsync(async connection =>
+                    Task serverHandling = server.AcceptConnectionAsync(async connection =>
                     {
                         await connection.ReadRequestDataAsync(readBody: false);
                         await connection.SendResponseAsync(HttpStatusCode.OK, headers: new HttpHeaderData[] { new HttpHeaderData("Content-Length", "5") });
@@ -497,11 +497,20 @@ namespace System.Net.Http.Functional.Tests
                                 httpClient.CancelPendingRequests();
                                 break;
 
-                            // case 2: timeout fires on its own
+                                // case 2: timeout fires on its own
                         }
 
                         await tcs.Task;
                     });
+
+                    // The client may have completed before even sending a request when testing HttpClient.Timeout.
+                    await Task.WhenAny(serverHandling, tcs.Task);
+                    if (cancelMode != 2)
+                    {
+                        // If using a timeout to cancel requests, it's possible the server's processing could have gotten interrupted,
+                        // so we want to ignore any exceptions from the server when in that mode.  For anything else, let exceptions propagate.
+                        await serverHandling;
+                    }
                 });
         }