Fix ManagedWebSocket.CloseAsync cancellation (#35715)
authorStephen Toub <stoub@microsoft.com>
Sun, 3 May 2020 21:49:22 +0000 (17:49 -0400)
committerGitHub <noreply@github.com>
Sun, 3 May 2020 21:49:22 +0000 (17:49 -0400)
We allow at most one pending receive operation on a web socket, and CloseAsync needs to both send and receive, so if there's a pending receive already, it just reuses / waits for that existing one, which records if it sees a close frame.  However, if the close operation is initiated with a cancellation token that's different from the cancellation token the receive operation was initiated with, the close won't respect the supplied cancellation token because it's just waiting on the existing receive.  The fix is simply to register with this new cancellation token as well when there's an existing receive that we wait on.

src/libraries/Common/src/System/Net/WebSockets/ManagedWebSocket.cs
src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs

index 6e54b0a..8a2d425 100644 (file)
@@ -1111,6 +1111,7 @@ namespace System.Net.WebSockets
                     {
                         Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
                         Task receiveTask;
+                        bool usingExistingReceive;
                         lock (ReceiveAsyncLock)
                         {
                             // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame.
@@ -1127,12 +1128,19 @@ namespace System.Net.WebSockets
                             // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst
                             // case is we then await it, find that it's not what we need, and try again.
                             receiveTask = _lastReceiveAsync;
-                            _lastReceiveAsync = receiveTask = ValidateAndReceiveAsync(receiveTask, closeBuffer, cancellationToken);
+                            Task newReceiveTask = ValidateAndReceiveAsync(receiveTask, closeBuffer, cancellationToken);
+                            usingExistingReceive = ReferenceEquals(receiveTask, newReceiveTask);
+                            _lastReceiveAsync = receiveTask = newReceiveTask;
                         }
 
                         // Wait for whatever receive task we have.  We'll then loop around again to re-check our state.
+                        // If this is an existing receive, and if we have a cancelable token, we need to register with that
+                        // token while we wait, since it may not be the same one that was given to the receive initially.
                         Debug.Assert(receiveTask != null);
-                        await receiveTask.ConfigureAwait(false);
+                        using (usingExistingReceive ? cancellationToken.Register(s => ((ManagedWebSocket)s!).Abort(), this) : default)
+                        {
+                            await receiveTask.ConfigureAwait(false);
+                        }
                     }
                 }
                 finally
index 918c5e6..c0774aa 100644 (file)
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Net.Test.Common;
 using System.Text;
 using System.Threading;
@@ -324,5 +325,47 @@ namespace System.Net.WebSockets.Client.Tests
                 }
             }
         }
+
+        [ConditionalFact(nameof(WebSocketsSupported))]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)]
+        public async Task CloseAsync_CancelableEvenWhenPendingReceive_Throws()
+        {
+            var tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            await LoopbackServer.CreateClientAndServerAsync(async uri =>
+            {
+                try
+                {
+                    using (var cws = new ClientWebSocket())
+                    using (var cts = new CancellationTokenSource(TimeOutMilliseconds))
+                    {
+                        await cws.ConnectAsync(uri, cts.Token);
+
+                        Task receiveTask = cws.ReceiveAsync(new byte[1], CancellationToken.None);
+
+                        var cancelCloseCts = new CancellationTokenSource();
+                        await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
+                        {
+                            Task t = cws.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cancelCloseCts.Token);
+                            cancelCloseCts.Cancel();
+                            await t;
+                        });
+
+                        await Assert.ThrowsAnyAsync<OperationCanceledException>(() => receiveTask);
+                    }
+                }
+                finally
+                {
+                    tcs.SetResult(true);
+                }
+            }, server => server.AcceptConnectionAsync(async connection =>
+            {
+                Dictionary<string, string> headers = await LoopbackHelper.WebSocketHandshakeAsync(connection);
+                Assert.NotNull(headers);
+
+                await tcs.Task;
+
+            }), new LoopbackServer.Options { WebSocketEndpoint = true });
+        }
     }
 }