--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Threading.Tasks;
+
+namespace System.Threading
+{
+ /// <summary>Provides an async mutex.</summary>
+ /// <remarks>
+ /// This could be achieved with a <see cref="SemaphoreSlim"/> constructed with an initial
+ /// and max limit of 1. However, this implementation is optimized to the needs of ManagedWebSocket,
+ /// which is that we expect zero contention in typical use cases.
+ /// </remarks>
+ internal sealed class AsyncMutex
+ {
+ /// <summary>Fast-path gate count tracking access to the mutex.</summary>
+ /// <remarks>
+ /// If the value is 1, the mutex can be entered atomically with an interlocked operation.
+ /// If the value is less than or equal to 0, the mutex is held and requires fallback to enter it.
+ /// </remarks>
+ private int _gate = 1;
+ /// <summary>Secondary check guarded by the lock to indicate whether the mutex is acquired.</summary>
+ /// <remarks>
+ /// This is only meaningful after having updated <see cref="_gate"/> via interlockeds and taken the appropriate path.
+ /// If after decrementing <see cref="_gate"/> we end up with a negative count, the mutex is contended, hence
+ /// <see cref="_lockedSemaphoreFull"/> starting as <c>true</c>. The primary purpose of this field
+ /// is to handle the race condition between one thread acquiring the mutex, then another thread trying to acquire
+ /// and getting as far as completing the interlocked operation, and then the original thread releasing; at that point
+ /// it'll hit the lock and we need to store that the mutex is available to enter. If we instead used a
+ /// SemaphoreSlim as the fallback from the interlockeds, this would have been its count, and it would have started
+ /// with an initial count of 0.
+ /// </remarks>
+ private bool _lockedSemaphoreFull = true;
+ /// <summary>The tail of the double-linked circular waiting queue.</summary>
+ /// <remarks>
+ /// Waiters are added at the tail.
+ /// Items are dequeued from the head (tail.Prev).
+ /// </remarks>
+ private Waiter? _waitersTail;
+
+ /// <summary>Gets whether the mutex is currently held by some operation (not necessarily the caller).</summary>
+ /// <remarks>This should be used only for asserts and debugging.</remarks>
+ public bool IsHeld => _gate != 1;
+
+ /// <summary>Gets the object used to synchronize contended operations.</summary>
+ private object SyncObj => this;
+
+ /// <summary>Asynchronously waits to enter the mutex.</summary>
+ /// <param name="cancellationToken">The CancellationToken token to observe.</param>
+ /// <returns>A task that will complete when the mutex has been entered or the enter canceled.</returns>
+ public Task EnterAsync(CancellationToken cancellationToken)
+ {
+ // If cancellation was requested, bail immediately.
+ // If the mutex is not currently held nor contended, enter immediately.
+ // Otherwise, fall back to a more expensive likely-asynchronous wait.
+ return
+ cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) :
+ Interlocked.Decrement(ref _gate) >= 0 ? Task.CompletedTask :
+ Contended(cancellationToken);
+
+ // Everything that follows is the equivalent of:
+ // return _sem.WaitAsync(cancellationToken);
+ // if _sem were to be constructed as `new SemaphoreSlim(0)`.
+
+ Task Contended(CancellationToken cancellationToken)
+ {
+ var w = new Waiter(this);
+
+ // We need to register for cancellation before storing the waiter into the list.
+ // If we registered after, we might leak a registration if the mutex was exited and the waiter
+ // removed from the list prior to CancellationRegistration being properly assigned. By registering before,
+ // there's a different race condition, that of cancellation being requested prior to storing the waiter into
+ // the list; if that happens, we could end up adding the waiter and have it still stored in the list even
+ // though OnCancellation was called. So once we hold the lock, which OnCancellation also needs to take, we
+ // check again whether cancellation has been requested,and avoid storing the waiter if it has.
+ w.CancellationRegistration = cancellationToken.UnsafeRegister((s, token) => OnCancellation(s, token), w);
+
+ lock (SyncObj)
+ {
+ // Now that we're holding the lock, check to see whether the async lock is acquirable.
+ if (!_lockedSemaphoreFull)
+ {
+ // If we are able to acquire the lock, we're done; we just need to clean up after the registration.
+ w.CancellationRegistration.Unregister();
+ _lockedSemaphoreFull = true;
+ return Task.CompletedTask;
+ }
+
+ // Now that we're holding the lock and thus synchronized with OnCancellation, check to see
+ // if cancellation has been requested.
+ if (cancellationToken.IsCancellationRequested)
+ {
+ w.TrySetCanceled(cancellationToken);
+ return w.Task;
+ }
+
+ // The lock couldn't be acquired.
+ // Add the waiter to the linked list of waiters.
+ if (_waitersTail is null)
+ {
+ w.Next = w.Prev = w;
+ }
+ else
+ {
+ Debug.Assert(_waitersTail.Next != null && _waitersTail.Prev != null);
+ w.Next = _waitersTail;
+ w.Prev = _waitersTail.Prev;
+ w.Prev.Next = w.Next.Prev = w;
+ }
+ _waitersTail = w;
+ }
+
+ // Return the waiter as a value task.
+ return w.Task;
+
+ // Cancels the specified waiter if it's still in the list.
+ static void OnCancellation(object? state, CancellationToken cancellationToken)
+ {
+ Waiter? w = (Waiter)state!;
+ AsyncMutex m = w.Owner;
+
+ lock (m.SyncObj)
+ {
+ bool inList = w.Next != null;
+ if (inList)
+ {
+ // The waiter is in the list.
+ Debug.Assert(w.Prev != null);
+
+ // The gate counter was decremented when this waiter was added. We need
+ // to undo that. Since the waiter is still in the list, the lock must
+ // still be held by someone, which means we don't need to do anything with
+ // the result of this increment. If it increments to < 1, then there are
+ // still other waiters. If it increments to 1, we're in a rare race condition
+ // where there are no other waiters and the owner just incremented the gate
+ // count; they would have seen it be < 1, so they will proceed to take the
+ // contended code path and synchronize on the lock we're holding... once we
+ // release it, they will appropriately update state.
+ Interlocked.Increment(ref m._gate);
+
+ if (w.Next == w)
+ {
+ Debug.Assert(m._waitersTail == w);
+ m._waitersTail = null;
+ }
+ else
+ {
+ w.Next!.Prev = w.Prev;
+ w.Prev.Next = w.Next;
+ if (m._waitersTail == w)
+ {
+ m._waitersTail = w.Next;
+ }
+ }
+
+ // Remove it from the list.
+ w.Next = w.Prev = null;
+ }
+ else
+ {
+ // The waiter was no longer in the list. We must not cancel it.
+ w = null;
+ }
+ }
+
+ // If the waiter was in the list, we removed it under the lock and thus own
+ // the ability to cancel it. Do so.
+ w?.TrySetCanceled(cancellationToken);
+ }
+ }
+ }
+
+ /// <summary>Releases the mutex.</summary>
+ /// <remarks>The caller must logically own the mutex. This is not validated.</remarks>
+ public void Exit()
+ {
+ if (Interlocked.Increment(ref _gate) < 1)
+ {
+ // This is the equivalent of:
+ // _sem.Release();
+ // if _sem were to be constructed as `new SemaphoreSlim(0)`.
+ Contended();
+ }
+
+ void Contended()
+ {
+ Waiter? w;
+ lock (SyncObj)
+ {
+ Debug.Assert(_lockedSemaphoreFull);
+
+ w = _waitersTail;
+ if (w is null)
+ {
+ _lockedSemaphoreFull = false;
+ }
+ else
+ {
+ Debug.Assert(w.Next != null && w.Prev != null);
+ Debug.Assert(w.Next != w || w.Prev == w);
+ Debug.Assert(w.Prev != w || w.Next == w);
+
+ if (w.Next == w)
+ {
+ _waitersTail = null;
+ }
+ else
+ {
+ w = w.Prev; // get the head
+ Debug.Assert(w.Next != null && w.Prev != null);
+ Debug.Assert(w.Next != w && w.Prev != w);
+
+ w.Next.Prev = w.Prev;
+ w.Prev.Next = w.Next;
+ }
+
+ w.Next = w.Prev = null;
+ }
+ }
+
+ // Either there wasn't a waiter, or we got one and successfully removed it from the list,
+ // at which point we own the ability to complete it. Do so.
+ if (w is not null)
+ {
+ w.CancellationRegistration.Unregister();
+ w.TrySetResult();
+ }
+ }
+ }
+
+ /// <summary>Represents a waiter for the mutex.</summary>
+ private sealed class Waiter : TaskCompletionSource
+ {
+ public Waiter(AsyncMutex owner) : base(TaskCreationOptions.RunContinuationsAsynchronously) => Owner = owner;
+ public AsyncMutex Owner { get; }
+ public CancellationTokenRegistration CancellationRegistration { get; set; }
+ public Waiter? Next { get; set; }
+ public Waiter? Prev { get; set; }
+ }
+ }
+}
/// <summary>Valid states to be in when calling CloseAsync.</summary>
private static readonly WebSocketState[] s_validCloseStates = { WebSocketState.Open, WebSocketState.CloseReceived, WebSocketState.CloseSent };
- /// <summary>Successfully completed task representing a close message.</summary>
- private static readonly Task<WebSocketReceiveResult> s_cachedCloseTask = Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Close, true));
-
/// <summary>The maximum size in bytes of a message frame header that includes mask bytes.</summary>
internal const int MaxMessageHeaderLength = 14;
/// <summary>The maximum size of a control message payload.</summary>
/// </summary>
private readonly Utf8MessageState _utf8TextState = new Utf8MessageState();
/// <summary>
- /// Semaphore used to ensure that calls to SendFrameAsync don't run concurrently.
+ /// Mutex used to ensure that calls to SendFrameAsync don't run concurrently. We don't support multiple concurrent SendAsync calls,
+ /// but this is needed to support SendAsync concurrently with keep-alive pings and CloseAsync.
+ /// </summary>
+ private readonly AsyncMutex _sendMutex = new AsyncMutex();
+ /// <summary>
+ /// Mutex used to ensure that calls to ReceiveAsyncPrivate don't run concurrently. We don't support multiple concurrent ReceiveAsync calls,
+ /// but this is needed to support SendAsync concurrently with keep-alive pings and CloseAsync.
/// </summary>
- private readonly SemaphoreSlim _sendFrameAsyncLock = new SemaphoreSlim(1, 1);
+ private readonly AsyncMutex _receiveMutex = new AsyncMutex();
// We maintain the current WebSocketState in _state. However, we separately maintain _sentCloseFrame and _receivedCloseFrame
// as there isn't a strict ordering between CloseSent and CloseReceived. If we receive a close frame from the server, we need to
/// Whether the last SendAsync had <seealso cref="WebSocketMessageFlags.DisableCompression" /> flag set.
/// </summary>
private bool _lastSendHadDisableCompression;
- /// <summary>
- /// The task returned from the last ReceiveAsync(ArraySegment, ...) operation to not complete synchronously.
- /// If this is not null and not completed when a subsequent ReceiveAsync is issued, an exception occurs.
- /// </summary>
- private Task _lastReceiveAsync = Task.CompletedTask;
/// <summary>Lock used to protect update and check-and-update operations on _state.</summary>
- private object StateUpdateLock => _sendFrameAsyncLock;
- /// <summary>
- /// We need to coordinate between receives and close operations happening concurrently, as a ReceiveAsync may
- /// be pending while a Close{Output}Async is issued, which itself needs to loop until a close frame is received.
- /// As such, we need thread-safety in the management of <see cref="_lastReceiveAsync"/>.
- /// </summary>
- private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it
+ private object StateUpdateLock => _sendMutex;
private readonly WebSocketInflater? _inflater;
private readonly WebSocketDeflater? _deflater;
internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval)
{
Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null");
- Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null");
- Debug.Assert(StateUpdateLock != ReceiveAsyncLock, "Locks should be different objects");
-
- Debug.Assert(stream != null, $"Expected non-null stream");
- Debug.Assert(stream.CanRead, $"Expected readable stream");
- Debug.Assert(stream.CanWrite, $"Expected writeable stream");
- Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid keepalive interval: {keepAliveInterval}");
+ Debug.Assert(stream != null, $"Expected non-null {nameof(stream)}");
+ Debug.Assert(stream.CanRead, $"Expected readable {nameof(stream)}");
+ Debug.Assert(stream.CanWrite, $"Expected writeable {nameof(stream)}");
+ Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid {nameof(keepAliveInterval)}: {keepAliveInterval}");
_stream = stream;
_isServer = isServer;
WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer));
- return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken).AsTask();
+ return SendAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken).AsTask();
}
- private ValueTask SendPrivateAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
+ public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) =>
+ SendAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken);
+
+ public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
{
if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary)
{
{
WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates);
- Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
- lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync
- {
- ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted);
- Task<WebSocketReceiveResult> t = ReceiveAsyncPrivate<WebSocketReceiveResult>(buffer, cancellationToken).AsTask();
- _lastReceiveAsync = t;
- return t;
- }
+ return ReceiveAsyncPrivate<WebSocketReceiveResult>(buffer, cancellationToken).AsTask();
}
catch (Exception exc)
{
}
}
+ public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, CancellationToken cancellationToken)
+ {
+ try
+ {
+ WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates);
+
+ return ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(buffer, cancellationToken);
+ }
+ catch (Exception exc)
+ {
+ return ValueTask.FromException<ValueWebSocketReceiveResult>(exc);
+ }
+ }
+
public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
{
WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription);
}
}
- public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
- {
- return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken);
- }
-
- public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
- {
- return SendPrivateAsync(buffer, messageType, messageFlags, cancellationToken);
- }
-
- public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, CancellationToken cancellationToken)
- {
- try
- {
- WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates);
-
- Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
- lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync
- {
- ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted);
-
- ValueTask<ValueWebSocketReceiveResult> receiveValueTask = ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(buffer, cancellationToken);
- if (receiveValueTask.IsCompletedSuccessfully)
- {
- _lastReceiveAsync = receiveValueTask.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask;
- return receiveValueTask;
- }
-
- // We need to both store the last receive task and return it, but we can't do that with a ValueTask,
- // as that could result in consuming it multiple times. Instead, we use AsTask to consume it just once,
- // and then store that Task and return a new ValueTask that wraps it. (It would be nice in the future
- // to avoid this AsTask as well; currently it's used both for error detection and as part of close tracking.)
- Task<ValueWebSocketReceiveResult> receiveTask = receiveValueTask.AsTask();
- _lastReceiveAsync = receiveTask;
- return new ValueTask<ValueWebSocketReceiveResult>(receiveTask);
- }
- }
- catch (Exception exc)
- {
- return ValueTask.FromException<ValueWebSocketReceiveResult>(exc);
- }
- }
-
- private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, CancellationToken cancellationToken)
- {
- if (receiveTask == null ||
- (receiveTask.IsCompletedSuccessfully &&
- !(receiveTask is Task<WebSocketReceiveResult> wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) &&
- !(receiveTask is Task<ValueWebSocketReceiveResult> vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close)))
- {
- ValueTask<ValueWebSocketReceiveResult> vt = ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(buffer, cancellationToken);
- receiveTask =
- vt.IsCompletedSuccessfully ? (vt.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask) :
- vt.AsTask();
- }
-
- return receiveTask;
- }
-
/// <summary>Sends a websocket frame to the network.</summary>
/// <param name="opcode">The opcode for the message.</param>
/// <param name="endOfMessage">The value of the FIN bit for the message.</param>
// pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path.
// Similarly, it should be rare that there are multiple outstanding calls to SendFrameAsync, but if there are, again
// fall back to the fallback path.
-#pragma warning disable CA1416 // Validate platform compatibility, will not wait because timeout equals 0
- return cancellationToken.CanBeCanceled || !_sendFrameAsyncLock.Wait(0, default) ?
-#pragma warning restore CA1416
- SendFrameFallbackAsync(opcode, endOfMessage, disableCompression, payloadBuffer, cancellationToken) :
+ Task lockTask = _sendMutex.EnterAsync(cancellationToken);
+ return cancellationToken.CanBeCanceled || !lockTask.IsCompletedSuccessfully ?
+ SendFrameFallbackAsync(opcode, endOfMessage, disableCompression, payloadBuffer, lockTask, cancellationToken) :
SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, disableCompression, payloadBuffer);
}
/// <param name="payloadBuffer">The buffer containing the payload data fro the message.</param>
private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory<byte> payloadBuffer)
{
- Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock");
+ Debug.Assert(_sendMutex.IsHeld, $"Caller should hold the {nameof(_sendMutex)}");
// If we get here, the cancellation token is not cancelable so we don't have to worry about it,
// and we own the semaphore, so we don't need to asynchronously wait for it.
if (releaseSendBufferAndSemaphore)
{
ReleaseSendBuffer();
- _sendFrameAsyncLock.Release();
+ _sendMutex.Exit();
}
}
finally
{
ReleaseSendBuffer();
- _sendFrameAsyncLock.Release();
+ _sendMutex.Exit();
}
}
- private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory<byte> payloadBuffer, CancellationToken cancellationToken)
+ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory<byte> payloadBuffer, Task lockTask, CancellationToken cancellationToken)
{
- await _sendFrameAsyncLock.WaitAsync(cancellationToken).ConfigureAwait(false);
+ await lockTask.ConfigureAwait(false);
try
{
int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span);
finally
{
ReleaseSendBuffer();
- _sendFrameAsyncLock.Release();
+ _sendMutex.Exit();
}
}
private void SendKeepAliveFrameAsync()
{
-#pragma warning disable CA1416 // Validate platform compatibility, will not wait because timeout equals 0
- bool acquiredLock = _sendFrameAsyncLock.Wait(0);
-#pragma warning restore CA1416
- if (acquiredLock)
- {
- // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures.
- // The call will handle releasing the lock. We send a pong rather than ping, since it's allowed by
- // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response.
- ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory<byte>.Empty);
- if (t.IsCompletedSuccessfully)
- {
- t.GetAwaiter().GetResult();
- }
- else
- {
- // "Observe" any exception, ignoring it to prevent the unobserved exception event from being raised.
- t.AsTask().ContinueWith(static p => { _ = p.Exception; },
- CancellationToken.None,
- TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
- TaskScheduler.Default);
- }
+ // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures.
+ // The call will handle releasing the lock. We send a pong rather than ping, since it's allowed by
+ // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response.
+ ValueTask t = SendFrameAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory<byte>.Empty, CancellationToken.None);
+ if (t.IsCompletedSuccessfully)
+ {
+ t.GetAwaiter().GetResult();
}
else
{
- // If the lock is already held, something is already getting sent,
- // so there's no need to send a keep-alive ping.
+ // "Observe" any exception, ignoring it to prevent the unobserved exception event from being raised.
+ t.AsTask().ContinueWith(static p => { _ = p.Exception; },
+ CancellationToken.None,
+ TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
+ TaskScheduler.Default);
}
}
// 0 or 4 bytes - Mask, if Masked is 1 - random value XOR'd with each 4 bytes of the payload, round-robin
// Length bytes - Payload data
- Debug.Assert(sendBuffer.Length >= MaxMessageHeaderLength, $"Expected sendBuffer to be at least {MaxMessageHeaderLength}, got {sendBuffer.Length}");
+ Debug.Assert(sendBuffer.Length >= MaxMessageHeaderLength, $"Expected {nameof(sendBuffer)} to be at least {MaxMessageHeaderLength}, got {sendBuffer.Length}");
sendBuffer[0] = (byte)opcode; // 4 bits for the opcode
if (endOfMessage)
/// <param name="payloadBuffer">The buffer into which payload data should be written.</param>
/// <param name="cancellationToken">The CancellationToken used to cancel the websocket.</param>
/// <returns>Information about the received message.</returns>
+ [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
private async ValueTask<TResult> ReceiveAsyncPrivate<TResult>(Memory<byte> payloadBuffer, CancellationToken cancellationToken)
{
// This is a long method. While splitting it up into pieces would arguably help with readability, doing so would
CancellationTokenRegistration registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this);
try
{
- while (true) // in case we get control frames that should be ignored from the user's perspective
+ await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
+ try
{
- // Get the last received header. If its payload length is non-zero, that means we previously
- // received the header but were only able to read a part of the fragment, so we should skip
- // reading another header and just proceed to use that same header and read more data associated
- // with it. If instead its payload length is zero, then we've completed the processing of
- // thta message, and we should read the next header.
- MessageHeader header = _lastReceiveHeader;
- if (header.Processed)
+ while (true) // in case we get control frames that should be ignored from the user's perspective
{
- if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength)))
+ // Get the last received header. If its payload length is non-zero, that means we previously
+ // received the header but were only able to read a part of the fragment, so we should skip
+ // reading another header and just proceed to use that same header and read more data associated
+ // with it. If instead its payload length is zero, then we've completed the processing of
+ // thta message, and we should read the next header.
+ MessageHeader header = _lastReceiveHeader;
+ if (header.Processed)
{
- // Make sure we have the first two bytes, which includes the start of the payload length.
- if (_receiveBufferCount < 2)
+ if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength)))
{
- await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false);
+ // Make sure we have the first two bytes, which includes the start of the payload length.
+ if (_receiveBufferCount < 2)
+ {
+ await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false);
+ }
+
+ // Then make sure we have the full header based on the payload length.
+ // If this is the server, we also need room for the received mask.
+ long payloadLength = _receiveBuffer.Span[_receiveBufferOffset + 1] & 0x7F;
+ if (_isServer || payloadLength > 125)
+ {
+ int minNeeded =
+ 2 +
+ (_isServer ? MaskLength : 0) +
+ (payloadLength <= 125 ? 0 : payloadLength == 126 ? sizeof(ushort) : sizeof(ulong)); // additional 2 or 8 bytes for 16-bit or 64-bit length
+ await EnsureBufferContainsAsync(minNeeded, cancellationToken).ConfigureAwait(false);
+ }
}
- // Then make sure we have the full header based on the payload length.
- // If this is the server, we also need room for the received mask.
- long payloadLength = _receiveBuffer.Span[_receiveBufferOffset + 1] & 0x7F;
- if (_isServer || payloadLength > 125)
+ string? headerErrorMessage = TryParseMessageHeaderFromReceiveBuffer(out header);
+ if (headerErrorMessage != null)
{
- int minNeeded =
- 2 +
- (_isServer ? MaskLength : 0) +
- (payloadLength <= 125 ? 0 : payloadLength == 126 ? sizeof(ushort) : sizeof(ulong)); // additional 2 or 8 bytes for 16-bit or 64-bit length
- await EnsureBufferContainsAsync(minNeeded, cancellationToken).ConfigureAwait(false);
+ await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, headerErrorMessage).ConfigureAwait(false);
+ }
+ _receivedMaskOffsetOffset = 0;
+
+ if (header.PayloadLength == 0 && header.Compressed)
+ {
+ // In the rare case where we receive a compressed message with no payload
+ // we need to tell the inflater about it, because the receive code bellow will
+ // not try to do anything when PayloadLength == 0.
+ _inflater!.AddBytes(0, endOfMessage: header.Fin);
}
}
- string? headerErrorMessage = TryParseMessageHeaderFromReceiveBuffer(out header);
- if (headerErrorMessage != null)
+ // If the header represents a ping or a pong, it's a control message meant
+ // to be transparent to the user, so handle it and then loop around to read again.
+ // Alternatively, if it's a close message, handle it and exit.
+ if (header.Opcode == MessageOpcode.Ping || header.Opcode == MessageOpcode.Pong)
+ {
+ await HandleReceivedPingPongAsync(header, cancellationToken).ConfigureAwait(false);
+ continue;
+ }
+ else if (header.Opcode == MessageOpcode.Close)
{
- await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, headerErrorMessage).ConfigureAwait(false);
+ await HandleReceivedCloseAsync(header, cancellationToken).ConfigureAwait(false);
+ return GetReceiveResult<TResult>(0, WebSocketMessageType.Close, true);
}
- _receivedMaskOffsetOffset = 0;
- if (header.PayloadLength == 0 && header.Compressed)
+ // If this is a continuation, replace the opcode with the one of the message it's continuing
+ if (header.Opcode == MessageOpcode.Continuation)
{
- // In the rare case where we receive a compressed message with no payload
- // we need to tell the inflater about it, because the receive code bellow will
- // not try to do anything when PayloadLength == 0.
- _inflater!.AddBytes(0, endOfMessage: header.Fin);
+ header.Opcode = _lastReceiveHeader.Opcode;
+ header.Compressed = _lastReceiveHeader.Compressed;
}
- }
- // If the header represents a ping or a pong, it's a control message meant
- // to be transparent to the user, so handle it and then loop around to read again.
- // Alternatively, if it's a close message, handle it and exit.
- if (header.Opcode == MessageOpcode.Ping || header.Opcode == MessageOpcode.Pong)
- {
- await HandleReceivedPingPongAsync(header, cancellationToken).ConfigureAwait(false);
- continue;
- }
- else if (header.Opcode == MessageOpcode.Close)
- {
- await HandleReceivedCloseAsync(header, cancellationToken).ConfigureAwait(false);
- return GetReceiveResult<TResult>(0, WebSocketMessageType.Close, true);
- }
+ // The message should now be a binary or text message. Handle it by reading the payload and returning the contents.
+ Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}");
- // If this is a continuation, replace the opcode with the one of the message it's continuing
- if (header.Opcode == MessageOpcode.Continuation)
- {
- header.Opcode = _lastReceiveHeader.Opcode;
- header.Compressed = _lastReceiveHeader.Compressed;
- }
+ // If there's no data to read, return an appropriate result.
+ if (header.Processed || payloadBuffer.Length == 0)
+ {
+ _lastReceiveHeader = header;
+ return GetReceiveResult<TResult>(
+ count: 0,
+ messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
+ endOfMessage: header.EndOfMessage);
+ }
- // The message should now be a binary or text message. Handle it by reading the payload and returning the contents.
- Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}");
+ // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data
+ // remains for future reads. We first need to copy any data that may be lingering in the receive buffer
+ // into the destination; then to minimize ReceiveAsync calls, we want to read as much as we can, stopping
+ // only when we've either read the whole message or when we've filled the payload buffer.
- // If there's no data to read, return an appropriate result.
- if (header.Processed || payloadBuffer.Length == 0)
- {
- _lastReceiveHeader = header;
- return GetReceiveResult<TResult>(
- count: 0,
- messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
- endOfMessage: header.EndOfMessage);
- }
+ // First copy any data lingering in the receive buffer.
+ int totalBytesReceived = 0;
+
+ // Only start a new receive if we haven't received the entire frame.
+ if (header.PayloadLength > 0)
+ {
+ if (header.Compressed)
+ {
+ Debug.Assert(_inflater is not null);
+ _inflater.Prepare(header.PayloadLength, payloadBuffer.Length);
+ }
- // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data
- // remains for future reads. We first need to copy any data that may be lingering in the receive buffer
- // into the destination; then to minimize ReceiveAsync calls, we want to read as much as we can, stopping
- // only when we've either read the whole message or when we've filled the payload buffer.
+ // Read directly into the appropriate buffer until we've hit a limit.
+ int limit = (int)Math.Min(header.Compressed ? _inflater!.Span.Length : payloadBuffer.Length, header.PayloadLength);
- // First copy any data lingering in the receive buffer.
- int totalBytesReceived = 0;
+ if (_receiveBufferCount > 0)
+ {
+ int receiveBufferBytesToCopy = Math.Min(limit, _receiveBufferCount);
+ Debug.Assert(receiveBufferBytesToCopy > 0);
- // Only start a new receive if we haven't received the entire frame.
- if (header.PayloadLength > 0)
- {
- if (header.Compressed)
- {
- Debug.Assert(_inflater is not null);
- _inflater.Prepare(header.PayloadLength, payloadBuffer.Length);
- }
+ _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(
+ header.Compressed ? _inflater!.Span : payloadBuffer.Span);
+ ConsumeFromBuffer(receiveBufferBytesToCopy);
+ totalBytesReceived += receiveBufferBytesToCopy;
+ }
- // Read directly into the appropriate buffer until we've hit a limit.
- int limit = (int)Math.Min(header.Compressed ? _inflater!.Span.Length : payloadBuffer.Length, header.PayloadLength);
+ while (totalBytesReceived < limit)
+ {
+ int numBytesRead = await _stream.ReadAsync(header.Compressed ?
+ _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) :
+ payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived),
+ cancellationToken).ConfigureAwait(false);
+ if (numBytesRead <= 0)
+ {
+ ThrowIfEOFUnexpected(throwOnPrematureClosure: true);
+ break;
+ }
+ totalBytesReceived += numBytesRead;
+ }
- if (_receiveBufferCount > 0)
- {
- int receiveBufferBytesToCopy = Math.Min(limit, _receiveBufferCount);
- Debug.Assert(receiveBufferBytesToCopy > 0);
+ if (_isServer)
+ {
+ _receivedMaskOffsetOffset = ApplyMask(header.Compressed ?
+ _inflater!.Span.Slice(0, totalBytesReceived) :
+ payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset);
+ }
- _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(
- header.Compressed ? _inflater!.Span : payloadBuffer.Span);
- ConsumeFromBuffer(receiveBufferBytesToCopy);
- totalBytesReceived += receiveBufferBytesToCopy;
- }
+ header.PayloadLength -= totalBytesReceived;
- while (totalBytesReceived < limit)
- {
- int numBytesRead = await _stream.ReadAsync(header.Compressed ?
- _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) :
- payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived),
- cancellationToken).ConfigureAwait(false);
- if (numBytesRead <= 0)
+ if (header.Compressed)
{
- ThrowIfEOFUnexpected(throwOnPrematureClosure: true);
- break;
+ _inflater!.AddBytes(totalBytesReceived, endOfMessage: header.Fin && header.PayloadLength == 0);
}
- totalBytesReceived += numBytesRead;
}
- if (_isServer)
+ if (header.Compressed)
{
- _receivedMaskOffsetOffset = ApplyMask(header.Compressed ?
- _inflater!.Span.Slice(0, totalBytesReceived) :
- payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset);
+ // In case of compression totalBytesReceived should actually represent how much we've
+ // inflated, rather than how much we've read from the stream.
+ header.Processed = _inflater!.Inflate(payloadBuffer.Span, out totalBytesReceived) && header.PayloadLength == 0;
}
-
- header.PayloadLength -= totalBytesReceived;
-
- if (header.Compressed)
+ else
{
- _inflater!.AddBytes(totalBytesReceived, endOfMessage: header.Fin && header.PayloadLength == 0);
+ // Without compression the frame is processed as soon as we've received everything
+ header.Processed = header.PayloadLength == 0;
}
- }
- if (header.Compressed)
- {
- // In case of compression totalBytesReceived should actually represent how much we've
- // inflated, rather than how much we've read from the stream.
- header.Processed = _inflater!.Inflate(payloadBuffer.Span, out totalBytesReceived) && header.PayloadLength == 0;
- }
- else
- {
- // Without compression the frame is processed as soon as we've received everything
- header.Processed = header.PayloadLength == 0;
- }
+ // If this a text message, validate that it contains valid UTF8.
+ if (header.Opcode == MessageOpcode.Text &&
+ !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.EndOfMessage, _utf8TextState))
+ {
+ await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false);
+ }
- // If this a text message, validate that it contains valid UTF8.
- if (header.Opcode == MessageOpcode.Text &&
- !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.EndOfMessage, _utf8TextState))
- {
- await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false);
+ _lastReceiveHeader = header;
+ return GetReceiveResult<TResult>(
+ totalBytesReceived,
+ header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
+ header.EndOfMessage);
}
-
- _lastReceiveHeader = header;
- return GetReceiveResult<TResult>(
- totalBytesReceived,
- header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
- header.EndOfMessage);
+ }
+ finally
+ {
+ _receiveMutex.Exit();
}
}
catch (Exception exc) when (exc is not OperationCanceledException)
// additional data, but at this point we're about to close the connection and we're just stalling
// to try to get the server to close first.
ValueTask<int> finalReadTask = _stream.ReadAsync(_receiveBuffer, cancellationToken);
- if (!finalReadTask.IsCompletedSuccessfully)
+ if (finalReadTask.IsCompletedSuccessfully)
{
- const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same as netfx)
- using (var finalCts = new CancellationTokenSource(WaitForCloseTimeoutMs))
- using (finalCts.Token.Register(static s => ((ManagedWebSocket)s!).Abort(), this))
+ finalReadTask.GetAwaiter().GetResult();
+ }
+ else
+ {
+ const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same duration as .NET Framework)
+ try
{
- try
- {
- await finalReadTask.ConfigureAwait(false);
- }
- catch
- {
- // Eat any resulting exceptions. We were going to close the connection, anyway.
- }
+#pragma warning disable CA2016 // Token was already provided to the ReadAsync
+ await finalReadTask.AsTask().WaitAsync(TimeSpan.FromMilliseconds(WaitForCloseTimeoutMs)).ConfigureAwait(false);
+#pragma warning restore CA2016
+ }
+ catch
+ {
+ Abort();
+ // Eat any resulting exceptions. We were going to close the connection, anyway.
}
}
}
/// <returns>null if a valid header was read; non-null containing the string error message to use if the header was invalid.</returns>
private string? TryParseMessageHeaderFromReceiveBuffer(out MessageHeader resultHeader)
{
- Debug.Assert(_receiveBufferCount >= 2, $"Expected to at least have the first two bytes of the header.");
+ Debug.Assert(_receiveBufferCount >= 2, "Expected to at least have the first two bytes of the header.");
MessageHeader header = default;
Span<byte> receiveBufferSpan = _receiveBuffer.Span;
// Read the remainder of the payload length, if necessary
if (header.PayloadLength == 126)
{
- Debug.Assert(_receiveBufferCount >= 2, $"Expected to have two bytes for the payload length.");
+ Debug.Assert(_receiveBufferCount >= 2, "Expected to have two bytes for the payload length.");
header.PayloadLength = (receiveBufferSpan[_receiveBufferOffset] << 8) | receiveBufferSpan[_receiveBufferOffset + 1];
ConsumeFromBuffer(2);
}
else if (header.PayloadLength == 127)
{
- Debug.Assert(_receiveBufferCount >= 8, $"Expected to have eight bytes for the payload length.");
+ Debug.Assert(_receiveBufferCount >= 8, "Expected to have eight bytes for the payload length.");
header.PayloadLength = 0;
for (int i = 0; i < 8; i++)
{
}
// Return the read header
+ header.Processed = header.PayloadLength == 0 && !header.Compressed;
resultHeader = header;
- resultHeader.Processed = header.PayloadLength == 0 && !header.Compressed;
return null;
}
byte[] closeBuffer = ArrayPool<byte>.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength);
try
{
+ // Loop until we've received a close frame.
while (!_receivedCloseFrame)
{
- Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
- Task receiveTask;
- bool usingExistingReceive;
- lock (ReceiveAsyncLock)
+ // Enter the receive lock in order to get a consistent view of whether we've received a close
+ // frame. If we haven't, issue a receive. Since that receive will try to take the same
+ // non-entrant receive lock, we then exit the lock before waiting for the receive to complete,
+ // as it will always complete asynchronously and only after we've exited the lock.
+ ValueTask<ValueWebSocketReceiveResult> receiveTask = default;
+ try
{
- // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame.
- // It could have been received between our check above and now due to a concurrent receive completing.
- if (_receivedCloseFrame)
+ await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
+ try
{
- break;
- }
+ if (!_receivedCloseFrame)
+ {
+ receiveTask = ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(closeBuffer, cancellationToken);
+ }
- // We've not yet processed a received close frame, which means we need to wait for a received close to complete.
- // There may already be one in flight, in which case we want to just wait for that one rather than kicking off
- // another (we don't support concurrent receive operations). We need to kick off a new receive if either we've
- // never issued a receive or if the last issued receive completed for reasons other than a close frame. There is
- // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst
- // case is we then await it, find that it's not what we need, and try again.
- receiveTask = _lastReceiveAsync;
- Task newReceiveTask = ValidateAndReceiveAsync(receiveTask, closeBuffer, cancellationToken);
- usingExistingReceive = ReferenceEquals(receiveTask, newReceiveTask);
- _lastReceiveAsync = receiveTask = newReceiveTask;
+ }
+ finally
+ {
+ _receiveMutex.Exit();
+ }
}
-
- // Wait for whatever receive task we have. We'll then loop around again to re-check our state.
- // If this is an existing receive, and if we have a cancelable token, we need to register with that
- // token while we wait, since it may not be the same one that was given to the receive initially.
- Debug.Assert(receiveTask != null);
- using (usingExistingReceive ? cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this) : default)
+ catch (OperationCanceledException)
{
- await receiveTask.ConfigureAwait(false);
+ // If waiting on the receive lock was canceled, abort the connection, as we would do
+ // as part of the receive itself.
+ Abort();
+ throw;
}
+
+ // Wait for the receive to complete if we issued one.
+ await receiveTask.ConfigureAwait(false);
}
}
finally
count += s_textEncoding.GetByteCount(closeStatusDescription);
buffer = ArrayPool<byte>.Shared.Rent(count);
int encodedLength = s_textEncoding.GetBytes(closeStatusDescription, 0, closeStatusDescription.Length, buffer, 2);
- Debug.Assert(count - 2 == encodedLength, $"GetByteCount and GetBytes encoded count didn't match");
+ Debug.Assert(count - 2 == encodedLength, $"{nameof(s_textEncoding.GetByteCount)} and {nameof(s_textEncoding.GetBytes)} encoded count didn't match");
}
ushort closeStatusValue = (ushort)closeStatus;
private void ConsumeFromBuffer(int count)
{
- Debug.Assert(count >= 0, $"Expected non-negative count, got {count}");
+ Debug.Assert(count >= 0, $"Expected non-negative {nameof(count)}, got {count}");
Debug.Assert(count <= _receiveBufferCount, $"Trying to consume {count}, which is more than exists {_receiveBufferCount}");
_receiveBufferCount -= count;
_receiveBufferOffset += count;
}
+ [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))]
private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, CancellationToken cancellationToken, bool throwOnPrematureClosure = true)
{
Debug.Assert(minimumRequiredBytes <= _receiveBuffer.Length, $"Requested number of bytes {minimumRequiredBytes} must not exceed {_receiveBuffer.Length}");
/// <summary>Releases the send buffer to the pool.</summary>
private void ReleaseSendBuffer()
{
- Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock");
+ Debug.Assert(_sendMutex.IsHeld, $"Caller should hold the {nameof(_sendMutex)}");
if (_sendBuffer is byte[] toReturn)
{