Fix the token included by HttpClient.HandleFailure (#53133)
authorStephen Toub <stoub@microsoft.com>
Mon, 24 May 2021 18:22:10 +0000 (14:22 -0400)
committerGitHub <noreply@github.com>
Mon, 24 May 2021 18:22:10 +0000 (14:22 -0400)
src/libraries/System.Net.Http/src/System/Net/Http/HttpClient.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientTest.cs

index 5716b66..b09850e 100644 (file)
@@ -583,17 +583,31 @@ namespace System.Net.Http
 
             Exception? toThrow = null;
 
-            if (e is OperationCanceledException oce && !cancellationToken.IsCancellationRequested && !pendingRequestsCts.IsCancellationRequested)
+            if (e is OperationCanceledException oce)
             {
-                // If this exception is for cancellation, but cancellation wasn't requested, either by the caller's token or by the pending requests source,
-                // the only other cause could be a timeout.  Treat it as such.
-                e = toThrow = new TaskCanceledException(SR.Format(SR.net_http_request_timedout, _timeout.TotalSeconds), new TimeoutException(e.Message, e), oce.CancellationToken);
+                if (cancellationToken.IsCancellationRequested)
+                {
+                    if (oce.CancellationToken != cancellationToken)
+                    {
+                        // We got a cancellation exception, and the caller requested cancellation, but the exception doesn't contain that token.
+                        // Massage things so that the cancellation exception we propagate appropriately contains the caller's token (it's possible
+                        // multiple things caused cancellation, in which case we can attribute it to the caller's token, or it's possible the
+                        // exception contains the linked token source, in which case that token isn't meaningful to the caller).
+                        e = toThrow = new TaskCanceledException(oce.Message, oce.InnerException, cancellationToken);
+                    }
+                }
+                else if (!pendingRequestsCts.IsCancellationRequested)
+                {
+                    // If this exception is for cancellation, but cancellation wasn't requested, either by the caller's token or by the pending requests source,
+                    // the only other cause could be a timeout.  Treat it as such.
+                    e = toThrow = new TaskCanceledException(SR.Format(SR.net_http_request_timedout, _timeout.TotalSeconds), new TimeoutException(e.Message, e), oce.CancellationToken);
+                }
             }
-            else if (cts.IsCancellationRequested && e is HttpRequestException) // if cancellationToken is canceled, cts will also be canceled
+            else if (e is HttpRequestException && cts.IsCancellationRequested) // if cancellationToken is canceled, cts will also be canceled
             {
                 // If the cancellation token source was canceled, race conditions abound, and we consider the failure to be
                 // caused by the cancellation (e.g. WebException when reading from canceled response stream).
-                e = toThrow = new OperationCanceledException(cts.Token);
+                e = toThrow = new OperationCanceledException(cancellationToken.IsCancellationRequested ? cancellationToken : cts.Token);
             }
 
             if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, e);
index 883bb8c..06d3f2a 100644 (file)
@@ -289,11 +289,18 @@ namespace System.Net.Http.Functional.Tests
 
                 cts.Cancel();
 
-                await Assert.ThrowsAsync<TaskCanceledException>(() => t1);
-                await Assert.ThrowsAsync<TaskCanceledException>(() => t2);
-                await Assert.ThrowsAsync<TaskCanceledException>(() => t3);
-                await Assert.ThrowsAsync<TaskCanceledException>(() => t4);
-                await Assert.ThrowsAsync<TaskCanceledException>(() => t5);
+                async Task ValidateCancellationAsync(Task t)
+                {
+                    TaskCanceledException tce = await Assert.ThrowsAsync<TaskCanceledException>(() => t);
+                    Assert.Equal(cts.Token, tce.CancellationToken);
+
+                }
+
+                await ValidateCancellationAsync(t1);
+                await ValidateCancellationAsync(t2);
+                await ValidateCancellationAsync(t3);
+                await ValidateCancellationAsync(t4);
+                await ValidateCancellationAsync(t5);
             }
         }
 
@@ -382,7 +389,9 @@ namespace System.Net.Http.Functional.Tests
                     var cts = new CancellationTokenSource();
                     cts.Cancel();
 
-                    await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetStringAsync(uri, cts.Token));
+                    TaskCanceledException tce = await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetStringAsync(uri, cts.Token));
+                    Assert.Equal(cts.Token, tce.CancellationToken);
+
                     onClientFinished.Release();
                 },
                 async server =>
@@ -402,7 +411,8 @@ namespace System.Net.Http.Functional.Tests
                 {
                     using HttpClient httpClient = CreateHttpClient();
 
-                    await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetStringAsync(uri, cts.Token));
+                    TaskCanceledException tce = await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetStringAsync(uri, cts.Token));
+                    Assert.Equal(cts.Token, tce.CancellationToken);
                 },
                 async server =>
                 {
@@ -549,7 +559,9 @@ namespace System.Net.Http.Functional.Tests
                     var cts = new CancellationTokenSource();
                     cts.Cancel();
 
-                    await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetByteArrayAsync(uri, cts.Token));
+                    TaskCanceledException tce = await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetByteArrayAsync(uri, cts.Token));
+                    Assert.Equal(cts.Token, tce.CancellationToken);
+
                     onClientFinished.Release();
                 },
                 async server =>
@@ -569,7 +581,8 @@ namespace System.Net.Http.Functional.Tests
                 {
                     using HttpClient httpClient = CreateHttpClient();
 
-                    await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetByteArrayAsync(uri, cts.Token));
+                    TaskCanceledException tce = await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetByteArrayAsync(uri, cts.Token));
+                    Assert.Equal(cts.Token, tce.CancellationToken);
                 },
                 async server =>
                 {
@@ -623,7 +636,9 @@ namespace System.Net.Http.Functional.Tests
                     var cts = new CancellationTokenSource();
                     cts.Cancel();
 
-                    await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetStreamAsync(uri, cts.Token));
+                    TaskCanceledException tce = await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetStreamAsync(uri, cts.Token));
+                    Assert.Equal(cts.Token, tce.CancellationToken);
+
                     onClientFinished.Release();
                 },
                 async server =>
@@ -643,7 +658,8 @@ namespace System.Net.Http.Functional.Tests
                 {
                     using HttpClient httpClient = CreateHttpClient();
 
-                    await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetStreamAsync(uri, cts.Token));
+                    TaskCanceledException tce = await Assert.ThrowsAsync<TaskCanceledException>(() => httpClient.GetStreamAsync(uri, cts.Token));
+                    Assert.Equal(cts.Token, tce.CancellationToken);
                 },
                 async server =>
                 {
@@ -729,7 +745,7 @@ namespace System.Net.Http.Functional.Tests
                 cts.Cancel();
                 Task<HttpResponseMessage> task = client.GetAsync(CreateFakeUri(), completionOption, token);
                 OperationCanceledException e = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await task);
-                Assert.Null(e.InnerException);
+                Assert.Equal(e.CancellationToken, token);
             }
         }
 
@@ -746,7 +762,7 @@ namespace System.Net.Http.Functional.Tests
                 Task<HttpResponseMessage> task = client.GetAsync(CreateFakeUri(), completionOption, cts.Token);
                 cts.Cancel();
                 OperationCanceledException e = Assert.ThrowsAny<OperationCanceledException>(() => task.GetAwaiter().GetResult());
-                Assert.Null(e.InnerException);
+                Assert.Equal(e.CancellationToken, cts.Token);
             }
         }
 
@@ -821,7 +837,8 @@ namespace System.Net.Http.Functional.Tests
 
                 cts.Cancel();
 
-                await Assert.ThrowsAsync<TaskCanceledException>(() => t1);
+                TaskCanceledException tce = await Assert.ThrowsAsync<TaskCanceledException>(() => t1);
+                Assert.Equal(cts.Token, tce.CancellationToken);
             }
         }
 
@@ -966,6 +983,7 @@ namespace System.Net.Http.Functional.Tests
                     });
 
                     TaskCanceledException ex = await Assert.ThrowsAsync<TaskCanceledException>(() => sendTask);
+                    Assert.Equal(cts.Token, ex.CancellationToken);
                     Assert.IsNotType<TimeoutException>(ex.InnerException);
                 },
                 async server =>
@@ -1050,6 +1068,7 @@ namespace System.Net.Http.Functional.Tests
                     });
 
                     TaskCanceledException ex = await Assert.ThrowsAsync<TaskCanceledException>(() => sendTask);
+                    Assert.Equal(cts.Token, ex.CancellationToken);
                     Assert.IsNotType<TimeoutException>(ex.InnerException);
                 },
                 async server =>