if (array != null)
{
- ArrayPool<byte>.Shared.Return(array, true);
+ ArrayPool<byte>.Shared.Return(array);
}
}
}
<Compile Include="System\Net\Http\Headers\UriHeaderParser.cs" />
<Compile Include="System\Net\Http\Headers\ViaHeaderValue.cs" />
<Compile Include="System\Net\Http\Headers\WarningHeaderValue.cs" />
+ <Compile Include="System\Threading\AsyncMutex.cs" />
<Compile Include="$(CommonPath)System\IO\StreamHelpers.CopyValidation.cs"
Link="Common\System\IO\StreamHelpers.CopyValidation.cs" />
<Compile Include="$(CommonPath)System\Net\Security\SslClientAuthenticationOptionsExtensions.cs"
<Compile Include="$(CommonPath)\System\Threading\Tasks\RendezvousAwaitable.cs"
Link="Common\System\Threading\Tasks\RendezvousAwaitable.cs" />
<Compile Include="$(CommonPath)System\Threading\Tasks\TaskToApm.cs"
- Link="System\System\Threading\Tasks\TaskToApm.cs" />
+ Link="Common\System\Threading\Tasks\TaskToApm.cs" />
<Compile Include="$(CommonPath)Interop\Windows\SChannel\Interop.SecPkgContext_ApplicationProtocol.cs"
Link="Common\Interop\Windows\SChannel\Interop.SecPkgContext_ApplicationProtocol.cs" />
<Compile Include="$(CommonPath)System\Net\Security\SecurityBuffer.Windows.cs"
}
}
- if (_knownHeader == KnownHeaders.Location)
+ if (_knownHeader == KnownHeaders.ContentType)
+ {
+ string? contentType = GetKnownContentType(headerValue);
+ if (contentType != null)
+ {
+ return contentType;
+ }
+ }
+ else if (_knownHeader == KnownHeaders.Location)
{
// Normally Location should be in ISO-8859-1 but occasionally some servers respond with UTF-8.
if (TryDecodeUtf8(headerValue, out string? decoded))
return HttpRuleParser.DefaultHttpEncoding.GetString(headerValue);
}
+ internal static string? GetKnownContentType(ReadOnlySpan<byte> contentTypeValue)
+ {
+ string? candidate = null;
+ switch (contentTypeValue.Length)
+ {
+ case 8:
+ switch (contentTypeValue[7] | 0x20)
+ {
+ case 'l': candidate = "text/xml"; break; // text/xm[l]
+ case 's': candidate = "text/css"; break; // text/cs[s]
+ case 'v': candidate = "text/csv"; break; // text/cs[v]
+ }
+ break;
+
+ case 9:
+ switch (contentTypeValue[6] | 0x20)
+ {
+ case 'g': candidate = "image/gif"; break; // image/[g]if
+ case 'p': candidate = "image/png"; break; // image/[p]ng
+ case 't': candidate = "text/html"; break; // text/h[t]ml
+ }
+ break;
+
+ case 10:
+ switch (contentTypeValue[0] | 0x20)
+ {
+ case 't': candidate = "text/plain"; break; // [t]ext/plain
+ case 'i': candidate = "image/jpeg"; break; // [i]mage/jpeg
+ }
+ break;
+
+ case 15:
+ switch (contentTypeValue[12] | 0x20)
+ {
+ case 'p': candidate = "application/pdf"; break; // application/[p]df
+ case 'x': candidate = "application/xml"; break; // application/[x]ml
+ case 'z': candidate = "application/zip"; break; // application/[z]ip
+ }
+ break;
+
+ case 16:
+ switch (contentTypeValue[12] | 0x20)
+ {
+ case 'g': candidate = "application/grpc"; break; // application/[g]rpc
+ case 'j': candidate = "application/json"; break; // application/[j]son
+ }
+ break;
+
+ case 19:
+ candidate = "multipart/form-data"; // multipart/form-data
+ break;
+
+ case 22:
+ candidate = "application/javascript"; // application/javascript
+ break;
+
+ case 24:
+ switch (contentTypeValue[0] | 0x20)
+ {
+ case 'a': candidate = "application/octet-stream"; break; // application/octet-stream
+ case 't': candidate = "text/html; charset=utf-8"; break; // text/html; charset=utf-8
+ }
+ break;
+
+ case 25:
+ candidate = "text/plain; charset=utf-8"; // text/plain; charset=utf-8
+ break;
+
+ case 31:
+ candidate = "application/json; charset=utf-8"; // application/json; charset=utf-8
+ break;
+
+ case 33:
+ candidate = "application/x-www-form-urlencoded"; // application/x-www-form-urlencoded
+ break;
+ }
+
+ Debug.Assert(candidate is null || candidate.Length == contentTypeValue.Length);
+
+ return candidate != null && ByteArrayHelpers.EqualsOrdinalAsciiIgnoreCase(candidate, contentTypeValue) ?
+ candidate :
+ null;
+ }
+
private static bool TryDecodeUtf8(ReadOnlySpan<byte> input, [NotNullWhen(true)] out string? decoded)
{
char[] rented = ArrayPool<char>.Shared.Rent(input.Length);
return values;
}
- internal static int GetValuesAsStrings(HeaderDescriptor descriptor, object sourceValues, ref string[] values)
+ internal static int GetValuesAsStrings(HeaderDescriptor descriptor, object sourceValues, [NotNull] ref string[]? values)
{
+ values ??= Array.Empty<string>();
+
HeaderStoreItemInfo? info = sourceValues as HeaderStoreItemInfo;
if (info is null)
{
return 1;
}
- Debug.Assert(values != null);
int length = GetValueCount(info);
if (length > 0)
public static readonly KnownHeader ExpectCT = new KnownHeader("Expect-CT");
public static readonly KnownHeader Expires = new KnownHeader("Expires", HttpHeaderType.Content | HttpHeaderType.NonTrailing, DateHeaderParser.Parser, null, H2StaticTable.Expires);
public static readonly KnownHeader From = new KnownHeader("From", HttpHeaderType.Request, GenericHeaderParser.MailAddressParser, null, H2StaticTable.From);
+ public static readonly KnownHeader GrpcEncoding = new KnownHeader("grpc-encoding", HttpHeaderType.Custom, null, new string[] { "identity", "gzip", "deflate" });
+ public static readonly KnownHeader GrpcMessage = new KnownHeader("grpc-message");
+ public static readonly KnownHeader GrpcStatus = new KnownHeader("grpc-status", HttpHeaderType.Custom, null, new string[] { "0" });
public static readonly KnownHeader Host = new KnownHeader("Host", HttpHeaderType.Request | HttpHeaderType.NonTrailing, GenericHeaderParser.HostParser, null, H2StaticTable.Host);
public static readonly KnownHeader IfMatch = new KnownHeader("If-Match", HttpHeaderType.Request | HttpHeaderType.NonTrailing, GenericHeaderParser.MultipleValueEntityTagParser, null, H2StaticTable.IfMatch);
public static readonly KnownHeader IfModifiedSince = new KnownHeader("If-Modified-Since", HttpHeaderType.Request | HttpHeaderType.NonTrailing, DateHeaderParser.Parser, null, H2StaticTable.IfModifiedSince, H3StaticTable.IfModifiedSince);
switch (key[0] | 0x20)
{
case 'c': return ContentMD5; // [C]ontent-MD5
+ case 'g': return GrpcStatus; // [g]rpc-status
case 'r': return RetryAfter; // [R]etry-After
case 's': return SetCookie2; // [S]et-Cookie2
}
break;
case 12:
- switch (key[2] | 0x20)
+ switch (key[5] | 0x20)
{
- case 'c': return AcceptPatch; // Ac[c]ept-Patch
- case 'm': return XMSEdgeRef; // X-[M]SEdge-Ref
- case 'n': return ContentType; // Co[n]tent-Type
- case 'p': return XPoweredBy; // X-[P]owered-By
- case 'r': return XRequestID; // X-[R]equest-ID
- case 'x': return MaxForwards; // Ma[x]-Forwards
+ case 'd': return XMSEdgeRef; // X-MSE[d]ge-Ref
+ case 'e': return XPoweredBy; // X-Pow[e]red-By
+ case 'm': return GrpcMessage; // grpc-[m]essage
+ case 'n': return ContentType; // Conte[n]t-Type
+ case 'o': return MaxForwards; // Max-F[o]rwards
+ case 't': return AcceptPatch; // Accep[t]-Patch
+ case 'u': return XRequestID; // X-Req[u]est-ID
}
break;
{
case 'd': return LastModified; // Last-Modifie[d]
case 'e': return ContentRange; // Content-Rang[e]
- case 'g': return ServerTiming; // Server-Timin[g]
+ case 'g':
+ switch (key[0] | 0x20)
+ {
+ case 's': return ServerTiming; // [S]erver-Timin[g]
+ case 'g': return GrpcEncoding; // [g]rpc-encodin[g]
+ }
+ break;
case 'h': return IfNoneMatch; // If-None-Matc[h]
case 'l': return CacheControl; // Cache-Contro[l]
case 'n': return Authorization; // Authorizatio[n]
try
{
- ArraySegment<byte> buffer;
- if (TryGetBuffer(out buffer))
- {
- return CopyToAsyncCore(stream.WriteAsync(new ReadOnlyMemory<byte>(buffer.Array, buffer.Offset, buffer.Count), cancellationToken));
- }
- else
- {
- Task task = SerializeToStreamAsync(stream, context, cancellationToken);
- CheckTaskNotNull(task);
- return CopyToAsyncCore(new ValueTask(task));
- }
+ return WaitAsync(InternalCopyToAsync(stream, context, cancellationToken));
}
catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e))
{
return Task.FromException(GetStreamCopyException(e));
}
- }
- private static async Task CopyToAsyncCore(ValueTask copyTask)
- {
- try
+ static async Task WaitAsync(ValueTask copyTask)
{
- await copyTask.ConfigureAwait(false);
- }
- catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e))
- {
- throw WrapStreamCopyException(e);
+ try
+ {
+ await copyTask.ConfigureAwait(false);
+ }
+ catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e))
+ {
+ throw WrapStreamCopyException(e);
+ }
}
}
- public Task LoadIntoBufferAsync()
+ internal ValueTask InternalCopyToAsync(Stream stream, TransportContext? context, CancellationToken cancellationToken)
{
- return LoadIntoBufferAsync(MaxBufferSize);
+ if (TryGetBuffer(out ArraySegment<byte> buffer))
+ {
+ return stream.WriteAsync(buffer, cancellationToken);
+ }
+
+ Task task = SerializeToStreamAsync(stream, context, cancellationToken);
+ CheckTaskNotNull(task);
+ return new ValueTask(task);
}
+ public Task LoadIntoBufferAsync() =>
+ LoadIntoBufferAsync(MaxBufferSize);
+
// No "CancellationToken" parameter needed since canceling the CTS will close the connection, resulting
// in an exception being thrown while we're buffering.
// If buffering is used without a connection, it is supposed to be fast, thus no cancellation required.
/// </summary>
internal static HttpMethod Normalize(HttpMethod method)
{
+ // _http3EncodedBytes is only set for the singleton instances, so if it's not null,
+ // we can avoid the dictionary lookup. Otherwise, look up the method instance in the
+ // dictionary and return the normalized instance if it's found.
Debug.Assert(method != null);
- return s_knownMethods.TryGetValue(method, out HttpMethod? normalized) ?
- normalized :
+ return
+ method._http3EncodedBytes is null && s_knownMethods.TryGetValue(method, out HttpMethod? normalized) ? normalized :
method;
}
// NOTE: These are mutable structs; do not make these readonly.
private ArrayBuffer _incomingBuffer;
private ArrayBuffer _outgoingBuffer;
- private ArrayBuffer _headerBuffer;
/// <summary>Reusable array used to get the values for each header being written to the wire.</summary>
- private string[] _headerValues = Array.Empty<string>();
+ [ThreadStatic]
+ private static string[]? t_headerValues;
private int _currentWriteSize; // as passed to StartWriteAsync
private readonly Dictionary<int, Http2Stream> _httpStreams;
- private readonly SemaphoreSlim _writerLock;
- private readonly SemaphoreSlim _headerSerializationLock;
-
+ private readonly AsyncMutex _writerLock;
private readonly CreditManager _connectionWindow;
private readonly CreditManager _concurrentStreams;
_stream = stream;
_incomingBuffer = new ArrayBuffer(InitialConnectionBufferSize);
_outgoingBuffer = new ArrayBuffer(InitialConnectionBufferSize);
- _headerBuffer = new ArrayBuffer(InitialConnectionBufferSize);
_hpackDecoder = new HPackDecoder(maxHeadersLength: pool.Settings._maxResponseHeadersLength * 1024);
_httpStreams = new Dictionary<int, Http2Stream>();
- _writerLock = new SemaphoreSlim(1, 1);
- _headerSerializationLock = new SemaphoreSlim(1, 1);
+ _writerLock = new AsyncMutex();
_connectionWindow = new CreditManager(this, nameof(_connectionWindow), DefaultInitialWindowSize);
_concurrentStreams = new CreditManager(this, nameof(_concurrentStreams), int.MaxValue);
private async ValueTask<Memory<byte>> StartWriteAsync(int writeBytes, CancellationToken cancellationToken = default)
{
if (NetEventSource.IsEnabled) Trace($"{nameof(writeBytes)}={writeBytes}");
- await AcquireWriteLockAsync(cancellationToken).ConfigureAwait(false);
+ // Acquire the write lock
+ ValueTask acquireLockTask = _writerLock.EnterAsync(cancellationToken);
+ if (acquireLockTask.IsCompletedSuccessfully)
+ {
+ acquireLockTask.GetAwaiter().GetResult(); // to enable the value task sources to be pooled
+ }
+ else
+ {
+ Interlocked.Increment(ref _pendingWriters);
+ try
+ {
+ await acquireLockTask.ConfigureAwait(false);
+ }
+ catch
+ {
+ if (Interlocked.Decrement(ref _pendingWriters) == 0)
+ {
+ // If a pending waiter is canceled, we may end up in a situation where a previously written frame
+ // saw that there were pending writers and as such deferred its flush to them, but if/when that pending
+ // writer is canceled, nothing may end up flushing the deferred work (at least not promptly). To compensate,
+ // if a pending writer does end up being canceled, we flush asynchronously. We can't check whether there's such
+ // a pending operation because we failed to acquire the lock that protects that state. But we can at least only
+ // do the flush if our decrement caused the pending count to reach 0: if it's still higher than zero, then there's
+ // at least one other pending writer who can handle the flush. Worst case, we pay for a flush that ends up being
+ // a nop. Note: we explicitly do not pass in the cancellationToken; if we're here, it's almost certainly because
+ // cancellation was requested, and it's because of that cancellation that we need to flush.
+ LogExceptions(FlushAsync(cancellationToken: default));
+ }
+
+ throw;
+ }
+ Interlocked.Decrement(ref _pendingWriters);
+ }
+
+ // If the connection has been aborted, then fail now instead of trying to send more data.
+ if (_abortException != null)
+ {
+ _writerLock.Exit();
+ throw new IOException(SR.net_http_request_aborted, _abortException);
+ }
+
+ // Flush anything necessary, and return back the write buffer to use.
try
{
// If there is a pending write that was canceled while in progress, wait for it to complete.
if (_inProgressWrite != null)
{
- await _inProgressWrite.ConfigureAwait(false);
+ await new ValueTask(_inProgressWrite).ConfigureAwait(false); // await ValueTask to minimize number of awaiter fields
_inProgressWrite = null;
}
int totalBufferLength = _outgoingBuffer.Capacity;
int activeBufferLength = _outgoingBuffer.ActiveLength;
+ // If the buffer has already grown to 32k, does not have room for the next request,
+ // and is non-empty, flush the current contents to the wire.
if (totalBufferLength >= UnflushedOutgoingBufferSize &&
writeBytes >= totalBufferLength - activeBufferLength &&
activeBufferLength > 0)
{
- // If the buffer has already grown to 32k, does not have room for the next request,
- // and is non-empty, flush the current contents to the wire.
- await FlushOutgoingBytesAsync().ConfigureAwait(false); // we explicitly do not pass cancellationToken here, as this flush impacts more than just this operation
+ // We explicitly do not pass cancellationToken here, as this flush impacts more than just this operation.
+ await new ValueTask(FlushOutgoingBytesAsync()).ConfigureAwait(false); // await ValueTask to minimize number of awaiter fields
}
_outgoingBuffer.EnsureAvailableSpace(writeBytes);
}
catch
{
- _writerLock.Release();
+ _writerLock.Exit();
throw;
}
}
{
if (NetEventSource.IsEnabled) Trace($"{nameof(flush)}={flush}");
- // We can't validate that we hold the semaphore, but we can at least validate that someone is holding it.
- Debug.Assert(_writerLock.CurrentCount == 0);
+ // We can't validate that we hold the mutex, but we can at least validate that someone is holding it.
+ Debug.Assert(_writerLock.IsHeld);
_outgoingBuffer.Commit(_currentWriteSize);
_lastPendingWriterShouldFlush |= (flush == FlushTiming.AfterPendingWrites);
{
if (NetEventSource.IsEnabled) Trace("");
- // We can't validate that we hold the semaphore, but we can at least validate that someone is holding it.
- Debug.Assert(_writerLock.CurrentCount == 0);
+ // We can't validate that we hold the mutex, but we can at least validate that someone is holding it.
+ Debug.Assert(_writerLock.IsHeld);
EndWrite(forceFlush: false);
}
private void EndWrite(bool forceFlush)
{
- // We can't validate that we hold the semaphore, but we can at least validate that someone is holding it.
- Debug.Assert(_writerLock.CurrentCount == 0);
+ // We can't validate that we hold the mutex, but we can at least validate that someone is holding it.
+ Debug.Assert(_writerLock.IsHeld);
try
{
}
finally
{
- _writerLock.Release();
+ _writerLock.Exit();
}
}
private async ValueTask AcquireWriteLockAsync(CancellationToken cancellationToken)
{
- Task acquireLockTask = _writerLock.WaitAsync(cancellationToken);
- if (!acquireLockTask.IsCompletedSuccessfully)
+ ValueTask acquireLockTask = _writerLock.EnterAsync(cancellationToken);
+ if (acquireLockTask.IsCompletedSuccessfully)
+ {
+ acquireLockTask.GetAwaiter().GetResult(); // to enable the value task sources to be pooled
+ }
+ else
{
Interlocked.Increment(ref _pendingWriters);
// at least one other pending writer who can handle the flush. Worst case, we pay for a flush that ends up being
// a nop. Note: we explicitly do not pass in the cancellationToken; if we're here, it's almost certainly because
// cancellation was requested, and it's because of that cancellation that we need to flush.
- LogExceptions(FlushAsync());
+ LogExceptions(FlushAsync(cancellationToken: default));
}
throw;
// If the connection has been aborted, then fail now instead of trying to send more data.
if (_abortException != null)
{
- _writerLock.Release();
+ _writerLock.Exit();
throw new IOException(SR.net_http_request_aborted, _abortException);
}
}
(buffer.Slice(0, maxSize), buffer.Slice(maxSize)) :
(buffer, Memory<byte>.Empty);
- private void WriteIndexedHeader(int index)
+ private void WriteIndexedHeader(int index, ref ArrayBuffer headerBuffer)
{
if (NetEventSource.IsEnabled) Trace($"{nameof(index)}={index}");
int bytesWritten;
- while (!HPackEncoder.EncodeIndexedHeaderField(index, _headerBuffer.AvailableSpan, out bytesWritten))
+ while (!HPackEncoder.EncodeIndexedHeaderField(index, headerBuffer.AvailableSpan, out bytesWritten))
{
- _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1);
+ headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
}
- _headerBuffer.Commit(bytesWritten);
+ headerBuffer.Commit(bytesWritten);
}
- private void WriteIndexedHeader(int index, string value)
+ private void WriteIndexedHeader(int index, string value, ref ArrayBuffer headerBuffer)
{
if (NetEventSource.IsEnabled) Trace($"{nameof(index)}={index}, {nameof(value)}={value}");
int bytesWritten;
- while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexing(index, value, _headerBuffer.AvailableSpan, out bytesWritten))
+ while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexing(index, value, headerBuffer.AvailableSpan, out bytesWritten))
{
- _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1);
+ headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
}
- _headerBuffer.Commit(bytesWritten);
+ headerBuffer.Commit(bytesWritten);
}
- private void WriteLiteralHeader(string name, ReadOnlySpan<string> values)
+ private void WriteLiteralHeader(string name, ReadOnlySpan<string> values, ref ArrayBuffer headerBuffer)
{
if (NetEventSource.IsEnabled) Trace($"{nameof(name)}={name}, {nameof(values)}={string.Join(", ", values.ToArray())}");
int bytesWritten;
- while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, _headerBuffer.AvailableSpan, out bytesWritten))
+ while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, headerBuffer.AvailableSpan, out bytesWritten))
{
- _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1);
+ headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
}
- _headerBuffer.Commit(bytesWritten);
+ headerBuffer.Commit(bytesWritten);
}
- private void WriteLiteralHeaderValues(ReadOnlySpan<string> values, string? separator)
+ private void WriteLiteralHeaderValues(ReadOnlySpan<string> values, string? separator, ref ArrayBuffer headerBuffer)
{
if (NetEventSource.IsEnabled) Trace($"{nameof(values)}={string.Join(separator, values.ToArray())}");
int bytesWritten;
- while (!HPackEncoder.EncodeStringLiterals(values, separator, _headerBuffer.AvailableSpan, out bytesWritten))
+ while (!HPackEncoder.EncodeStringLiterals(values, separator, headerBuffer.AvailableSpan, out bytesWritten))
{
- _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1);
+ headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
}
- _headerBuffer.Commit(bytesWritten);
+ headerBuffer.Commit(bytesWritten);
}
- private void WriteLiteralHeaderValue(string value)
+ private void WriteLiteralHeaderValue(string value, ref ArrayBuffer headerBuffer)
{
if (NetEventSource.IsEnabled) Trace($"{nameof(value)}={value}");
int bytesWritten;
- while (!HPackEncoder.EncodeStringLiteral(value, _headerBuffer.AvailableSpan, out bytesWritten))
+ while (!HPackEncoder.EncodeStringLiteral(value, headerBuffer.AvailableSpan, out bytesWritten))
{
- _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1);
+ headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1);
}
- _headerBuffer.Commit(bytesWritten);
+ headerBuffer.Commit(bytesWritten);
}
- private void WriteBytes(ReadOnlySpan<byte> bytes)
+ private void WriteBytes(ReadOnlySpan<byte> bytes, ref ArrayBuffer headerBuffer)
{
if (NetEventSource.IsEnabled) Trace($"{nameof(bytes.Length)}={bytes.Length}");
- if (bytes.Length > _headerBuffer.AvailableLength)
+ if (bytes.Length > headerBuffer.AvailableLength)
{
- _headerBuffer.EnsureAvailableSpace(bytes.Length);
+ headerBuffer.EnsureAvailableSpace(bytes.Length);
}
- bytes.CopyTo(_headerBuffer.AvailableSpan);
- _headerBuffer.Commit(bytes.Length);
+ bytes.CopyTo(headerBuffer.AvailableSpan);
+ headerBuffer.Commit(bytes.Length);
}
- private void WriteHeaderCollection(HttpHeaders headers)
+ private void WriteHeaderCollection(HttpHeaders headers, ref ArrayBuffer headerBuffer)
{
if (NetEventSource.IsEnabled) Trace("");
return;
}
+ ref string[]? tmpHeaderValuesArray = ref t_headerValues;
foreach (KeyValuePair<HeaderDescriptor, object> header in headers.HeaderStore)
{
- int headerValuesCount = HttpHeaders.GetValuesAsStrings(header.Key, header.Value, ref _headerValues);
+ int headerValuesCount = HttpHeaders.GetValuesAsStrings(header.Key, header.Value, ref tmpHeaderValuesArray);
Debug.Assert(headerValuesCount > 0, "No values for header??");
- ReadOnlySpan<string> headerValues = _headerValues.AsSpan(0, headerValuesCount);
+ ReadOnlySpan<string> headerValues = tmpHeaderValuesArray.AsSpan(0, headerValuesCount);
KnownHeader? knownHeader = header.Key.KnownHeader;
if (knownHeader != null)
{
if (string.Equals(value, "trailers", StringComparison.OrdinalIgnoreCase))
{
- WriteBytes(knownHeader.Http2EncodedName);
- WriteLiteralHeaderValue(value);
+ WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer);
+ WriteLiteralHeaderValue(value, ref headerBuffer);
break;
}
}
}
// For all other known headers, send them via their pre-encoded name and the associated value.
- WriteBytes(knownHeader.Http2EncodedName);
+ WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer);
string? separator = null;
if (headerValues.Length > 1)
{
}
}
- WriteLiteralHeaderValues(headerValues, separator);
+ WriteLiteralHeaderValues(headerValues, separator, ref headerBuffer);
}
}
else
{
// The header is not known: fall back to just encoding the header name and value(s).
- WriteLiteralHeader(header.Key.Name, headerValues);
+ WriteLiteralHeader(header.Key.Name, headerValues, ref headerBuffer);
}
}
}
- private void WriteHeaders(HttpRequestMessage request)
+ private void WriteHeaders(HttpRequestMessage request, ref ArrayBuffer headerBuffer)
{
if (NetEventSource.IsEnabled) Trace("");
- Debug.Assert(_headerBuffer.ActiveLength == 0);
// HTTP2 does not support Transfer-Encoding: chunked, so disable this on the request.
if (request.HasHeaders && request.Headers.TransferEncodingChunked == true)
// Method is normalized so we can do reference equality here.
if (ReferenceEquals(normalizedMethod, HttpMethod.Get))
{
- WriteIndexedHeader(H2StaticTable.MethodGet);
+ WriteIndexedHeader(H2StaticTable.MethodGet, ref headerBuffer);
}
else if (ReferenceEquals(normalizedMethod, HttpMethod.Post))
{
- WriteIndexedHeader(H2StaticTable.MethodPost);
+ WriteIndexedHeader(H2StaticTable.MethodPost, ref headerBuffer);
}
else
{
- WriteIndexedHeader(H2StaticTable.MethodGet, normalizedMethod.Method);
+ WriteIndexedHeader(H2StaticTable.MethodGet, normalizedMethod.Method, ref headerBuffer);
}
- WriteIndexedHeader(_stream is SslStream ? H2StaticTable.SchemeHttps : H2StaticTable.SchemeHttp);
+ WriteIndexedHeader(_stream is SslStream ? H2StaticTable.SchemeHttps : H2StaticTable.SchemeHttp, ref headerBuffer);
if (request.HasHeaders && request.Headers.Host != null)
{
- WriteIndexedHeader(H2StaticTable.Authority, request.Headers.Host);
+ WriteIndexedHeader(H2StaticTable.Authority, request.Headers.Host, ref headerBuffer);
}
else
{
- WriteBytes(_pool._http2EncodedAuthorityHostHeader);
+ WriteBytes(_pool._http2EncodedAuthorityHostHeader, ref headerBuffer);
}
Debug.Assert(request.RequestUri != null);
string pathAndQuery = request.RequestUri.PathAndQuery;
if (pathAndQuery == "/")
{
- WriteIndexedHeader(H2StaticTable.PathSlash);
+ WriteIndexedHeader(H2StaticTable.PathSlash, ref headerBuffer);
}
else
{
- WriteIndexedHeader(H2StaticTable.PathSlash, pathAndQuery);
+ WriteIndexedHeader(H2StaticTable.PathSlash, pathAndQuery, ref headerBuffer);
}
if (request.HasHeaders)
{
- WriteHeaderCollection(request.Headers);
+ WriteHeaderCollection(request.Headers, ref headerBuffer);
}
// Determine cookies to send.
string cookiesFromContainer = _pool.Settings._cookieContainer!.GetCookieHeader(request.RequestUri);
if (cookiesFromContainer != string.Empty)
{
- WriteBytes(KnownHeaders.Cookie.Http2EncodedName);
- WriteLiteralHeaderValue(cookiesFromContainer);
+ WriteBytes(KnownHeaders.Cookie.Http2EncodedName, ref headerBuffer);
+ WriteLiteralHeaderValue(cookiesFromContainer, ref headerBuffer);
}
}
// unless this is a method that never has a body.
if (normalizedMethod.MustHaveRequestBody)
{
- WriteBytes(KnownHeaders.ContentLength.Http2EncodedName);
- WriteLiteralHeaderValue("0");
+ WriteBytes(KnownHeaders.ContentLength.Http2EncodedName, ref headerBuffer);
+ WriteLiteralHeaderValue("0", ref headerBuffer);
}
}
else
{
- WriteHeaderCollection(request.Content.Headers);
+ WriteHeaderCollection(request.Content.Headers, ref headerBuffer);
}
}
private async ValueTask<Http2Stream> SendHeadersAsync(HttpRequestMessage request, CancellationToken cancellationToken, bool mustFlush)
{
- // We serialize usage of the header encoder and the header buffer.
- // This also ensures that new streams are always created in ascending order.
- await _headerSerializationLock.WaitAsync(cancellationToken).ConfigureAwait(false);
+ // Enforce MAX_CONCURRENT_STREAMS setting value. We do this before anything else, e.g. renting buffers to serialize headers,
+ // in order to avoid consuming resources in potentially many requests waiting for access.
try
{
- // Generate the entire header block, without framing, into the connection header buffer.
- WriteHeaders(request);
-
- try
- {
- // Enforce MAX_CONCURRENT_STREAMS setting value.
- await _concurrentStreams.RequestCreditAsync(1, cancellationToken).ConfigureAwait(false);
- }
- catch (ObjectDisposedException)
+ await _concurrentStreams.RequestCreditAsync(1, cancellationToken).ConfigureAwait(false);
+ }
+ catch (ObjectDisposedException)
+ {
+ // We have race condition between shutting down and initiating new requests.
+ // When we are shutting down the connection (e.g. due to receiving GOAWAY, etc)
+ // we will wait until the stream count goes to 0, and then we will close the connetion
+ // and perform clean up, including disposing _concurrentStreams.
+ // So if we get ObjectDisposedException here, we must have shut down the connection.
+ // Throw a retryable request exception if this is not result of some other error.
+ // This will cause retry logic to kick in and perform another connection attempt.
+ // The user should never see this exception. See similar handling below.
+ // Throw a retryable request exception if this is not result of some other error.
+ // This will cause retry logic to kick in and perform another connection attempt.
+ // The user should never see this exception. See also below.
+ lock (SyncObject)
{
- // We have race condition between shutting down and initiating new requests.
- // When we are shutting down the connection (e.g. due to receiving GOAWAY, etc)
- // we will wait until the stream count goes to 0, and then we will close the connetion
- // and perform clean up, including disposing _concurrentStreams.
- // So if we get ObjectDisposedException here, we must have shut down the connection.
- // Throw a retryable request exception if this is not result of some other error.
- // This will cause retry logic to kick in and perform another connection attempt.
- // The user should never see this exception. See similar handling below.
- // Throw a retryable request exception if this is not result of some other error.
- // This will cause retry logic to kick in and perform another connection attempt.
- // The user should never see this exception. See also below.
Debug.Assert(_disposed || _lastStreamId != -1);
Debug.Assert(_httpStreams.Count == 0);
-
- lock (SyncObject)
- {
- throw GetShutdownException();
- }
+ throw GetShutdownException();
}
+ }
+
+ ArrayBuffer headerBuffer = default;
+ try
+ {
+ // Serialize headers to a temporary buffer, and do as much work to prepare to send the headers as we can
+ // before taking the write lock.
+ headerBuffer = new ArrayBuffer(InitialConnectionBufferSize, usePool: true);
+ WriteHeaders(request, ref headerBuffer);
+ ReadOnlyMemory<byte> remaining = headerBuffer.ActiveMemory;
+ Debug.Assert(remaining.Length > 0);
+
+ // Calculate the total number of bytes we're going to use (content + headers).
+ int frameCount = ((remaining.Length - 1) / FrameHeader.MaxLength) + 1;
+ int totalSize = remaining.Length + (frameCount * FrameHeader.Size);
+ ReadOnlyMemory<byte> current;
+ (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength);
+ FrameFlags flags =
+ (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None) |
+ (request.Content == null ? FrameFlags.EndStream : FrameFlags.None);
+
+ // Start the write. This serializes access to write to the connection, and ensures that HEADERS
+ // and CONTINUATION frames stay together, as they must do. We use the lock as well to ensure new
+ // streams are created and started in order.
+ Memory<byte> writeBuffer = await StartWriteAsync(totalSize, cancellationToken).ConfigureAwait(false);
try
{
- // Allocate the next available stream ID.
- // Note that if we fail before sending the headers, we'll just skip this stream ID, which is fine.
+ // Allocate the next available stream ID. Note that if we fail before sending the headers,
+ // we'll just skip this stream ID, which is fine.
int streamId;
lock (SyncObject)
{
_nextStream += 2;
}
- ReadOnlyMemory<byte> remaining = _headerBuffer.ActiveMemory;
- Debug.Assert(remaining.Length > 0);
-
- // Calculate the total number of bytes we're going to use (content + headers).
- int frameCount = ((remaining.Length - 1) / FrameHeader.MaxLength) + 1;
- int totalSize = remaining.Length + frameCount * FrameHeader.Size;
-
- // Note, HEADERS and CONTINUATION frames must be together, so hold the writer lock across sending all of them.
- Memory<byte> writeBuffer = await StartWriteAsync(totalSize, cancellationToken).ConfigureAwait(false);
if (NetEventSource.IsEnabled) Trace(streamId, $"Started writing. {nameof(totalSize)}={totalSize}");
- // Send the HEADERS frame.
- ReadOnlyMemory<byte> current;
- (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength);
-
- FrameFlags flags =
- (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None) |
- (request.Content == null ? FrameFlags.EndStream : FrameFlags.None);
-
- FrameHeader frameHeader = new FrameHeader(current.Length, FrameType.Headers, flags, streamId);
- frameHeader.WriteTo(writeBuffer.Span);
+ // Copy the HEADERS frame.
+ new FrameHeader(current.Length, FrameType.Headers, flags, streamId).WriteTo(writeBuffer.Span);
writeBuffer = writeBuffer.Slice(FrameHeader.Size);
-
current.CopyTo(writeBuffer);
writeBuffer = writeBuffer.Slice(current.Length);
-
if (NetEventSource.IsEnabled) Trace(streamId, $"Wrote HEADERS frame. Length={current.Length}, flags={flags}");
- // Send CONTINUATION frames, if any.
+ // Copy CONTINUATION frames, if any.
while (remaining.Length > 0)
{
(current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength);
+ flags = remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None;
- flags = (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None);
-
- frameHeader = new FrameHeader(current.Length, FrameType.Continuation, flags, streamId);
- frameHeader.WriteTo(writeBuffer.Span);
+ new FrameHeader(current.Length, FrameType.Continuation, flags, streamId).WriteTo(writeBuffer.Span);
writeBuffer = writeBuffer.Slice(FrameHeader.Size);
-
current.CopyTo(writeBuffer);
writeBuffer = writeBuffer.Slice(current.Length);
-
if (NetEventSource.IsEnabled) Trace(streamId, $"Wrote CONTINUATION frame. Length={current.Length}, flags={flags}");
}
Debug.Assert(writeBuffer.Length == 0);
- Http2Stream http2Stream;
- try
- {
- // We're about to write the HEADERS frame, so add the stream to the dictionary now.
- // The lifetime of the stream is now controlled by the stream itself and the connection.
- // This can fail if the connection is shutting down, in which case we will cancel sending this frame.
- http2Stream = AddStream(streamId, request);
- }
- catch
- {
- CancelWrite();
- throw;
- }
+ // We're about to flush the HEADERS frame, so add the stream to the dictionary now.
+ // The lifetime of the stream is now controlled by the stream itself and the connection.
+ // This can fail if the connection is shutting down, in which case we will cancel sending this frame.
+ Http2Stream http2Stream = AddStream(streamId, request);
FinishWrite(mustFlush || (flags & FrameFlags.EndStream) != 0 ? FlushTiming.AfterPendingWrites : FlushTiming.Eventually);
-
return http2Stream;
}
catch
{
- _concurrentStreams.AdjustCredit(1);
+ CancelWrite();
throw;
}
}
+ catch
+ {
+ _concurrentStreams.AdjustCredit(1);
+ throw;
+ }
finally
{
- _headerBuffer.Discard(_headerBuffer.ActiveLength);
- _headerSerializationLock.Release();
+ headerBuffer.Dispose();
}
}
writeBuffer = await StartWriteAsync(FrameHeader.Size + current.Length, cancellationToken).ConfigureAwait(false);
if (NetEventSource.IsEnabled) Trace(streamId, $"Started writing. {nameof(writeBuffer.Length)}={writeBuffer.Length}");
}
- catch (OperationCanceledException)
+ catch
{
_connectionWindow.AdjustCredit(frameSize);
throw;
/// Reset _waitSource.
/// </summary>
private ManualResetValueTaskSourceCore<bool> _waitSource = new ManualResetValueTaskSourceCore<bool> { RunContinuationsAsynchronously = true }; // mutable struct, do not make this readonly
+ /// <summary>Cancellation registration used to cancel the <see cref="_waitSource"/>.</summary>
+ private CancellationTokenRegistration _waitSourceCancellation;
/// <summary>
/// Whether code has requested or is about to request a wait be performed and thus requires a call to SetResult to complete it.
/// This is read and written while holding the lock so that most operations on _waitSource don't need to be.
{
using (Http2WriteStream writeStream = new Http2WriteStream(this))
{
- await _request.Content.CopyToAsync(writeStream, null, _requestBodyCancellationToken).ConfigureAwait(false);
+ await _request.Content.InternalCopyToAsync(writeStream, null, _requestBodyCancellationToken).ConfigureAwait(false);
}
}
// associated with the implementation is just delegated to the ManualResetValueTaskSourceCore.
ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitSource.GetStatus(token);
void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => _waitSource.OnCompleted(continuation, state, token, flags);
- void IValueTaskSource.GetResult(short token) => _waitSource.GetResult(token);
+ void IValueTaskSource.GetResult(short token)
+ {
+ Debug.Assert(!Monitor.IsEntered(SyncObject));
+
+ // Clean up the registration. It's important to Dispose rather than Unregister, so that we wait
+ // for any in-flight cancellation to complete.
+ _waitSourceCancellation.Dispose();
+ _waitSourceCancellation = default;
+
+ // Propagate any exceptions if there were any.
+ _waitSource.GetResult(token);
+ }
private void WaitForData()
{
// Reset'ing it is this code here. It's possible for this to race with the _waitSource being completed, but that's ok and is
// handled by _waitSource as one of its primary purposes. We can't assert _hasWaiter here, though, as once we released the
// lock, a producer could have seen _hasWaiter as true and both set it to false and signaled _waitSource.
- if (!cancellationToken.CanBeCanceled)
- {
- return new ValueTask(this, _waitSource.Version);
- }
// With HttpClient, the supplied cancellation token will always be cancelable, as HttpClient supplies a token that
// will have cancellation requested if CancelPendingRequests is called (or when a non-infinite Timeout expires).
// However, this could still be non-cancelable if HttpMessageInvoker was used, at which point this will only be
- // cancelable if the caller's token was cancelable. To avoid the extra allocation here in such a case, we make
- // this pay-for-play: if the token isn't cancelable, return a ValueTask wrapping this object directly, and only
- // if it is cancelable, then register for the cancellation callback, allocate a task for the asynchronously
- // completing case, etc.
- return GetCancelableWaiterTask(cancellationToken);
+ // cancelable if the caller's token was cancelable.
- async ValueTask GetCancelableWaiterTask(CancellationToken cancellationToken)
+ _waitSourceCancellation = cancellationToken.UnsafeRegister(s =>
{
- using (cancellationToken.UnsafeRegister(s =>
- {
- var thisRef = (Http2Stream)s!;
+ var thisRef = (Http2Stream)s!;
- bool signalWaiter;
- Debug.Assert(!Monitor.IsEntered(thisRef.SyncObject));
- lock (thisRef.SyncObject)
- {
- signalWaiter = thisRef._hasWaiter;
- thisRef._hasWaiter = false;
- }
+ bool signalWaiter;
+ Debug.Assert(!Monitor.IsEntered(thisRef.SyncObject));
+ lock (thisRef.SyncObject)
+ {
+ signalWaiter = thisRef._hasWaiter;
+ thisRef._hasWaiter = false;
+ }
- if (signalWaiter)
- {
- // Wake up the wait. It will then immediately check whether cancellation was requested and throw if it was.
- thisRef._waitSource.SetResult(true);
- }
- }, this))
+ if (signalWaiter)
{
- await new ValueTask(this, _waitSource.Version).ConfigureAwait(false);
+ // Wake up the wait. It will then immediately check whether cancellation was requested and throw if it was.
+ thisRef._waitSource.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(
+ CancellationHelper.CreateOperationCanceledException(null, _waitSourceCancellation.Token)));
}
+ }, this);
- CancellationHelper.ThrowIfCancellationRequested(cancellationToken);
- }
+ // There's a race condition in UnsafeRegister above. If cancellation is requested prior to UnsafeRegister,
+ // the delegate may be invoked synchronously as part of the UnsafeRegister call. In that case, it will execute
+ // before _waitSourceCancellation has been set, which means UnsafeRegister will have set a cancellation
+ // exception into the wait source with a default token rather than the ideal one. To handle that,
+ // we check for cancellation again, and throw here with the right token. Worst case, if cancellation is
+ // requested prior to here, we end up allocating an extra OCE object.
+ CancellationHelper.ThrowIfCancellationRequested(cancellationToken);
+
+ return new ValueTask(this, _waitSource.Version);
}
public void Trace(string message, [CallerMemberName] string? memberName = null) =>
/// <summary>Awaits a task, ignoring any resulting exceptions.</summary>
internal static void IgnoreExceptions(ValueTask<int> task)
{
- _ = IgnoreExceptionsAsync(task);
-
- static async Task IgnoreExceptionsAsync(ValueTask<int> task)
+ // Avoid TaskScheduler.UnobservedTaskException firing for any exceptions.
+ if (task.IsCompleted)
{
- try { await task.ConfigureAwait(false); } catch { }
+ if (task.IsFaulted)
+ {
+ _ = task.AsTask().Exception;
+ }
}
- }
-
- /// <summary>Awaits a task, ignoring any resulting exceptions.</summary>
- internal static void IgnoreExceptions(Task task)
- {
- _ = IgnoreExceptionsAsync(task);
-
- static async Task IgnoreExceptionsAsync(Task task)
+ else
{
- try { await task.ConfigureAwait(false); } catch { }
+ task.AsTask().ContinueWith(t => _ = t.Exception,
+ CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Default);
}
}
/// <summary>Awaits a task, logging any resulting exceptions (which are otherwise ignored).</summary>
- internal void LogExceptions(Task task)
- {
- _ = LogExceptionsAsync(task);
-
- async Task LogExceptionsAsync(Task task)
+ internal void LogExceptions(Task task) =>
+ task.ContinueWith(t =>
{
- try
- {
- await task.ConfigureAwait(false);
- }
- catch (Exception e)
- {
- if (NetEventSource.IsEnabled) Trace($"Exception from asynchronous processing: {e}");
- }
- }
- }
+ Exception? e = t.Exception?.InnerException; // Access Exception even if not tracing, to avoid TaskScheduler.UnobservedTaskException firing
+ if (NetEventSource.IsEnabled) Trace($"Exception from asynchronous processing: {e}");
+ }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Default);
}
}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Concurrent;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Runtime.ExceptionServices;
+using System.Threading.Tasks;
+using System.Threading.Tasks.Sources;
+
+namespace System.Threading
+{
+ /// <summary>Provides an async mutex.</summary>
+ /// <remarks>
+ /// This could be achieved with a <see cref="SemaphoreSlim"/> constructed with an initial
+ /// and max limit of 1. However, this implementation is optimized to the needs of HTTP/2,
+ /// where the mutex is held for a very short period of time, when it is held any other
+ /// attempts to access it must wait asynchronously, where it's only binary rather than counting, and where
+ /// we want to minimize contention that a releaser incurs while trying to unblock a waiter. The primary
+ /// value-add is the fast-path interlocked checks that minimize contention for these use cases (essentially
+ /// making it an async futex), and then as long as we're wrapping something and we know exactly how all
+ /// consumers use the type, we can offer a ValueTask-based implementation that reuses waiter nodes.
+ /// </remarks>
+ internal sealed class AsyncMutex
+ {
+ /// <summary>Fast-path gate count tracking access to the mutex.</summary>
+ /// <remarks>
+ /// If the value is 1, the mutex can be entered atomically with an interlocked operation.
+ /// If the value is less than or equal to 0, the mutex is held and requires fallback to enter it.
+ /// </remarks>
+ private int _gate = 1;
+ /// <summary>Secondary check guarded by the lock to indicate whether the mutex is acquired.</summary>
+ /// <remarks>
+ /// This is only meaningful after having updated <see cref="_gate"/> via interlockeds and taken the appropriate path.
+ /// If after decrementing <see cref="_gate"/> we end up with a negative count, the mutex is contended, hence
+ /// <see cref="_lockedSemaphoreFull"/> starting as <c>true</c>. The primary purpose of this field
+ /// is to handle the race condition between one thread acquiring the mutex, then another thread trying to acquire
+ /// and getting as far as completing the interlocked operation, and then the original thread releasing; at that point
+ /// it'll hit the lock and we need to store that the mutex is available to enter. If we instead used a
+ /// SemaphoreSlim as the fallback from the interlockeds, this would have been its count, and it would have started
+ /// with an initial count of 0.
+ /// </remarks>
+ private bool _lockedSemaphoreFull = true;
+ /// <summary>The head of the double-linked waiting queue. Waiters are dequeued from the head.</summary>
+ private Waiter? _waitersHead;
+ /// <summary>The tail of the double-linked waiting queue. Waiters are added at the tail.</summary>
+ private Waiter? _waitersTail;
+ /// <summary>A pool of waiter objects that are ready to be reused.</summary>
+ /// <remarks>
+ /// There is no bound on this pool, but it ends up being implicitly bounded by the maximum number of concurrent
+ /// waiters there ever were, which for our uses in HTTP/2 will end up being the high-water mark of concurrent streams
+ /// on a single connection.
+ /// </remarks>
+ private readonly ConcurrentQueue<Waiter> _unusedWaiters = new ConcurrentQueue<Waiter>();
+
+ /// <summary>Gets whether the mutex is currently held by some operation (not necessarily the caller).</summary>
+ /// <remarks>This should be used only for asserts and debugging.</remarks>
+ public bool IsHeld => _gate != 1;
+
+ /// <summary>Objects used to synchronize operations on the instance.</summary>
+ private object SyncObj => _unusedWaiters;
+
+ /// <summary>Asynchronously waits to enter the mutex.</summary>
+ /// <param name="cancellationToken">The CancellationToken token to observe.</param>
+ /// <returns>A task that will complete when the mutex has been entered or the enter canceled.</returns>
+ public ValueTask EnterAsync(CancellationToken cancellationToken)
+ {
+ // If cancellation was requested, bail immediately.
+ // If the mutex is not currently held nor contended, enter immediately.
+ // Otherwise, fall back to a more expensive likely-asynchronous wait.
+ return
+ cancellationToken.IsCancellationRequested ? FromCanceled(cancellationToken) :
+ Interlocked.Decrement(ref _gate) >= 0 ? default :
+ Contended(cancellationToken);
+
+ // Everything that follows is the equivalent of:
+ // return _sem.WaitAsync(cancellationToken);
+ // if _sem were to be constructed as `new SemaphoreSlim(0)`.
+
+ ValueTask Contended(CancellationToken cancellationToken)
+ {
+ // Get a reusable waiter object. We do this before the lock to minimize work (and especially allocation)
+ // done while holding the lock. It's possible we'll end up dequeuing a waiter and then under the lock
+ // discovering the mutex is now available, at which point we will have wasted an object. That's currently
+ // showing to be the better alternative (including not trying to put it back in that case).
+ if (!_unusedWaiters.TryDequeue(out Waiter? w))
+ {
+ w = new Waiter(this);
+ }
+
+ lock (SyncObj)
+ {
+ // Now that we're holding the lock, check to see whether the async lock is acquirable.
+ if (!_lockedSemaphoreFull)
+ {
+ _lockedSemaphoreFull = true;
+ return default;
+ }
+ else
+ {
+ // Add it to the linked list of waiters.
+ if (_waitersTail is null)
+ {
+ Debug.Assert(_waitersHead is null);
+ _waitersTail = _waitersHead = w;
+ }
+ else
+ {
+ Debug.Assert(_waitersHead != null);
+ w.Prev = _waitersTail;
+ _waitersTail.Next = w;
+ _waitersTail = w;
+ }
+ }
+ }
+
+ // At this point the waiter was added to the list of waiters, so we want to
+ // register for cancellation in order to cancel it and remove it from the list
+ // if cancellation is requested. However, since we've released the lock, it's
+ // possible the waiter could have actually already been completed and removed
+ // from the list by another thread releasing the mutex. That's ok; we'll
+ // end up registering for cancellation here, and then when the consumer awaits
+ // it, the act of awaiting it will Dispose of the registration, ensuring that
+ // it won't run after that point, making it safe to pool that instance.
+ w.CancellationRegistration = cancellationToken.UnsafeRegister(s => OnCancellation(s), w);
+
+ // Return the waiter as a value task.
+ return new ValueTask(w, w.Version);
+
+ // Cancels the specified waiter if it's still in the list.
+ static void OnCancellation(object? state)
+ {
+ Waiter? w = (Waiter)state!;
+ AsyncMutex m = w.Owner;
+
+ lock (m.SyncObj)
+ {
+ bool inList = w.Next != null || w.Prev != null || m._waitersHead == w;
+ if (inList)
+ {
+ // The waiter was still in the list.
+ Debug.Assert(
+ m._waitersHead == w ||
+ (m._waitersTail == w && w.Prev != null && w.Next is null) ||
+ (w.Next != null && w.Prev != null));
+
+ // The gate counter was decremented when this waiter was added. We need
+ // to undo that. Since the waiter is still in the list, the lock must
+ // still be held by someone, which means we don't need to do anything with
+ // the result of this increment. If it increments to < 1, then there are
+ // still other waiters. If it increments to 1, we're in a rare race condition
+ // where there are no other waiters and the owner just incremented the gate
+ // count; they would have seen it be < 1, so they will proceed to take the
+ // contended code path and synchronize on the lock we're holding... once we
+ // release it, they will appropriately update state.
+ Interlocked.Increment(ref m._gate);
+
+ // Remove it from the list.
+ if (m._waitersHead == w && m._waitersTail == w)
+ {
+ // It's the only node in the list.
+ m._waitersHead = m._waitersTail = null;
+ }
+ else if (m._waitersTail == w)
+ {
+ // It's the most recently queued item in the list.
+ m._waitersTail = w.Prev;
+ Debug.Assert(m._waitersTail != null);
+ m._waitersTail.Next = null;
+ }
+ else if (m._waitersHead == w)
+ {
+ // It's the next item to be removed from the list.
+ m._waitersHead = w.Next;
+ Debug.Assert(m._waitersHead != null);
+ m._waitersHead.Prev = null;
+ }
+ else
+ {
+ // It's in the middle of the list.
+ Debug.Assert(w.Next != null);
+ Debug.Assert(w.Prev != null);
+ w.Next.Prev = w.Prev;
+ w.Prev.Next = w.Next;
+ }
+
+ // Remove it from the list.
+ w.Next = w.Prev = null;
+ }
+ else
+ {
+ // The waiter was no longer in the list. We must not cancel it.
+ w = null;
+ }
+ }
+
+ // If the waiter was in the list, we removed it under the lock and thus own
+ // the ability to cancel it. Do so.
+ w?.Cancel();
+ }
+ }
+ }
+
+ /// <summary>Releases the mutex.</summary>
+ /// <remarks>The caller must logically own the mutex. This is not validated.</remarks>
+ public void Exit()
+ {
+ if (Interlocked.Increment(ref _gate) < 1)
+ {
+ // This is the equivalent of:
+ // _sem.Release();
+ // if _sem were to be constructed as `new SemaphoreSlim(0)`.
+ Contended();
+ }
+
+ void Contended()
+ {
+ Waiter? w;
+
+ lock (SyncObj)
+ {
+ Debug.Assert(_lockedSemaphoreFull);
+
+ // Wake up the next waiter in the list.
+ w = _waitersHead;
+ if (w != null)
+ {
+ // Remove the waiter.
+ _waitersHead = w.Next;
+ if (w.Next != null)
+ {
+ w.Next.Prev = null;
+ }
+ else
+ {
+ Debug.Assert(_waitersTail == w);
+ _waitersTail = null;
+ }
+ w.Next = w.Prev = null;
+ }
+ else
+ {
+ // There wasn't a waiter. Mark that the async lock is no longer full.
+ Debug.Assert(_waitersTail is null);
+ _lockedSemaphoreFull = false;
+ }
+ }
+
+ // Either there wasn't a waiter, or we got one and successfully removed it from the list,
+ // at which point we own the ability to complete it. Do so.
+ w?.Set();
+ }
+ }
+
+ /// <summary>Creates a canceled ValueTask.</summary>
+ /// <remarks>Separated out to reduce asm for this rare path in the call site.</remarks>
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private static ValueTask FromCanceled(CancellationToken cancellationToken) =>
+ new ValueTask(Task.FromCanceled(cancellationToken));
+
+ /// <summary>Represents a waiter for the mutex.</summary>
+ /// <remarks>Implemented as a reusable backing source for a value task.</remarks>
+ private sealed class Waiter : IValueTaskSource
+ {
+ private ManualResetValueTaskSourceCore<bool> _mrvtsc; // mutable struct; do not make this readonly
+
+ public Waiter(AsyncMutex owner)
+ {
+ Owner = owner;
+ _mrvtsc.RunContinuationsAsynchronously = true;
+ }
+
+ public AsyncMutex Owner { get; }
+ public CancellationTokenRegistration CancellationRegistration { get; set; }
+ public Waiter? Next { get; set; }
+ public Waiter? Prev { get; set; }
+
+ public short Version => _mrvtsc.Version;
+
+ public void Set() => _mrvtsc.SetResult(true);
+ public void Cancel() => _mrvtsc.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(CancellationRegistration.Token)));
+
+ void IValueTaskSource.GetResult(short token)
+ {
+ Debug.Assert(Next is null && Prev is null);
+
+ // Dispose of the registration. It's critical that this Dispose rather than Unregister,
+ // so that we can be guaranteed all cancellation-related work has completed by the time
+ // we return the instance to the pool. Otherwise, a race condition could result in
+ // a cancellation request for this operation canceling another unlucky request that
+ // happened to reuse the same node.
+ Debug.Assert(!Monitor.IsEntered(Owner.SyncObj));
+ CancellationRegistration.Dispose();
+
+ // Complete the operation, propagating any exceptions.
+ _mrvtsc.GetResult(token);
+
+ // Reset the instance and return it to the pool.
+ // We don't bother with a try/finally to return instances
+ // to the pool in the case of exceptions.
+ _mrvtsc.Reset();
+ Owner._unusedWaiters.Enqueue(this);
+ }
+
+ public ValueTaskSourceStatus GetStatus(short token) =>
+ _mrvtsc.GetStatus(token);
+
+ public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) =>
+ _mrvtsc.OnCompleted(continuation, state, token, flags);
+ }
+ }
+}
[InlineData("Expect-CT")]
[InlineData("Expires")]
[InlineData("From")]
+ [InlineData("grpc-encoding")]
+ [InlineData("grpc-message")]
+ [InlineData("grpc-status")]
[InlineData("Host")]
[InlineData("If-Match")]
[InlineData("If-Modified-Since")]
Assert.Null(KnownHeaders.TryGetKnownHeader(casedName.Select(c => (byte)c).ToArray()));
}
}
+
+ [Theory]
+ [InlineData("Access-Control-Allow-Credentials", "true")]
+ [InlineData("Access-Control-Allow-Headers", "*")]
+ [InlineData("Access-Control-Allow-Methods", "*")]
+ [InlineData("Access-Control-Allow-Origin", "*")]
+ [InlineData("Access-Control-Allow-Origin", "null")]
+ [InlineData("Access-Control-Expose-Headers", "*")]
+ [InlineData("Cache-Control", "must-revalidate")]
+ [InlineData("Cache-Control", "no-cache")]
+ [InlineData("Cache-Control", "no-store")]
+ [InlineData("Cache-Control", "no-transform")]
+ [InlineData("Cache-Control", "private")]
+ [InlineData("Cache-Control", "proxy-revalidate")]
+ [InlineData("Cache-Control", "public")]
+ [InlineData("Connection", "close")]
+ [InlineData("Content-Disposition", "attachment")]
+ [InlineData("Content-Disposition", "inline")]
+ [InlineData("Content-Encoding", "gzip")]
+ [InlineData("Content-Encoding", "deflate")]
+ [InlineData("Content-Encoding", "br")]
+ [InlineData("Content-Encoding", "compress")]
+ [InlineData("Content-Encoding", "identity")]
+ [InlineData("Content-Type", "text/xml")]
+ [InlineData("Content-Type", "text/css")]
+ [InlineData("Content-Type", "text/csv")]
+ [InlineData("Content-Type", "image/gif")]
+ [InlineData("Content-Type", "image/png")]
+ [InlineData("Content-Type", "text/html")]
+ [InlineData("Content-Type", "text/plain")]
+ [InlineData("Content-Type", "image/jpeg")]
+ [InlineData("Content-Type", "application/pdf")]
+ [InlineData("Content-Type", "application/xml")]
+ [InlineData("Content-Type", "application/zip")]
+ [InlineData("Content-Type", "application/grpc")]
+ [InlineData("Content-Type", "application/json")]
+ [InlineData("Content-Type", "multipart/form-data")]
+ [InlineData("Content-Type", "application/javascript")]
+ [InlineData("Content-Type", "application/octet-stream")]
+ [InlineData("Content-Type", "text/html; charset=utf-8")]
+ [InlineData("Content-Type", "text/plain; charset=utf-8")]
+ [InlineData("Content-Type", "application/json; charset=utf-8")]
+ [InlineData("Content-Type", "application/x-www-form-urlencoded")]
+ [InlineData("Expect", "100-continue")]
+ [InlineData("grpc-encoding", "identity")]
+ [InlineData("grpc-encoding", "gzip")]
+ [InlineData("grpc-encoding", "deflate")]
+ [InlineData("grpc-status", "0")]
+ [InlineData("Pragma", "no-cache")]
+ [InlineData("Referrer-Policy", "strict-origin-when-cross-origin")]
+ [InlineData("Referrer-Policy", "origin-when-cross-origin")]
+ [InlineData("Referrer-Policy", "strict-origin")]
+ [InlineData("Referrer-Policy", "origin")]
+ [InlineData("Referrer-Policy", "same-origin")]
+ [InlineData("Referrer-Policy", "no-referrer-when-downgrade")]
+ [InlineData("Referrer-Policy", "no-referrer")]
+ [InlineData("Referrer-Policy", "unsafe-url")]
+ [InlineData("TE", "trailers")]
+ [InlineData("TE", "compress")]
+ [InlineData("TE", "deflate")]
+ [InlineData("TE", "gzip")]
+ [InlineData("Transfer-Encoding", "chunked")]
+ [InlineData("Transfer-Encoding", "compress")]
+ [InlineData("Transfer-Encoding", "deflate")]
+ [InlineData("Transfer-Encoding", "gzip")]
+ [InlineData("Transfer-Encoding", "identity")]
+ [InlineData("Upgrade-Insecure-Requests", "1")]
+ [InlineData("Vary", "*")]
+ [InlineData("X-Content-Type-Options", "nosniff")]
+ [InlineData("X-Frame-Options", "DENY")]
+ [InlineData("X-Frame-Options", "SAMEORIGIN")]
+ [InlineData("X-XSS-Protection", "0")]
+ [InlineData("X-XSS-Protection", "1")]
+ [InlineData("X-XSS-Protection", "1; mode=block")]
+ public void GetKnownHeaderValue_Known_Found(string name, string value)
+ {
+ foreach (string casedValue in new[] { value, value.ToUpperInvariant(), value.ToLowerInvariant() })
+ {
+ Validate(KnownHeaders.TryGetKnownHeader(name), casedValue);
+ }
+
+ static void Validate(KnownHeader knownHeader, string value)
+ {
+ Assert.NotNull(knownHeader);
+
+ string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray());
+ Assert.NotNull(v1);
+ Assert.Equal(value, v1, StringComparer.OrdinalIgnoreCase);
+
+ string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray());
+ Assert.Same(v1, v2);
+ }
+ }
+
+ [Theory]
+ [InlineData("Content-Type", "application/jsot")]
+ [InlineData("Content-Type", "application/jsons")]
+ public void GetKnownHeaderValue_Unknown_NotFound(string name, string value)
+ {
+ KnownHeader knownHeader = KnownHeaders.TryGetKnownHeader(name);
+ Assert.NotNull(knownHeader);
+
+ string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray());
+ string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray());
+ Assert.Equal(value, v1);
+ Assert.Equal(value, v2);
+ Assert.NotSame(v1, v2);
+ }
}
}