{
Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
Task receiveTask;
+ bool usingExistingReceive;
lock (ReceiveAsyncLock)
{
// Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame.
// a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst
// case is we then await it, find that it's not what we need, and try again.
receiveTask = _lastReceiveAsync;
- _lastReceiveAsync = receiveTask = ValidateAndReceiveAsync(receiveTask, closeBuffer, cancellationToken);
+ Task newReceiveTask = ValidateAndReceiveAsync(receiveTask, closeBuffer, cancellationToken);
+ usingExistingReceive = ReferenceEquals(receiveTask, newReceiveTask);
+ _lastReceiveAsync = receiveTask = newReceiveTask;
}
// Wait for whatever receive task we have. We'll then loop around again to re-check our state.
+ // If this is an existing receive, and if we have a cancelable token, we need to register with that
+ // token while we wait, since it may not be the same one that was given to the receive initially.
Debug.Assert(receiveTask != null);
- await receiveTask.ConfigureAwait(false);
+ using (usingExistingReceive ? cancellationToken.Register(s => ((ManagedWebSocket)s!).Abort(), this) : default)
+ {
+ await receiveTask.ConfigureAwait(false);
+ }
}
}
finally
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
+using System.Diagnostics;
using System.Net.Test.Common;
using System.Text;
using System.Threading;
}
}
}
+
+ [ConditionalFact(nameof(WebSocketsSupported))]
+ [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)]
+ public async Task CloseAsync_CancelableEvenWhenPendingReceive_Throws()
+ {
+ var tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ await LoopbackServer.CreateClientAndServerAsync(async uri =>
+ {
+ try
+ {
+ using (var cws = new ClientWebSocket())
+ using (var cts = new CancellationTokenSource(TimeOutMilliseconds))
+ {
+ await cws.ConnectAsync(uri, cts.Token);
+
+ Task receiveTask = cws.ReceiveAsync(new byte[1], CancellationToken.None);
+
+ var cancelCloseCts = new CancellationTokenSource();
+ await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
+ {
+ Task t = cws.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cancelCloseCts.Token);
+ cancelCloseCts.Cancel();
+ await t;
+ });
+
+ await Assert.ThrowsAnyAsync<OperationCanceledException>(() => receiveTask);
+ }
+ }
+ finally
+ {
+ tcs.SetResult(true);
+ }
+ }, server => server.AcceptConnectionAsync(async connection =>
+ {
+ Dictionary<string, string> headers = await LoopbackHelper.WebSocketHandshakeAsync(connection);
+ Assert.NotNull(headers);
+
+ await tcs.Task;
+
+ }), new LoopbackServer.Options { WebSocketEndpoint = true });
+ }
}
}