Refactor ManagedWebSocket to avoid forcing Task allocations for ReceiveAsync (#56282)
authorStephen Toub <stoub@microsoft.com>
Wed, 11 Aug 2021 02:14:13 +0000 (22:14 -0400)
committerGitHub <noreply@github.com>
Wed, 11 Aug 2021 02:14:13 +0000 (22:14 -0400)
* Refactor ManagedWebSocket to avoid forcing Task allocations for ReceiveAsync

The ManagedWebSocket implementation today supports CloseAsyncs being issued concurrently with ReceiveAsyncs, even though CloseAsync needs to issue receives (this allowance was carried over from the .NET Framework implementation).  Currently the implementation does that by storing the last ReceiveAsync task and awaiting it in CloseAsync if there is one, but that means multiple parties may try to await the same task multiple times (the original caller of ReceiveAsync and CloseAsync), which means we can't just use a ValueTask.  So today asynchronously completing ReceiveAsyncs always use AsTask to create a Task from the returned ValueTask.  This isn't actually an additional task allocation today, as the async ValueTask builder will create a Task for the asynchronously completing operation, and then AsTask will just return that (and when it completes synchronously, there's extra code to substitute a singleton).  But once we switch to using the new pooling builder, that's no longer the case.

This PR uses an async lock as part of the ReceiveAsync implementation, with the existing async method awaiting entering that lock.  CloseAsync is then rewritten to be in terms of calling ReceiveAsync in a loop.  This also lets us remove the existing Monitor used for synchronously coordinating state between these operations, as the async lock serves that purpose as well.  Rather than using a SemaphoreSlim, since we expect zero contention in the common case, we use a simple AsyncMutex that's optimized for the zero contention case, using a single interlocked to acquire and a single interlocked to release the lock.

* Fix misleading comment

src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj
src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs [new file with mode: 0644]
src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs

index 215cf6b..8a49207 100644 (file)
@@ -5,8 +5,10 @@
     <Nullable>enable</Nullable>
   </PropertyGroup>
   <ItemGroup>
+    <Compile Include="System\Net\WebSockets\AsyncMutex.cs" />
     <Compile Include="System\Net\WebSockets\Compression\WebSocketDeflater.cs" />
     <Compile Include="System\Net\WebSockets\Compression\WebSocketInflater.cs" />
+    <Compile Include="System\Net\WebSockets\ManagedWebSocket.cs" />
     <Compile Include="System\Net\WebSockets\ValueWebSocketReceiveResult.cs" />
     <Compile Include="System\Net\WebSockets\WebSocket.cs" />
     <Compile Include="System\Net\WebSockets\WebSocketCloseStatus.cs" />
@@ -19,7 +21,6 @@
     <Compile Include="System\Net\WebSockets\WebSocketMessageFlags.cs" />
     <Compile Include="System\Net\WebSockets\WebSocketReceiveResult.cs" />
     <Compile Include="System\Net\WebSockets\WebSocketState.cs" />
-    <Compile Include="System\Net\WebSockets\ManagedWebSocket.cs" />
     <Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketValidate.cs"
              Link="Common\System\Net\WebSockets\WebSocketValidate.cs" />
     <Compile Include="$(CommonPath)System\IO\Compression\ZLibNative.cs"
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs
new file mode 100644 (file)
index 0000000..4191466
--- /dev/null
@@ -0,0 +1,242 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Threading.Tasks;
+
+namespace System.Threading
+{
+    /// <summary>Provides an async mutex.</summary>
+    /// <remarks>
+    /// This could be achieved with a <see cref="SemaphoreSlim"/> constructed with an initial
+    /// and max limit of 1.  However, this implementation is optimized to the needs of ManagedWebSocket,
+    /// which is that we expect zero contention in typical use cases.
+    /// </remarks>
+    internal sealed class AsyncMutex
+    {
+        /// <summary>Fast-path gate count tracking access to the mutex.</summary>
+        /// <remarks>
+        /// If the value is 1, the mutex can be entered atomically with an interlocked operation.
+        /// If the value is less than or equal to 0, the mutex is held and requires fallback to enter it.
+        /// </remarks>
+        private int _gate = 1;
+        /// <summary>Secondary check guarded by the lock to indicate whether the mutex is acquired.</summary>
+        /// <remarks>
+        /// This is only meaningful after having updated <see cref="_gate"/> via interlockeds and taken the appropriate path.
+        /// If after decrementing <see cref="_gate"/> we end up with a negative count, the mutex is contended, hence
+        /// <see cref="_lockedSemaphoreFull"/> starting as <c>true</c>.  The primary purpose of this field
+        /// is to handle the race condition between one thread acquiring the mutex, then another thread trying to acquire
+        /// and getting as far as completing the interlocked operation, and then the original thread releasing; at that point
+        /// it'll hit the lock and we need to store that the mutex is available to enter.  If we instead used a
+        /// SemaphoreSlim as the fallback from the interlockeds, this would have been its count, and it would have started
+        /// with an initial count of 0.
+        /// </remarks>
+        private bool _lockedSemaphoreFull = true;
+        /// <summary>The tail of the double-linked circular waiting queue.</summary>
+        /// <remarks>
+        /// Waiters are added at the tail.
+        /// Items are dequeued from the head (tail.Prev).
+        /// </remarks>
+        private Waiter? _waitersTail;
+
+        /// <summary>Gets whether the mutex is currently held by some operation (not necessarily the caller).</summary>
+        /// <remarks>This should be used only for asserts and debugging.</remarks>
+        public bool IsHeld => _gate != 1;
+
+        /// <summary>Gets the object used to synchronize contended operations.</summary>
+        private object SyncObj => this;
+
+        /// <summary>Asynchronously waits to enter the mutex.</summary>
+        /// <param name="cancellationToken">The CancellationToken token to observe.</param>
+        /// <returns>A task that will complete when the mutex has been entered or the enter canceled.</returns>
+        public Task EnterAsync(CancellationToken cancellationToken)
+        {
+            // If cancellation was requested, bail immediately.
+            // If the mutex is not currently held nor contended, enter immediately.
+            // Otherwise, fall back to a more expensive likely-asynchronous wait.
+            return
+                cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) :
+                Interlocked.Decrement(ref _gate) >= 0 ? Task.CompletedTask :
+                Contended(cancellationToken);
+
+            // Everything that follows is the equivalent of:
+            //     return _sem.WaitAsync(cancellationToken);
+            // if _sem were to be constructed as `new SemaphoreSlim(0)`.
+
+            Task Contended(CancellationToken cancellationToken)
+            {
+                var w = new Waiter(this);
+
+                // We need to register for cancellation before storing the waiter into the list.
+                // If we registered after, we might leak a registration if the mutex was exited and the waiter
+                // removed from the list prior to CancellationRegistration being properly assigned. By registering before,
+                // there's a different race condition, that of cancellation being requested prior to storing the waiter into
+                // the list; if that happens, we could end up adding the waiter and have it still stored in the list even
+                // though OnCancellation was called. So once we hold the lock, which OnCancellation also needs to take, we
+                // check again whether cancellation has been requested,and avoid storing the waiter if it has.
+                w.CancellationRegistration = cancellationToken.UnsafeRegister((s, token) => OnCancellation(s, token), w);
+
+                lock (SyncObj)
+                {
+                    // Now that we're holding the lock, check to see whether the async lock is acquirable.
+                    if (!_lockedSemaphoreFull)
+                    {
+                        // If we are able to acquire the lock, we're done; we just need to clean up after the registration.
+                        w.CancellationRegistration.Unregister();
+                        _lockedSemaphoreFull = true;
+                        return Task.CompletedTask;
+                    }
+
+                    // Now that we're holding the lock and thus synchronized with OnCancellation, check to see
+                    // if cancellation has been requested.
+                    if (cancellationToken.IsCancellationRequested)
+                    {
+                        w.TrySetCanceled(cancellationToken);
+                        return w.Task;
+                    }
+
+                    // The lock couldn't be acquired.
+                    // Add the waiter to the linked list of waiters.
+                    if (_waitersTail is null)
+                    {
+                        w.Next = w.Prev = w;
+                    }
+                    else
+                    {
+                        Debug.Assert(_waitersTail.Next != null && _waitersTail.Prev != null);
+                        w.Next = _waitersTail;
+                        w.Prev = _waitersTail.Prev;
+                        w.Prev.Next = w.Next.Prev = w;
+                    }
+                    _waitersTail = w;
+                }
+
+                // Return the waiter as a value task.
+                return w.Task;
+
+                // Cancels the specified waiter if it's still in the list.
+                static void OnCancellation(object? state, CancellationToken cancellationToken)
+                {
+                    Waiter? w = (Waiter)state!;
+                    AsyncMutex m = w.Owner;
+
+                    lock (m.SyncObj)
+                    {
+                        bool inList = w.Next != null;
+                        if (inList)
+                        {
+                            // The waiter is in the list.
+                            Debug.Assert(w.Prev != null);
+
+                            // The gate counter was decremented when this waiter was added.  We need
+                            // to undo that.  Since the waiter is still in the list, the lock must
+                            // still be held by someone, which means we don't need to do anything with
+                            // the result of this increment.  If it increments to < 1, then there are
+                            // still other waiters.  If it increments to 1, we're in a rare race condition
+                            // where there are no other waiters and the owner just incremented the gate
+                            // count; they would have seen it be < 1, so they will proceed to take the
+                            // contended code path and synchronize on the lock we're holding... once we
+                            // release it, they will appropriately update state.
+                            Interlocked.Increment(ref m._gate);
+
+                            if (w.Next == w)
+                            {
+                                Debug.Assert(m._waitersTail == w);
+                                m._waitersTail = null;
+                            }
+                            else
+                            {
+                                w.Next!.Prev = w.Prev;
+                                w.Prev.Next = w.Next;
+                                if (m._waitersTail == w)
+                                {
+                                    m._waitersTail = w.Next;
+                                }
+                            }
+
+                            // Remove it from the list.
+                            w.Next = w.Prev = null;
+                        }
+                        else
+                        {
+                            // The waiter was no longer in the list.  We must not cancel it.
+                            w = null;
+                        }
+                    }
+
+                    // If the waiter was in the list, we removed it under the lock and thus own
+                    // the ability to cancel it.  Do so.
+                    w?.TrySetCanceled(cancellationToken);
+                }
+            }
+        }
+
+        /// <summary>Releases the mutex.</summary>
+        /// <remarks>The caller must logically own the mutex.  This is not validated.</remarks>
+        public void Exit()
+        {
+            if (Interlocked.Increment(ref _gate) < 1)
+            {
+                // This is the equivalent of:
+                //     _sem.Release();
+                // if _sem were to be constructed as `new SemaphoreSlim(0)`.
+                Contended();
+            }
+
+            void Contended()
+            {
+                Waiter? w;
+                lock (SyncObj)
+                {
+                    Debug.Assert(_lockedSemaphoreFull);
+
+                    w = _waitersTail;
+                    if (w is null)
+                    {
+                        _lockedSemaphoreFull = false;
+                    }
+                    else
+                    {
+                        Debug.Assert(w.Next != null && w.Prev != null);
+                        Debug.Assert(w.Next != w || w.Prev == w);
+                        Debug.Assert(w.Prev != w || w.Next == w);
+
+                        if (w.Next == w)
+                        {
+                            _waitersTail = null;
+                        }
+                        else
+                        {
+                            w = w.Prev; // get the head
+                            Debug.Assert(w.Next != null && w.Prev != null);
+                            Debug.Assert(w.Next != w && w.Prev != w);
+
+                            w.Next.Prev = w.Prev;
+                            w.Prev.Next = w.Next;
+                        }
+
+                        w.Next = w.Prev = null;
+                    }
+                }
+
+                // Either there wasn't a waiter, or we got one and successfully removed it from the list,
+                // at which point we own the ability to complete it.  Do so.
+                if (w is not null)
+                {
+                    w.CancellationRegistration.Unregister();
+                    w.TrySetResult();
+                }
+            }
+        }
+
+        /// <summary>Represents a waiter for the mutex.</summary>
+        private sealed class Waiter : TaskCompletionSource
+        {
+            public Waiter(AsyncMutex owner) : base(TaskCreationOptions.RunContinuationsAsynchronously) => Owner = owner;
+            public AsyncMutex Owner { get; }
+            public CancellationTokenRegistration CancellationRegistration { get; set; }
+            public Waiter? Next { get; set; }
+            public Waiter? Prev { get; set; }
+        }
+    }
+}
index 297b4c5..05b171c 100644 (file)
@@ -38,9 +38,6 @@ namespace System.Net.WebSockets
         /// <summary>Valid states to be in when calling CloseAsync.</summary>
         private static readonly WebSocketState[] s_validCloseStates = { WebSocketState.Open, WebSocketState.CloseReceived, WebSocketState.CloseSent };
 
-        /// <summary>Successfully completed task representing a close message.</summary>
-        private static readonly Task<WebSocketReceiveResult> s_cachedCloseTask = Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Close, true));
-
         /// <summary>The maximum size in bytes of a message frame header that includes mask bytes.</summary>
         internal const int MaxMessageHeaderLength = 14;
         /// <summary>The maximum size of a control message payload.</summary>
@@ -68,9 +65,15 @@ namespace System.Net.WebSockets
         /// </summary>
         private readonly Utf8MessageState _utf8TextState = new Utf8MessageState();
         /// <summary>
-        /// Semaphore used to ensure that calls to SendFrameAsync don't run concurrently.
+        /// Mutex used to ensure that calls to SendFrameAsync don't run concurrently.  We don't support multiple concurrent SendAsync calls,
+        /// but this is needed to support SendAsync concurrently with keep-alive pings and CloseAsync.
+        /// </summary>
+        private readonly AsyncMutex _sendMutex = new AsyncMutex();
+        /// <summary>
+        /// Mutex used to ensure that calls to ReceiveAsyncPrivate don't run concurrently.  We don't support multiple concurrent ReceiveAsync calls,
+        /// but this is needed to support SendAsync concurrently with keep-alive pings and CloseAsync.
         /// </summary>
-        private readonly SemaphoreSlim _sendFrameAsyncLock = new SemaphoreSlim(1, 1);
+        private readonly AsyncMutex _receiveMutex = new AsyncMutex();
 
         // We maintain the current WebSocketState in _state.  However, we separately maintain _sentCloseFrame and _receivedCloseFrame
         // as there isn't a strict ordering between CloseSent and CloseReceived.  If we receive a close frame from the server, we need to
@@ -125,20 +128,9 @@ namespace System.Net.WebSockets
         /// Whether the last SendAsync had <seealso cref="WebSocketMessageFlags.DisableCompression" /> flag set.
         /// </summary>
         private bool _lastSendHadDisableCompression;
-        /// <summary>
-        /// The task returned from the last ReceiveAsync(ArraySegment, ...) operation to not complete synchronously.
-        /// If this is not null and not completed when a subsequent ReceiveAsync is issued, an exception occurs.
-        /// </summary>
-        private Task _lastReceiveAsync = Task.CompletedTask;
 
         /// <summary>Lock used to protect update and check-and-update operations on _state.</summary>
-        private object StateUpdateLock => _sendFrameAsyncLock;
-        /// <summary>
-        /// We need to coordinate between receives and close operations happening concurrently, as a ReceiveAsync may
-        /// be pending while a Close{Output}Async is issued, which itself needs to loop until a close frame is received.
-        /// As such, we need thread-safety in the management of <see cref="_lastReceiveAsync"/>.
-        /// </summary>
-        private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it
+        private object StateUpdateLock => _sendMutex;
 
         private readonly WebSocketInflater? _inflater;
         private readonly WebSocketDeflater? _deflater;
@@ -151,13 +143,10 @@ namespace System.Net.WebSockets
         internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval)
         {
             Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null");
-            Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null");
-            Debug.Assert(StateUpdateLock != ReceiveAsyncLock, "Locks should be different objects");
-
-            Debug.Assert(stream != null, $"Expected non-null stream");
-            Debug.Assert(stream.CanRead, $"Expected readable stream");
-            Debug.Assert(stream.CanWrite, $"Expected writeable stream");
-            Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid keepalive interval: {keepAliveInterval}");
+            Debug.Assert(stream != null, $"Expected non-null {nameof(stream)}");
+            Debug.Assert(stream.CanRead, $"Expected readable {nameof(stream)}");
+            Debug.Assert(stream.CanWrite, $"Expected writeable {nameof(stream)}");
+            Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid {nameof(keepAliveInterval)}: {keepAliveInterval}");
 
             _stream = stream;
             _isServer = isServer;
@@ -254,10 +243,13 @@ namespace System.Net.WebSockets
 
             WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer));
 
-            return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken).AsTask();
+            return SendAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken).AsTask();
         }
 
-        private ValueTask SendPrivateAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
+        public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) =>
+            SendAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken);
+
+        public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
         {
             if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary)
             {
@@ -308,14 +300,7 @@ namespace System.Net.WebSockets
             {
                 WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates);
 
-                Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
-                lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync
-                {
-                    ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted);
-                    Task<WebSocketReceiveResult> t = ReceiveAsyncPrivate<WebSocketReceiveResult>(buffer, cancellationToken).AsTask();
-                    _lastReceiveAsync = t;
-                    return t;
-                }
+                return ReceiveAsyncPrivate<WebSocketReceiveResult>(buffer, cancellationToken).AsTask();
             }
             catch (Exception exc)
             {
@@ -323,6 +308,20 @@ namespace System.Net.WebSockets
             }
         }
 
+        public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, CancellationToken cancellationToken)
+        {
+            try
+            {
+                WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates);
+
+                return ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(buffer, cancellationToken);
+            }
+            catch (Exception exc)
+            {
+                return ValueTask.FromException<ValueWebSocketReceiveResult>(exc);
+            }
+        }
+
         public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
         {
             WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription);
@@ -382,65 +381,6 @@ namespace System.Net.WebSockets
             }
         }
 
-        public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
-        {
-            return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken);
-        }
-
-        public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
-        {
-            return SendPrivateAsync(buffer, messageType, messageFlags, cancellationToken);
-        }
-
-        public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, CancellationToken cancellationToken)
-        {
-            try
-            {
-                WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates);
-
-                Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
-                lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync
-                {
-                    ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted);
-
-                    ValueTask<ValueWebSocketReceiveResult> receiveValueTask = ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(buffer, cancellationToken);
-                    if (receiveValueTask.IsCompletedSuccessfully)
-                    {
-                        _lastReceiveAsync = receiveValueTask.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask;
-                        return receiveValueTask;
-                    }
-
-                    // We need to both store the last receive task and return it, but we can't do that with a ValueTask,
-                    // as that could result in consuming it multiple times.  Instead, we use AsTask to consume it just once,
-                    // and then store that Task and return a new ValueTask that wraps it. (It would be nice in the future
-                    // to avoid this AsTask as well; currently it's used both for error detection and as part of close tracking.)
-                    Task<ValueWebSocketReceiveResult> receiveTask = receiveValueTask.AsTask();
-                    _lastReceiveAsync = receiveTask;
-                    return new ValueTask<ValueWebSocketReceiveResult>(receiveTask);
-                }
-            }
-            catch (Exception exc)
-            {
-                return ValueTask.FromException<ValueWebSocketReceiveResult>(exc);
-            }
-        }
-
-        private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, CancellationToken cancellationToken)
-        {
-            if (receiveTask == null ||
-                        (receiveTask.IsCompletedSuccessfully &&
-                         !(receiveTask is Task<WebSocketReceiveResult> wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) &&
-                         !(receiveTask is Task<ValueWebSocketReceiveResult> vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close)))
-            {
-                ValueTask<ValueWebSocketReceiveResult> vt = ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(buffer, cancellationToken);
-                receiveTask =
-                    vt.IsCompletedSuccessfully ? (vt.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask) :
-                    vt.AsTask();
-            }
-
-            return receiveTask;
-        }
-
         /// <summary>Sends a websocket frame to the network.</summary>
         /// <param name="opcode">The opcode for the message.</param>
         /// <param name="endOfMessage">The value of the FIN bit for the message.</param>
@@ -453,10 +393,9 @@ namespace System.Net.WebSockets
             // pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path.
             // Similarly, it should be rare that there are multiple outstanding calls to SendFrameAsync, but if there are, again
             // fall back to the fallback path.
-#pragma warning disable CA1416 // Validate platform compatibility, will not wait because timeout equals 0
-            return cancellationToken.CanBeCanceled || !_sendFrameAsyncLock.Wait(0, default) ?
-#pragma warning restore CA1416
-                SendFrameFallbackAsync(opcode, endOfMessage, disableCompression, payloadBuffer, cancellationToken) :
+            Task lockTask = _sendMutex.EnterAsync(cancellationToken);
+            return cancellationToken.CanBeCanceled || !lockTask.IsCompletedSuccessfully ?
+                SendFrameFallbackAsync(opcode, endOfMessage, disableCompression, payloadBuffer, lockTask, cancellationToken) :
                 SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, disableCompression, payloadBuffer);
         }
 
@@ -467,7 +406,7 @@ namespace System.Net.WebSockets
         /// <param name="payloadBuffer">The buffer containing the payload data fro the message.</param>
         private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory<byte> payloadBuffer)
         {
-            Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock");
+            Debug.Assert(_sendMutex.IsHeld, $"Caller should hold the {nameof(_sendMutex)}");
 
             // If we get here, the cancellation token is not cancelable so we don't have to worry about it,
             // and we own the semaphore, so we don't need to asynchronously wait for it.
@@ -504,7 +443,7 @@ namespace System.Net.WebSockets
                 if (releaseSendBufferAndSemaphore)
                 {
                     ReleaseSendBuffer();
-                    _sendFrameAsyncLock.Release();
+                    _sendMutex.Exit();
                 }
             }
 
@@ -526,13 +465,13 @@ namespace System.Net.WebSockets
             finally
             {
                 ReleaseSendBuffer();
-                _sendFrameAsyncLock.Release();
+                _sendMutex.Exit();
             }
         }
 
-        private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory<byte> payloadBuffer, CancellationToken cancellationToken)
+        private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory<byte> payloadBuffer, Task lockTask, CancellationToken cancellationToken)
         {
-            await _sendFrameAsyncLock.WaitAsync(cancellationToken).ConfigureAwait(false);
+            await lockTask.ConfigureAwait(false);
             try
             {
                 int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span);
@@ -550,7 +489,7 @@ namespace System.Net.WebSockets
             finally
             {
                 ReleaseSendBuffer();
-                _sendFrameAsyncLock.Release();
+                _sendMutex.Exit();
             }
         }
 
@@ -606,32 +545,21 @@ namespace System.Net.WebSockets
 
         private void SendKeepAliveFrameAsync()
         {
-#pragma warning disable CA1416 // Validate platform compatibility, will not wait because timeout equals 0
-            bool acquiredLock = _sendFrameAsyncLock.Wait(0);
-#pragma warning restore CA1416
-            if (acquiredLock)
-            {
-                // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures.
-                // The call will handle releasing the lock.  We send a pong rather than ping, since it's allowed by
-                // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response.
-                ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory<byte>.Empty);
-                if (t.IsCompletedSuccessfully)
-                {
-                    t.GetAwaiter().GetResult();
-                }
-                else
-                {
-                    // "Observe" any exception, ignoring it to prevent the unobserved exception event from being raised.
-                    t.AsTask().ContinueWith(static p => { _ = p.Exception; },
-                        CancellationToken.None,
-                        TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
-                        TaskScheduler.Default);
-                }
+            // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures.
+            // The call will handle releasing the lock.  We send a pong rather than ping, since it's allowed by
+            // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response.
+            ValueTask t = SendFrameAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory<byte>.Empty, CancellationToken.None);
+            if (t.IsCompletedSuccessfully)
+            {
+                t.GetAwaiter().GetResult();
             }
             else
             {
-                // If the lock is already held, something is already getting sent,
-                // so there's no need to send a keep-alive ping.
+                // "Observe" any exception, ignoring it to prevent the unobserved exception event from being raised.
+                t.AsTask().ContinueWith(static p => { _ = p.Exception; },
+                    CancellationToken.None,
+                    TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
+                    TaskScheduler.Default);
             }
         }
 
@@ -658,7 +586,7 @@ namespace System.Net.WebSockets
             // 0 or 4 bytes - Mask, if Masked is 1 - random value XOR'd with each 4 bytes of the payload, round-robin
             // Length bytes - Payload data
 
-            Debug.Assert(sendBuffer.Length >= MaxMessageHeaderLength, $"Expected sendBuffer to be at least {MaxMessageHeaderLength}, got {sendBuffer.Length}");
+            Debug.Assert(sendBuffer.Length >= MaxMessageHeaderLength, $"Expected {nameof(sendBuffer)} to be at least {MaxMessageHeaderLength}, got {sendBuffer.Length}");
 
             sendBuffer[0] = (byte)opcode; // 4 bits for the opcode
             if (endOfMessage)
@@ -722,6 +650,7 @@ namespace System.Net.WebSockets
         /// <param name="payloadBuffer">The buffer into which payload data should be written.</param>
         /// <param name="cancellationToken">The CancellationToken used to cancel the websocket.</param>
         /// <returns>Information about the received message.</returns>
+        [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
         private async ValueTask<TResult> ReceiveAsyncPrivate<TResult>(Memory<byte> payloadBuffer, CancellationToken cancellationToken)
         {
             // This is a long method.  While splitting it up into pieces would arguably help with readability, doing so would
@@ -735,171 +664,179 @@ namespace System.Net.WebSockets
             CancellationTokenRegistration registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this);
             try
             {
-                while (true) // in case we get control frames that should be ignored from the user's perspective
+                await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
+                try
                 {
-                    // Get the last received header.  If its payload length is non-zero, that means we previously
-                    // received the header but were only able to read a part of the fragment, so we should skip
-                    // reading another header and just proceed to use that same header and read more data associated
-                    // with it.  If instead its payload length is zero, then we've completed the processing of
-                    // thta message, and we should read the next header.
-                    MessageHeader header = _lastReceiveHeader;
-                    if (header.Processed)
+                    while (true) // in case we get control frames that should be ignored from the user's perspective
                     {
-                        if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength)))
+                        // Get the last received header.  If its payload length is non-zero, that means we previously
+                        // received the header but were only able to read a part of the fragment, so we should skip
+                        // reading another header and just proceed to use that same header and read more data associated
+                        // with it.  If instead its payload length is zero, then we've completed the processing of
+                        // thta message, and we should read the next header.
+                        MessageHeader header = _lastReceiveHeader;
+                        if (header.Processed)
                         {
-                            // Make sure we have the first two bytes, which includes the start of the payload length.
-                            if (_receiveBufferCount < 2)
+                            if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength)))
                             {
-                                await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false);
+                                // Make sure we have the first two bytes, which includes the start of the payload length.
+                                if (_receiveBufferCount < 2)
+                                {
+                                    await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false);
+                                }
+
+                                // Then make sure we have the full header based on the payload length.
+                                // If this is the server, we also need room for the received mask.
+                                long payloadLength = _receiveBuffer.Span[_receiveBufferOffset + 1] & 0x7F;
+                                if (_isServer || payloadLength > 125)
+                                {
+                                    int minNeeded =
+                                        2 +
+                                        (_isServer ? MaskLength : 0) +
+                                        (payloadLength <= 125 ? 0 : payloadLength == 126 ? sizeof(ushort) : sizeof(ulong)); // additional 2 or 8 bytes for 16-bit or 64-bit length
+                                    await EnsureBufferContainsAsync(minNeeded, cancellationToken).ConfigureAwait(false);
+                                }
                             }
 
-                            // Then make sure we have the full header based on the payload length.
-                            // If this is the server, we also need room for the received mask.
-                            long payloadLength = _receiveBuffer.Span[_receiveBufferOffset + 1] & 0x7F;
-                            if (_isServer || payloadLength > 125)
+                            string? headerErrorMessage = TryParseMessageHeaderFromReceiveBuffer(out header);
+                            if (headerErrorMessage != null)
                             {
-                                int minNeeded =
-                                    2 +
-                                    (_isServer ? MaskLength : 0) +
-                                    (payloadLength <= 125 ? 0 : payloadLength == 126 ? sizeof(ushort) : sizeof(ulong)); // additional 2 or 8 bytes for 16-bit or 64-bit length
-                                await EnsureBufferContainsAsync(minNeeded, cancellationToken).ConfigureAwait(false);
+                                await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, headerErrorMessage).ConfigureAwait(false);
+                            }
+                            _receivedMaskOffsetOffset = 0;
+
+                            if (header.PayloadLength == 0 && header.Compressed)
+                            {
+                                // In the rare case where we receive a compressed message with no payload
+                                // we need to tell the inflater about it, because the receive code bellow will
+                                // not try to do anything when PayloadLength == 0.
+                                _inflater!.AddBytes(0, endOfMessage: header.Fin);
                             }
                         }
 
-                        string? headerErrorMessage = TryParseMessageHeaderFromReceiveBuffer(out header);
-                        if (headerErrorMessage != null)
+                        // If the header represents a ping or a pong, it's a control message meant
+                        // to be transparent to the user, so handle it and then loop around to read again.
+                        // Alternatively, if it's a close message, handle it and exit.
+                        if (header.Opcode == MessageOpcode.Ping || header.Opcode == MessageOpcode.Pong)
+                        {
+                            await HandleReceivedPingPongAsync(header, cancellationToken).ConfigureAwait(false);
+                            continue;
+                        }
+                        else if (header.Opcode == MessageOpcode.Close)
                         {
-                            await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, headerErrorMessage).ConfigureAwait(false);
+                            await HandleReceivedCloseAsync(header, cancellationToken).ConfigureAwait(false);
+                            return GetReceiveResult<TResult>(0, WebSocketMessageType.Close, true);
                         }
-                        _receivedMaskOffsetOffset = 0;
 
-                        if (header.PayloadLength == 0 && header.Compressed)
+                        // If this is a continuation, replace the opcode with the one of the message it's continuing
+                        if (header.Opcode == MessageOpcode.Continuation)
                         {
-                            // In the rare case where we receive a compressed message with no payload
-                            // we need to tell the inflater about it, because the receive code bellow will
-                            // not try to do anything when PayloadLength == 0.
-                            _inflater!.AddBytes(0, endOfMessage: header.Fin);
+                            header.Opcode = _lastReceiveHeader.Opcode;
+                            header.Compressed = _lastReceiveHeader.Compressed;
                         }
-                    }
 
-                    // If the header represents a ping or a pong, it's a control message meant
-                    // to be transparent to the user, so handle it and then loop around to read again.
-                    // Alternatively, if it's a close message, handle it and exit.
-                    if (header.Opcode == MessageOpcode.Ping || header.Opcode == MessageOpcode.Pong)
-                    {
-                        await HandleReceivedPingPongAsync(header, cancellationToken).ConfigureAwait(false);
-                        continue;
-                    }
-                    else if (header.Opcode == MessageOpcode.Close)
-                    {
-                        await HandleReceivedCloseAsync(header, cancellationToken).ConfigureAwait(false);
-                        return GetReceiveResult<TResult>(0, WebSocketMessageType.Close, true);
-                    }
+                        // The message should now be a binary or text message.  Handle it by reading the payload and returning the contents.
+                        Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}");
 
-                    // If this is a continuation, replace the opcode with the one of the message it's continuing
-                    if (header.Opcode == MessageOpcode.Continuation)
-                    {
-                        header.Opcode = _lastReceiveHeader.Opcode;
-                        header.Compressed = _lastReceiveHeader.Compressed;
-                    }
+                        // If there's no data to read, return an appropriate result.
+                        if (header.Processed || payloadBuffer.Length == 0)
+                        {
+                            _lastReceiveHeader = header;
+                            return GetReceiveResult<TResult>(
+                                count: 0,
+                                messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
+                                endOfMessage: header.EndOfMessage);
+                        }
 
-                    // The message should now be a binary or text message.  Handle it by reading the payload and returning the contents.
-                    Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}");
+                        // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data
+                        // remains for future reads.  We first need to copy any data that may be lingering in the receive buffer
+                        // into the destination; then to minimize ReceiveAsync calls, we want to read as much as we can, stopping
+                        // only when we've either read the whole message or when we've filled the payload buffer.
 
-                    // If there's no data to read, return an appropriate result.
-                    if (header.Processed || payloadBuffer.Length == 0)
-                    {
-                        _lastReceiveHeader = header;
-                        return GetReceiveResult<TResult>(
-                            count: 0,
-                            messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
-                            endOfMessage: header.EndOfMessage);
-                    }
+                        // First copy any data lingering in the receive buffer.
+                        int totalBytesReceived = 0;
+
+                        // Only start a new receive if we haven't received the entire frame.
+                        if (header.PayloadLength > 0)
+                        {
+                            if (header.Compressed)
+                            {
+                                Debug.Assert(_inflater is not null);
+                                _inflater.Prepare(header.PayloadLength, payloadBuffer.Length);
+                            }
 
-                    // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data
-                    // remains for future reads.  We first need to copy any data that may be lingering in the receive buffer
-                    // into the destination; then to minimize ReceiveAsync calls, we want to read as much as we can, stopping
-                    // only when we've either read the whole message or when we've filled the payload buffer.
+                            // Read directly into the appropriate buffer until we've hit a limit.
+                            int limit = (int)Math.Min(header.Compressed ? _inflater!.Span.Length : payloadBuffer.Length, header.PayloadLength);
 
-                    // First copy any data lingering in the receive buffer.
-                    int totalBytesReceived = 0;
+                            if (_receiveBufferCount > 0)
+                            {
+                                int receiveBufferBytesToCopy = Math.Min(limit, _receiveBufferCount);
+                                Debug.Assert(receiveBufferBytesToCopy > 0);
 
-                    // Only start a new receive if we haven't received the entire frame.
-                    if (header.PayloadLength > 0)
-                    {
-                        if (header.Compressed)
-                        {
-                            Debug.Assert(_inflater is not null);
-                            _inflater.Prepare(header.PayloadLength, payloadBuffer.Length);
-                        }
+                                _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(
+                                    header.Compressed ? _inflater!.Span : payloadBuffer.Span);
+                                ConsumeFromBuffer(receiveBufferBytesToCopy);
+                                totalBytesReceived += receiveBufferBytesToCopy;
+                            }
 
-                        // Read directly into the appropriate buffer until we've hit a limit.
-                        int limit = (int)Math.Min(header.Compressed ? _inflater!.Span.Length : payloadBuffer.Length, header.PayloadLength);
+                            while (totalBytesReceived < limit)
+                            {
+                                int numBytesRead = await _stream.ReadAsync(header.Compressed ?
+                                        _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) :
+                                        payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived),
+                                    cancellationToken).ConfigureAwait(false);
+                                if (numBytesRead <= 0)
+                                {
+                                    ThrowIfEOFUnexpected(throwOnPrematureClosure: true);
+                                    break;
+                                }
+                                totalBytesReceived += numBytesRead;
+                            }
 
-                        if (_receiveBufferCount > 0)
-                        {
-                            int receiveBufferBytesToCopy = Math.Min(limit, _receiveBufferCount);
-                            Debug.Assert(receiveBufferBytesToCopy > 0);
+                            if (_isServer)
+                            {
+                                _receivedMaskOffsetOffset = ApplyMask(header.Compressed ?
+                                    _inflater!.Span.Slice(0, totalBytesReceived) :
+                                    payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset);
+                            }
 
-                            _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(
-                                header.Compressed ? _inflater!.Span : payloadBuffer.Span);
-                            ConsumeFromBuffer(receiveBufferBytesToCopy);
-                            totalBytesReceived += receiveBufferBytesToCopy;
-                        }
+                            header.PayloadLength -= totalBytesReceived;
 
-                        while (totalBytesReceived < limit)
-                        {
-                            int numBytesRead = await _stream.ReadAsync(header.Compressed ?
-                                    _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) :
-                                    payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived),
-                                cancellationToken).ConfigureAwait(false);
-                            if (numBytesRead <= 0)
+                            if (header.Compressed)
                             {
-                                ThrowIfEOFUnexpected(throwOnPrematureClosure: true);
-                                break;
+                                _inflater!.AddBytes(totalBytesReceived, endOfMessage: header.Fin && header.PayloadLength == 0);
                             }
-                            totalBytesReceived += numBytesRead;
                         }
 
-                        if (_isServer)
+                        if (header.Compressed)
                         {
-                            _receivedMaskOffsetOffset = ApplyMask(header.Compressed ?
-                                _inflater!.Span.Slice(0, totalBytesReceived) :
-                                payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset);
+                            // In case of compression totalBytesReceived should actually represent how much we've
+                            // inflated, rather than how much we've read from the stream.
+                            header.Processed = _inflater!.Inflate(payloadBuffer.Span, out totalBytesReceived) && header.PayloadLength == 0;
                         }
-
-                        header.PayloadLength -= totalBytesReceived;
-
-                        if (header.Compressed)
+                        else
                         {
-                            _inflater!.AddBytes(totalBytesReceived, endOfMessage: header.Fin && header.PayloadLength == 0);
+                            // Without compression the frame is processed as soon as we've received everything
+                            header.Processed = header.PayloadLength == 0;
                         }
-                    }
 
-                    if (header.Compressed)
-                    {
-                        // In case of compression totalBytesReceived should actually represent how much we've
-                        // inflated, rather than how much we've read from the stream.
-                        header.Processed = _inflater!.Inflate(payloadBuffer.Span, out totalBytesReceived) && header.PayloadLength == 0;
-                    }
-                    else
-                    {
-                        // Without compression the frame is processed as soon as we've received everything
-                        header.Processed = header.PayloadLength == 0;
-                    }
+                        // If this a text message, validate that it contains valid UTF8.
+                        if (header.Opcode == MessageOpcode.Text &&
+                            !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.EndOfMessage, _utf8TextState))
+                        {
+                            await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false);
+                        }
 
-                    // If this a text message, validate that it contains valid UTF8.
-                    if (header.Opcode == MessageOpcode.Text &&
-                        !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.EndOfMessage, _utf8TextState))
-                    {
-                        await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false);
+                        _lastReceiveHeader = header;
+                        return GetReceiveResult<TResult>(
+                            totalBytesReceived,
+                            header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
+                            header.EndOfMessage);
                     }
-
-                    _lastReceiveHeader = header;
-                    return GetReceiveResult<TResult>(
-                        totalBytesReceived,
-                        header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
-                        header.EndOfMessage);
+                }
+                finally
+                {
+                    _receiveMutex.Exit();
                 }
             }
             catch (Exception exc) when (exc is not OperationCanceledException)
@@ -1018,20 +955,23 @@ namespace System.Net.WebSockets
             // additional data, but at this point we're about to close the connection and we're just stalling
             // to try to get the server to close first.
             ValueTask<int> finalReadTask = _stream.ReadAsync(_receiveBuffer, cancellationToken);
-            if (!finalReadTask.IsCompletedSuccessfully)
+            if (finalReadTask.IsCompletedSuccessfully)
             {
-                const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same as netfx)
-                using (var finalCts = new CancellationTokenSource(WaitForCloseTimeoutMs))
-                using (finalCts.Token.Register(static s => ((ManagedWebSocket)s!).Abort(), this))
+                finalReadTask.GetAwaiter().GetResult();
+            }
+            else
+            {
+                const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same duration as .NET Framework)
+                try
                 {
-                    try
-                    {
-                        await finalReadTask.ConfigureAwait(false);
-                    }
-                    catch
-                    {
-                        // Eat any resulting exceptions.  We were going to close the connection, anyway.
-                    }
+#pragma warning disable CA2016 // Token was already provided to the ReadAsync
+                    await finalReadTask.AsTask().WaitAsync(TimeSpan.FromMilliseconds(WaitForCloseTimeoutMs)).ConfigureAwait(false);
+#pragma warning restore CA2016
+                }
+                catch
+                {
+                    Abort();
+                    // Eat any resulting exceptions.  We were going to close the connection, anyway.
                 }
             }
         }
@@ -1137,7 +1077,7 @@ namespace System.Net.WebSockets
         /// <returns>null if a valid header was read; non-null containing the string error message to use if the header was invalid.</returns>
         private string? TryParseMessageHeaderFromReceiveBuffer(out MessageHeader resultHeader)
         {
-            Debug.Assert(_receiveBufferCount >= 2, $"Expected to at least have the first two bytes of the header.");
+            Debug.Assert(_receiveBufferCount >= 2, "Expected to at least have the first two bytes of the header.");
 
             MessageHeader header = default;
             Span<byte> receiveBufferSpan = _receiveBuffer.Span;
@@ -1155,13 +1095,13 @@ namespace System.Net.WebSockets
             // Read the remainder of the payload length, if necessary
             if (header.PayloadLength == 126)
             {
-                Debug.Assert(_receiveBufferCount >= 2, $"Expected to have two bytes for the payload length.");
+                Debug.Assert(_receiveBufferCount >= 2, "Expected to have two bytes for the payload length.");
                 header.PayloadLength = (receiveBufferSpan[_receiveBufferOffset] << 8) | receiveBufferSpan[_receiveBufferOffset + 1];
                 ConsumeFromBuffer(2);
             }
             else if (header.PayloadLength == 127)
             {
-                Debug.Assert(_receiveBufferCount >= 8, $"Expected to have eight bytes for the payload length.");
+                Debug.Assert(_receiveBufferCount >= 8, "Expected to have eight bytes for the payload length.");
                 header.PayloadLength = 0;
                 for (int i = 0; i < 8; i++)
                 {
@@ -1245,8 +1185,8 @@ namespace System.Net.WebSockets
             }
 
             // Return the read header
+            header.Processed = header.PayloadLength == 0 && !header.Compressed;
             resultHeader = header;
-            resultHeader.Processed = header.PayloadLength == 0 && !header.Compressed;
             return null;
         }
 
@@ -1281,40 +1221,40 @@ namespace System.Net.WebSockets
                 byte[] closeBuffer = ArrayPool<byte>.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength);
                 try
                 {
+                    // Loop until we've received a close frame.
                     while (!_receivedCloseFrame)
                     {
-                        Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
-                        Task receiveTask;
-                        bool usingExistingReceive;
-                        lock (ReceiveAsyncLock)
+                        // Enter the receive lock in order to get a consistent view of whether we've received a close
+                        // frame.  If we haven't, issue a receive.  Since that receive will try to take the same
+                        // non-entrant receive lock, we then exit the lock before waiting for the receive to complete,
+                        // as it will always complete asynchronously and only after we've exited the lock.
+                        ValueTask<ValueWebSocketReceiveResult> receiveTask = default;
+                        try
                         {
-                            // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame.
-                            // It could have been received between our check above and now due to a concurrent receive completing.
-                            if (_receivedCloseFrame)
+                            await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
+                            try
                             {
-                                break;
-                            }
+                                if (!_receivedCloseFrame)
+                                {
+                                    receiveTask = ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(closeBuffer, cancellationToken);
+                                }
 
-                            // We've not yet processed a received close frame, which means we need to wait for a received close to complete.
-                            // There may already be one in flight, in which case we want to just wait for that one rather than kicking off
-                            // another (we don't support concurrent receive operations).  We need to kick off a new receive if either we've
-                            // never issued a receive or if the last issued receive completed for reasons other than a close frame.  There is
-                            // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst
-                            // case is we then await it, find that it's not what we need, and try again.
-                            receiveTask = _lastReceiveAsync;
-                            Task newReceiveTask = ValidateAndReceiveAsync(receiveTask, closeBuffer, cancellationToken);
-                            usingExistingReceive = ReferenceEquals(receiveTask, newReceiveTask);
-                            _lastReceiveAsync = receiveTask = newReceiveTask;
+                            }
+                            finally
+                            {
+                                _receiveMutex.Exit();
+                            }
                         }
-
-                        // Wait for whatever receive task we have.  We'll then loop around again to re-check our state.
-                        // If this is an existing receive, and if we have a cancelable token, we need to register with that
-                        // token while we wait, since it may not be the same one that was given to the receive initially.
-                        Debug.Assert(receiveTask != null);
-                        using (usingExistingReceive ? cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this) : default)
+                        catch (OperationCanceledException)
                         {
-                            await receiveTask.ConfigureAwait(false);
+                            // If waiting on the receive lock was canceled, abort the connection, as we would do
+                            // as part of the receive itself.
+                            Abort();
+                            throw;
                         }
+
+                        // Wait for the receive to complete if we issued one.
+                        await receiveTask.ConfigureAwait(false);
                     }
                 }
                 finally
@@ -1351,7 +1291,7 @@ namespace System.Net.WebSockets
                     count += s_textEncoding.GetByteCount(closeStatusDescription);
                     buffer = ArrayPool<byte>.Shared.Rent(count);
                     int encodedLength = s_textEncoding.GetBytes(closeStatusDescription, 0, closeStatusDescription.Length, buffer, 2);
-                    Debug.Assert(count - 2 == encodedLength, $"GetByteCount and GetBytes encoded count didn't match");
+                    Debug.Assert(count - 2 == encodedLength, $"{nameof(s_textEncoding.GetByteCount)} and {nameof(s_textEncoding.GetBytes)} encoded count didn't match");
                 }
 
                 ushort closeStatusValue = (ushort)closeStatus;
@@ -1389,12 +1329,13 @@ namespace System.Net.WebSockets
 
         private void ConsumeFromBuffer(int count)
         {
-            Debug.Assert(count >= 0, $"Expected non-negative count, got {count}");
+            Debug.Assert(count >= 0, $"Expected non-negative {nameof(count)}, got {count}");
             Debug.Assert(count <= _receiveBufferCount, $"Trying to consume {count}, which is more than exists {_receiveBufferCount}");
             _receiveBufferCount -= count;
             _receiveBufferOffset += count;
         }
 
+        [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))]
         private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, CancellationToken cancellationToken, bool throwOnPrematureClosure = true)
         {
             Debug.Assert(minimumRequiredBytes <= _receiveBuffer.Length, $"Requested number of bytes {minimumRequiredBytes} must not exceed {_receiveBuffer.Length}");
@@ -1450,7 +1391,7 @@ namespace System.Net.WebSockets
         /// <summary>Releases the send buffer to the pool.</summary>
         private void ReleaseSendBuffer()
         {
-            Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock");
+            Debug.Assert(_sendMutex.IsHeld, $"Caller should hold the {nameof(_sendMutex)}");
 
             if (_sendBuffer is byte[] toReturn)
             {