From cd8355bd2aafff1776df9e009940d91697752152 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 7 May 2020 06:08:20 -0400 Subject: [PATCH] Improve HTTP/2 scaling performance (#35694) * Rearrange Http2Connection.SendHeadersAsync to increase scale Many concurrent streams are currently bottlenecking in SendHeadersAsync. They: - take the headers serialization lock - request a stream credit - serialize all their headers to a buffer stored on the connection - take the write lock - write out the headers - release the write lock - release the headers serialization lock Instead of using a header buffer pooled on the connection, we can instead temporarily grab ArrayPool buffers into which we serialize the headers, which allows us to eliminate the headers serialization lock. With that, we get: - request a stream credit - serialize the headers to a pooled buffer - take the write lock - write out the headers - release the write lock This has a significant impact on the ability for many concurrent streams to scale, as all of the header serialization work happens outside of the lock. * Use ContinueWith instead of async/await in Ignore/LogExceptions The current implementation will always invoke the continuation, even in the majority case where there's no failure (and if there is a failure, it's cheaper not to throw it again). We can avoid that and decrease both overhead and allocation in the common case. * Remove CopyToAsync wrapping task from SendRequestBodyAsync * Reduce contention on semaphore locks We're seeing a lot of contention trying to acquire the monitors inside of semaphore slims, even when the contention is to release the semaphore, which should entail minimal contention (and delay of a release will just cause more contention). We can use a small interlocked gate to add a fast path. * Stop clearing ArrayBuffer's pooled byte arrays We already weren't clearing as the array grew, nor are we clearing in most other places in the library. It's not clear why we were clearing in this one spot, but the zero'ing is showing up meaningfully in profiles. * Add several gRPC known headers * Avoid dictionary lookup in HttpMethod.Normalize * Avoid allocating strings for known Content-Types * Combine AcquireWriteLockAsync into StartWriteAsync There's not a good reason to keep them separate, and StartWriteAsync is AcquireWriteLockAsync's only caller. * Remove GetCancelableWaiterTask async method We can achieve the same thing without the extra async method by putting the cleanup logic into the awaiter's GetResult. * Address PR feedback --- src/libraries/Common/src/System/Net/ArrayBuffer.cs | 2 +- .../System.Net.Http/src/System.Net.Http.csproj | 3 +- .../System/Net/Http/Headers/HeaderDescriptor.cs | 94 +++++- .../src/System/Net/Http/Headers/HttpHeaders.cs | 5 +- .../src/System/Net/Http/Headers/KnownHeaders.cs | 27 +- .../src/System/Net/Http/HttpContent.cs | 44 +-- .../src/System/Net/Http/HttpMethod.cs | 7 +- .../Net/Http/SocketsHttpHandler/Http2Connection.cs | 329 +++++++++++---------- .../Net/Http/SocketsHttpHandler/Http2Stream.cs | 70 +++-- .../Http/SocketsHttpHandler/HttpConnectionBase.cs | 42 +-- .../src/System/Threading/AsyncMutex.cs | 313 ++++++++++++++++++++ .../tests/UnitTests/Headers/KnownHeadersTest.cs | 111 +++++++ 12 files changed, 801 insertions(+), 246 deletions(-) create mode 100644 src/libraries/System.Net.Http/src/System/Threading/AsyncMutex.cs diff --git a/src/libraries/Common/src/System/Net/ArrayBuffer.cs b/src/libraries/Common/src/System/Net/ArrayBuffer.cs index ca1a1c4..fc0bbeb 100644 --- a/src/libraries/Common/src/System/Net/ArrayBuffer.cs +++ b/src/libraries/Common/src/System/Net/ArrayBuffer.cs @@ -50,7 +50,7 @@ namespace System.Net if (array != null) { - ArrayPool.Shared.Return(array, true); + ArrayPool.Shared.Return(array); } } } diff --git a/src/libraries/System.Net.Http/src/System.Net.Http.csproj b/src/libraries/System.Net.Http/src/System.Net.Http.csproj index dbdc3aa..0744c3d 100644 --- a/src/libraries/System.Net.Http/src/System.Net.Http.csproj +++ b/src/libraries/System.Net.Http/src/System.Net.Http.csproj @@ -89,6 +89,7 @@ + + Link="Common\System\Threading\Tasks\TaskToApm.cs" /> contentTypeValue) + { + string? candidate = null; + switch (contentTypeValue.Length) + { + case 8: + switch (contentTypeValue[7] | 0x20) + { + case 'l': candidate = "text/xml"; break; // text/xm[l] + case 's': candidate = "text/css"; break; // text/cs[s] + case 'v': candidate = "text/csv"; break; // text/cs[v] + } + break; + + case 9: + switch (contentTypeValue[6] | 0x20) + { + case 'g': candidate = "image/gif"; break; // image/[g]if + case 'p': candidate = "image/png"; break; // image/[p]ng + case 't': candidate = "text/html"; break; // text/h[t]ml + } + break; + + case 10: + switch (contentTypeValue[0] | 0x20) + { + case 't': candidate = "text/plain"; break; // [t]ext/plain + case 'i': candidate = "image/jpeg"; break; // [i]mage/jpeg + } + break; + + case 15: + switch (contentTypeValue[12] | 0x20) + { + case 'p': candidate = "application/pdf"; break; // application/[p]df + case 'x': candidate = "application/xml"; break; // application/[x]ml + case 'z': candidate = "application/zip"; break; // application/[z]ip + } + break; + + case 16: + switch (contentTypeValue[12] | 0x20) + { + case 'g': candidate = "application/grpc"; break; // application/[g]rpc + case 'j': candidate = "application/json"; break; // application/[j]son + } + break; + + case 19: + candidate = "multipart/form-data"; // multipart/form-data + break; + + case 22: + candidate = "application/javascript"; // application/javascript + break; + + case 24: + switch (contentTypeValue[0] | 0x20) + { + case 'a': candidate = "application/octet-stream"; break; // application/octet-stream + case 't': candidate = "text/html; charset=utf-8"; break; // text/html; charset=utf-8 + } + break; + + case 25: + candidate = "text/plain; charset=utf-8"; // text/plain; charset=utf-8 + break; + + case 31: + candidate = "application/json; charset=utf-8"; // application/json; charset=utf-8 + break; + + case 33: + candidate = "application/x-www-form-urlencoded"; // application/x-www-form-urlencoded + break; + } + + Debug.Assert(candidate is null || candidate.Length == contentTypeValue.Length); + + return candidate != null && ByteArrayHelpers.EqualsOrdinalAsciiIgnoreCase(candidate, contentTypeValue) ? + candidate : + null; + } + private static bool TryDecodeUtf8(ReadOnlySpan input, [NotNullWhen(true)] out string? decoded) { char[] rented = ArrayPool.Shared.Rent(input.Length); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index ac0849f..840a55a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -1194,8 +1194,10 @@ namespace System.Net.Http.Headers return values; } - internal static int GetValuesAsStrings(HeaderDescriptor descriptor, object sourceValues, ref string[] values) + internal static int GetValuesAsStrings(HeaderDescriptor descriptor, object sourceValues, [NotNull] ref string[]? values) { + values ??= Array.Empty(); + HeaderStoreItemInfo? info = sourceValues as HeaderStoreItemInfo; if (info is null) { @@ -1210,7 +1212,6 @@ namespace System.Net.Http.Headers return 1; } - Debug.Assert(values != null); int length = GetValueCount(info); if (length > 0) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeaders.cs index 34163e9..b8f9b98 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeaders.cs @@ -49,6 +49,9 @@ namespace System.Net.Http.Headers public static readonly KnownHeader ExpectCT = new KnownHeader("Expect-CT"); public static readonly KnownHeader Expires = new KnownHeader("Expires", HttpHeaderType.Content | HttpHeaderType.NonTrailing, DateHeaderParser.Parser, null, H2StaticTable.Expires); public static readonly KnownHeader From = new KnownHeader("From", HttpHeaderType.Request, GenericHeaderParser.MailAddressParser, null, H2StaticTable.From); + public static readonly KnownHeader GrpcEncoding = new KnownHeader("grpc-encoding", HttpHeaderType.Custom, null, new string[] { "identity", "gzip", "deflate" }); + public static readonly KnownHeader GrpcMessage = new KnownHeader("grpc-message"); + public static readonly KnownHeader GrpcStatus = new KnownHeader("grpc-status", HttpHeaderType.Custom, null, new string[] { "0" }); public static readonly KnownHeader Host = new KnownHeader("Host", HttpHeaderType.Request | HttpHeaderType.NonTrailing, GenericHeaderParser.HostParser, null, H2StaticTable.Host); public static readonly KnownHeader IfMatch = new KnownHeader("If-Match", HttpHeaderType.Request | HttpHeaderType.NonTrailing, GenericHeaderParser.MultipleValueEntityTagParser, null, H2StaticTable.IfMatch); public static readonly KnownHeader IfModifiedSince = new KnownHeader("If-Modified-Since", HttpHeaderType.Request | HttpHeaderType.NonTrailing, DateHeaderParser.Parser, null, H2StaticTable.IfModifiedSince, H3StaticTable.IfModifiedSince); @@ -250,20 +253,22 @@ namespace System.Net.Http.Headers switch (key[0] | 0x20) { case 'c': return ContentMD5; // [C]ontent-MD5 + case 'g': return GrpcStatus; // [g]rpc-status case 'r': return RetryAfter; // [R]etry-After case 's': return SetCookie2; // [S]et-Cookie2 } break; case 12: - switch (key[2] | 0x20) + switch (key[5] | 0x20) { - case 'c': return AcceptPatch; // Ac[c]ept-Patch - case 'm': return XMSEdgeRef; // X-[M]SEdge-Ref - case 'n': return ContentType; // Co[n]tent-Type - case 'p': return XPoweredBy; // X-[P]owered-By - case 'r': return XRequestID; // X-[R]equest-ID - case 'x': return MaxForwards; // Ma[x]-Forwards + case 'd': return XMSEdgeRef; // X-MSE[d]ge-Ref + case 'e': return XPoweredBy; // X-Pow[e]red-By + case 'm': return GrpcMessage; // grpc-[m]essage + case 'n': return ContentType; // Conte[n]t-Type + case 'o': return MaxForwards; // Max-F[o]rwards + case 't': return AcceptPatch; // Accep[t]-Patch + case 'u': return XRequestID; // X-Req[u]est-ID } break; @@ -272,7 +277,13 @@ namespace System.Net.Http.Headers { case 'd': return LastModified; // Last-Modifie[d] case 'e': return ContentRange; // Content-Rang[e] - case 'g': return ServerTiming; // Server-Timin[g] + case 'g': + switch (key[0] | 0x20) + { + case 's': return ServerTiming; // [S]erver-Timin[g] + case 'g': return GrpcEncoding; // [g]rpc-encodin[g] + } + break; case 'h': return IfNoneMatch; // If-None-Matc[h] case 'l': return CacheControl; // Cache-Contro[l] case 'n': return Authorization; // Authorizatio[n] diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpContent.cs index 9575952..23abf0a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpContent.cs @@ -346,41 +346,41 @@ namespace System.Net.Http try { - ArraySegment buffer; - if (TryGetBuffer(out buffer)) - { - return CopyToAsyncCore(stream.WriteAsync(new ReadOnlyMemory(buffer.Array, buffer.Offset, buffer.Count), cancellationToken)); - } - else - { - Task task = SerializeToStreamAsync(stream, context, cancellationToken); - CheckTaskNotNull(task); - return CopyToAsyncCore(new ValueTask(task)); - } + return WaitAsync(InternalCopyToAsync(stream, context, cancellationToken)); } catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e)) { return Task.FromException(GetStreamCopyException(e)); } - } - private static async Task CopyToAsyncCore(ValueTask copyTask) - { - try + static async Task WaitAsync(ValueTask copyTask) { - await copyTask.ConfigureAwait(false); - } - catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e)) - { - throw WrapStreamCopyException(e); + try + { + await copyTask.ConfigureAwait(false); + } + catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e)) + { + throw WrapStreamCopyException(e); + } } } - public Task LoadIntoBufferAsync() + internal ValueTask InternalCopyToAsync(Stream stream, TransportContext? context, CancellationToken cancellationToken) { - return LoadIntoBufferAsync(MaxBufferSize); + if (TryGetBuffer(out ArraySegment buffer)) + { + return stream.WriteAsync(buffer, cancellationToken); + } + + Task task = SerializeToStreamAsync(stream, context, cancellationToken); + CheckTaskNotNull(task); + return new ValueTask(task); } + public Task LoadIntoBufferAsync() => + LoadIntoBufferAsync(MaxBufferSize); + // No "CancellationToken" parameter needed since canceling the CTS will close the connection, resulting // in an exception being thrown while we're buffering. // If buffering is used without a connection, it is supposed to be fast, thus no cancellation required. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpMethod.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpMethod.cs index fef238f..44b1e4f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpMethod.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpMethod.cs @@ -176,9 +176,12 @@ namespace System.Net.Http /// internal static HttpMethod Normalize(HttpMethod method) { + // _http3EncodedBytes is only set for the singleton instances, so if it's not null, + // we can avoid the dictionary lookup. Otherwise, look up the method instance in the + // dictionary and return the normalized instance if it's found. Debug.Assert(method != null); - return s_knownMethods.TryGetValue(method, out HttpMethod? normalized) ? - normalized : + return + method._http3EncodedBytes is null && s_knownMethods.TryGetValue(method, out HttpMethod? normalized) ? normalized : method; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 9246892..00b9bc1 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -24,10 +24,10 @@ namespace System.Net.Http // NOTE: These are mutable structs; do not make these readonly. private ArrayBuffer _incomingBuffer; private ArrayBuffer _outgoingBuffer; - private ArrayBuffer _headerBuffer; /// Reusable array used to get the values for each header being written to the wire. - private string[] _headerValues = Array.Empty(); + [ThreadStatic] + private static string[]? t_headerValues; private int _currentWriteSize; // as passed to StartWriteAsync @@ -35,9 +35,7 @@ namespace System.Net.Http private readonly Dictionary _httpStreams; - private readonly SemaphoreSlim _writerLock; - private readonly SemaphoreSlim _headerSerializationLock; - + private readonly AsyncMutex _writerLock; private readonly CreditManager _connectionWindow; private readonly CreditManager _concurrentStreams; @@ -106,14 +104,12 @@ namespace System.Net.Http _stream = stream; _incomingBuffer = new ArrayBuffer(InitialConnectionBufferSize); _outgoingBuffer = new ArrayBuffer(InitialConnectionBufferSize); - _headerBuffer = new ArrayBuffer(InitialConnectionBufferSize); _hpackDecoder = new HPackDecoder(maxHeadersLength: pool.Settings._maxResponseHeadersLength * 1024); _httpStreams = new Dictionary(); - _writerLock = new SemaphoreSlim(1, 1); - _headerSerializationLock = new SemaphoreSlim(1, 1); + _writerLock = new AsyncMutex(); _connectionWindow = new CreditManager(this, nameof(_connectionWindow), DefaultInitialWindowSize); _concurrentStreams = new CreditManager(this, nameof(_concurrentStreams), int.MaxValue); @@ -765,27 +761,69 @@ namespace System.Net.Http private async ValueTask> StartWriteAsync(int writeBytes, CancellationToken cancellationToken = default) { if (NetEventSource.IsEnabled) Trace($"{nameof(writeBytes)}={writeBytes}"); - await AcquireWriteLockAsync(cancellationToken).ConfigureAwait(false); + // Acquire the write lock + ValueTask acquireLockTask = _writerLock.EnterAsync(cancellationToken); + if (acquireLockTask.IsCompletedSuccessfully) + { + acquireLockTask.GetAwaiter().GetResult(); // to enable the value task sources to be pooled + } + else + { + Interlocked.Increment(ref _pendingWriters); + try + { + await acquireLockTask.ConfigureAwait(false); + } + catch + { + if (Interlocked.Decrement(ref _pendingWriters) == 0) + { + // If a pending waiter is canceled, we may end up in a situation where a previously written frame + // saw that there were pending writers and as such deferred its flush to them, but if/when that pending + // writer is canceled, nothing may end up flushing the deferred work (at least not promptly). To compensate, + // if a pending writer does end up being canceled, we flush asynchronously. We can't check whether there's such + // a pending operation because we failed to acquire the lock that protects that state. But we can at least only + // do the flush if our decrement caused the pending count to reach 0: if it's still higher than zero, then there's + // at least one other pending writer who can handle the flush. Worst case, we pay for a flush that ends up being + // a nop. Note: we explicitly do not pass in the cancellationToken; if we're here, it's almost certainly because + // cancellation was requested, and it's because of that cancellation that we need to flush. + LogExceptions(FlushAsync(cancellationToken: default)); + } + + throw; + } + Interlocked.Decrement(ref _pendingWriters); + } + + // If the connection has been aborted, then fail now instead of trying to send more data. + if (_abortException != null) + { + _writerLock.Exit(); + throw new IOException(SR.net_http_request_aborted, _abortException); + } + + // Flush anything necessary, and return back the write buffer to use. try { // If there is a pending write that was canceled while in progress, wait for it to complete. if (_inProgressWrite != null) { - await _inProgressWrite.ConfigureAwait(false); + await new ValueTask(_inProgressWrite).ConfigureAwait(false); // await ValueTask to minimize number of awaiter fields _inProgressWrite = null; } int totalBufferLength = _outgoingBuffer.Capacity; int activeBufferLength = _outgoingBuffer.ActiveLength; + // If the buffer has already grown to 32k, does not have room for the next request, + // and is non-empty, flush the current contents to the wire. if (totalBufferLength >= UnflushedOutgoingBufferSize && writeBytes >= totalBufferLength - activeBufferLength && activeBufferLength > 0) { - // If the buffer has already grown to 32k, does not have room for the next request, - // and is non-empty, flush the current contents to the wire. - await FlushOutgoingBytesAsync().ConfigureAwait(false); // we explicitly do not pass cancellationToken here, as this flush impacts more than just this operation + // We explicitly do not pass cancellationToken here, as this flush impacts more than just this operation. + await new ValueTask(FlushOutgoingBytesAsync()).ConfigureAwait(false); // await ValueTask to minimize number of awaiter fields } _outgoingBuffer.EnsureAvailableSpace(writeBytes); @@ -796,7 +834,7 @@ namespace System.Net.Http } catch { - _writerLock.Release(); + _writerLock.Exit(); throw; } } @@ -811,8 +849,8 @@ namespace System.Net.Http { if (NetEventSource.IsEnabled) Trace($"{nameof(flush)}={flush}"); - // We can't validate that we hold the semaphore, but we can at least validate that someone is holding it. - Debug.Assert(_writerLock.CurrentCount == 0); + // We can't validate that we hold the mutex, but we can at least validate that someone is holding it. + Debug.Assert(_writerLock.IsHeld); _outgoingBuffer.Commit(_currentWriteSize); _lastPendingWriterShouldFlush |= (flush == FlushTiming.AfterPendingWrites); @@ -823,16 +861,16 @@ namespace System.Net.Http { if (NetEventSource.IsEnabled) Trace(""); - // We can't validate that we hold the semaphore, but we can at least validate that someone is holding it. - Debug.Assert(_writerLock.CurrentCount == 0); + // We can't validate that we hold the mutex, but we can at least validate that someone is holding it. + Debug.Assert(_writerLock.IsHeld); EndWrite(forceFlush: false); } private void EndWrite(bool forceFlush) { - // We can't validate that we hold the semaphore, but we can at least validate that someone is holding it. - Debug.Assert(_writerLock.CurrentCount == 0); + // We can't validate that we hold the mutex, but we can at least validate that someone is holding it. + Debug.Assert(_writerLock.IsHeld); try { @@ -849,14 +887,18 @@ namespace System.Net.Http } finally { - _writerLock.Release(); + _writerLock.Exit(); } } private async ValueTask AcquireWriteLockAsync(CancellationToken cancellationToken) { - Task acquireLockTask = _writerLock.WaitAsync(cancellationToken); - if (!acquireLockTask.IsCompletedSuccessfully) + ValueTask acquireLockTask = _writerLock.EnterAsync(cancellationToken); + if (acquireLockTask.IsCompletedSuccessfully) + { + acquireLockTask.GetAwaiter().GetResult(); // to enable the value task sources to be pooled + } + else { Interlocked.Increment(ref _pendingWriters); @@ -877,7 +919,7 @@ namespace System.Net.Http // at least one other pending writer who can handle the flush. Worst case, we pay for a flush that ends up being // a nop. Note: we explicitly do not pass in the cancellationToken; if we're here, it's almost certainly because // cancellation was requested, and it's because of that cancellation that we need to flush. - LogExceptions(FlushAsync()); + LogExceptions(FlushAsync(cancellationToken: default)); } throw; @@ -889,7 +931,7 @@ namespace System.Net.Http // If the connection has been aborted, then fail now instead of trying to send more data. if (_abortException != null) { - _writerLock.Release(); + _writerLock.Exit(); throw new IOException(SR.net_http_request_aborted, _abortException); } } @@ -940,85 +982,85 @@ namespace System.Net.Http (buffer.Slice(0, maxSize), buffer.Slice(maxSize)) : (buffer, Memory.Empty); - private void WriteIndexedHeader(int index) + private void WriteIndexedHeader(int index, ref ArrayBuffer headerBuffer) { if (NetEventSource.IsEnabled) Trace($"{nameof(index)}={index}"); int bytesWritten; - while (!HPackEncoder.EncodeIndexedHeaderField(index, _headerBuffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeIndexedHeaderField(index, headerBuffer.AvailableSpan, out bytesWritten)) { - _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1); + headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1); } - _headerBuffer.Commit(bytesWritten); + headerBuffer.Commit(bytesWritten); } - private void WriteIndexedHeader(int index, string value) + private void WriteIndexedHeader(int index, string value, ref ArrayBuffer headerBuffer) { if (NetEventSource.IsEnabled) Trace($"{nameof(index)}={index}, {nameof(value)}={value}"); int bytesWritten; - while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexing(index, value, _headerBuffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexing(index, value, headerBuffer.AvailableSpan, out bytesWritten)) { - _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1); + headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1); } - _headerBuffer.Commit(bytesWritten); + headerBuffer.Commit(bytesWritten); } - private void WriteLiteralHeader(string name, ReadOnlySpan values) + private void WriteLiteralHeader(string name, ReadOnlySpan values, ref ArrayBuffer headerBuffer) { if (NetEventSource.IsEnabled) Trace($"{nameof(name)}={name}, {nameof(values)}={string.Join(", ", values.ToArray())}"); int bytesWritten; - while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, _headerBuffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewName(name, values, HttpHeaderParser.DefaultSeparator, headerBuffer.AvailableSpan, out bytesWritten)) { - _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1); + headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1); } - _headerBuffer.Commit(bytesWritten); + headerBuffer.Commit(bytesWritten); } - private void WriteLiteralHeaderValues(ReadOnlySpan values, string? separator) + private void WriteLiteralHeaderValues(ReadOnlySpan values, string? separator, ref ArrayBuffer headerBuffer) { if (NetEventSource.IsEnabled) Trace($"{nameof(values)}={string.Join(separator, values.ToArray())}"); int bytesWritten; - while (!HPackEncoder.EncodeStringLiterals(values, separator, _headerBuffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeStringLiterals(values, separator, headerBuffer.AvailableSpan, out bytesWritten)) { - _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1); + headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1); } - _headerBuffer.Commit(bytesWritten); + headerBuffer.Commit(bytesWritten); } - private void WriteLiteralHeaderValue(string value) + private void WriteLiteralHeaderValue(string value, ref ArrayBuffer headerBuffer) { if (NetEventSource.IsEnabled) Trace($"{nameof(value)}={value}"); int bytesWritten; - while (!HPackEncoder.EncodeStringLiteral(value, _headerBuffer.AvailableSpan, out bytesWritten)) + while (!HPackEncoder.EncodeStringLiteral(value, headerBuffer.AvailableSpan, out bytesWritten)) { - _headerBuffer.EnsureAvailableSpace(_headerBuffer.AvailableLength + 1); + headerBuffer.EnsureAvailableSpace(headerBuffer.AvailableLength + 1); } - _headerBuffer.Commit(bytesWritten); + headerBuffer.Commit(bytesWritten); } - private void WriteBytes(ReadOnlySpan bytes) + private void WriteBytes(ReadOnlySpan bytes, ref ArrayBuffer headerBuffer) { if (NetEventSource.IsEnabled) Trace($"{nameof(bytes.Length)}={bytes.Length}"); - if (bytes.Length > _headerBuffer.AvailableLength) + if (bytes.Length > headerBuffer.AvailableLength) { - _headerBuffer.EnsureAvailableSpace(bytes.Length); + headerBuffer.EnsureAvailableSpace(bytes.Length); } - bytes.CopyTo(_headerBuffer.AvailableSpan); - _headerBuffer.Commit(bytes.Length); + bytes.CopyTo(headerBuffer.AvailableSpan); + headerBuffer.Commit(bytes.Length); } - private void WriteHeaderCollection(HttpHeaders headers) + private void WriteHeaderCollection(HttpHeaders headers, ref ArrayBuffer headerBuffer) { if (NetEventSource.IsEnabled) Trace(""); @@ -1027,11 +1069,12 @@ namespace System.Net.Http return; } + ref string[]? tmpHeaderValuesArray = ref t_headerValues; foreach (KeyValuePair header in headers.HeaderStore) { - int headerValuesCount = HttpHeaders.GetValuesAsStrings(header.Key, header.Value, ref _headerValues); + int headerValuesCount = HttpHeaders.GetValuesAsStrings(header.Key, header.Value, ref tmpHeaderValuesArray); Debug.Assert(headerValuesCount > 0, "No values for header??"); - ReadOnlySpan headerValues = _headerValues.AsSpan(0, headerValuesCount); + ReadOnlySpan headerValues = tmpHeaderValuesArray.AsSpan(0, headerValuesCount); KnownHeader? knownHeader = header.Key.KnownHeader; if (knownHeader != null) @@ -1048,8 +1091,8 @@ namespace System.Net.Http { if (string.Equals(value, "trailers", StringComparison.OrdinalIgnoreCase)) { - WriteBytes(knownHeader.Http2EncodedName); - WriteLiteralHeaderValue(value); + WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer); + WriteLiteralHeaderValue(value, ref headerBuffer); break; } } @@ -1057,7 +1100,7 @@ namespace System.Net.Http } // For all other known headers, send them via their pre-encoded name and the associated value. - WriteBytes(knownHeader.Http2EncodedName); + WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer); string? separator = null; if (headerValues.Length > 1) { @@ -1072,21 +1115,20 @@ namespace System.Net.Http } } - WriteLiteralHeaderValues(headerValues, separator); + WriteLiteralHeaderValues(headerValues, separator, ref headerBuffer); } } else { // The header is not known: fall back to just encoding the header name and value(s). - WriteLiteralHeader(header.Key.Name, headerValues); + WriteLiteralHeader(header.Key.Name, headerValues, ref headerBuffer); } } } - private void WriteHeaders(HttpRequestMessage request) + private void WriteHeaders(HttpRequestMessage request, ref ArrayBuffer headerBuffer) { if (NetEventSource.IsEnabled) Trace(""); - Debug.Assert(_headerBuffer.ActiveLength == 0); // HTTP2 does not support Transfer-Encoding: chunked, so disable this on the request. if (request.HasHeaders && request.Headers.TransferEncodingChunked == true) @@ -1099,42 +1141,42 @@ namespace System.Net.Http // Method is normalized so we can do reference equality here. if (ReferenceEquals(normalizedMethod, HttpMethod.Get)) { - WriteIndexedHeader(H2StaticTable.MethodGet); + WriteIndexedHeader(H2StaticTable.MethodGet, ref headerBuffer); } else if (ReferenceEquals(normalizedMethod, HttpMethod.Post)) { - WriteIndexedHeader(H2StaticTable.MethodPost); + WriteIndexedHeader(H2StaticTable.MethodPost, ref headerBuffer); } else { - WriteIndexedHeader(H2StaticTable.MethodGet, normalizedMethod.Method); + WriteIndexedHeader(H2StaticTable.MethodGet, normalizedMethod.Method, ref headerBuffer); } - WriteIndexedHeader(_stream is SslStream ? H2StaticTable.SchemeHttps : H2StaticTable.SchemeHttp); + WriteIndexedHeader(_stream is SslStream ? H2StaticTable.SchemeHttps : H2StaticTable.SchemeHttp, ref headerBuffer); if (request.HasHeaders && request.Headers.Host != null) { - WriteIndexedHeader(H2StaticTable.Authority, request.Headers.Host); + WriteIndexedHeader(H2StaticTable.Authority, request.Headers.Host, ref headerBuffer); } else { - WriteBytes(_pool._http2EncodedAuthorityHostHeader); + WriteBytes(_pool._http2EncodedAuthorityHostHeader, ref headerBuffer); } Debug.Assert(request.RequestUri != null); string pathAndQuery = request.RequestUri.PathAndQuery; if (pathAndQuery == "/") { - WriteIndexedHeader(H2StaticTable.PathSlash); + WriteIndexedHeader(H2StaticTable.PathSlash, ref headerBuffer); } else { - WriteIndexedHeader(H2StaticTable.PathSlash, pathAndQuery); + WriteIndexedHeader(H2StaticTable.PathSlash, pathAndQuery, ref headerBuffer); } if (request.HasHeaders) { - WriteHeaderCollection(request.Headers); + WriteHeaderCollection(request.Headers, ref headerBuffer); } // Determine cookies to send. @@ -1143,8 +1185,8 @@ namespace System.Net.Http string cookiesFromContainer = _pool.Settings._cookieContainer!.GetCookieHeader(request.RequestUri); if (cookiesFromContainer != string.Empty) { - WriteBytes(KnownHeaders.Cookie.Http2EncodedName); - WriteLiteralHeaderValue(cookiesFromContainer); + WriteBytes(KnownHeaders.Cookie.Http2EncodedName, ref headerBuffer); + WriteLiteralHeaderValue(cookiesFromContainer, ref headerBuffer); } } @@ -1154,13 +1196,13 @@ namespace System.Net.Http // unless this is a method that never has a body. if (normalizedMethod.MustHaveRequestBody) { - WriteBytes(KnownHeaders.ContentLength.Http2EncodedName); - WriteLiteralHeaderValue("0"); + WriteBytes(KnownHeaders.ContentLength.Http2EncodedName, ref headerBuffer); + WriteLiteralHeaderValue("0", ref headerBuffer); } } else { - WriteHeaderCollection(request.Content.Headers); + WriteHeaderCollection(request.Content.Headers, ref headerBuffer); } } @@ -1196,45 +1238,61 @@ namespace System.Net.Http private async ValueTask SendHeadersAsync(HttpRequestMessage request, CancellationToken cancellationToken, bool mustFlush) { - // We serialize usage of the header encoder and the header buffer. - // This also ensures that new streams are always created in ascending order. - await _headerSerializationLock.WaitAsync(cancellationToken).ConfigureAwait(false); + // Enforce MAX_CONCURRENT_STREAMS setting value. We do this before anything else, e.g. renting buffers to serialize headers, + // in order to avoid consuming resources in potentially many requests waiting for access. try { - // Generate the entire header block, without framing, into the connection header buffer. - WriteHeaders(request); - - try - { - // Enforce MAX_CONCURRENT_STREAMS setting value. - await _concurrentStreams.RequestCreditAsync(1, cancellationToken).ConfigureAwait(false); - } - catch (ObjectDisposedException) + await _concurrentStreams.RequestCreditAsync(1, cancellationToken).ConfigureAwait(false); + } + catch (ObjectDisposedException) + { + // We have race condition between shutting down and initiating new requests. + // When we are shutting down the connection (e.g. due to receiving GOAWAY, etc) + // we will wait until the stream count goes to 0, and then we will close the connetion + // and perform clean up, including disposing _concurrentStreams. + // So if we get ObjectDisposedException here, we must have shut down the connection. + // Throw a retryable request exception if this is not result of some other error. + // This will cause retry logic to kick in and perform another connection attempt. + // The user should never see this exception. See similar handling below. + // Throw a retryable request exception if this is not result of some other error. + // This will cause retry logic to kick in and perform another connection attempt. + // The user should never see this exception. See also below. + lock (SyncObject) { - // We have race condition between shutting down and initiating new requests. - // When we are shutting down the connection (e.g. due to receiving GOAWAY, etc) - // we will wait until the stream count goes to 0, and then we will close the connetion - // and perform clean up, including disposing _concurrentStreams. - // So if we get ObjectDisposedException here, we must have shut down the connection. - // Throw a retryable request exception if this is not result of some other error. - // This will cause retry logic to kick in and perform another connection attempt. - // The user should never see this exception. See similar handling below. - // Throw a retryable request exception if this is not result of some other error. - // This will cause retry logic to kick in and perform another connection attempt. - // The user should never see this exception. See also below. Debug.Assert(_disposed || _lastStreamId != -1); Debug.Assert(_httpStreams.Count == 0); - - lock (SyncObject) - { - throw GetShutdownException(); - } + throw GetShutdownException(); } + } + + ArrayBuffer headerBuffer = default; + try + { + // Serialize headers to a temporary buffer, and do as much work to prepare to send the headers as we can + // before taking the write lock. + headerBuffer = new ArrayBuffer(InitialConnectionBufferSize, usePool: true); + WriteHeaders(request, ref headerBuffer); + ReadOnlyMemory remaining = headerBuffer.ActiveMemory; + Debug.Assert(remaining.Length > 0); + + // Calculate the total number of bytes we're going to use (content + headers). + int frameCount = ((remaining.Length - 1) / FrameHeader.MaxLength) + 1; + int totalSize = remaining.Length + (frameCount * FrameHeader.Size); + ReadOnlyMemory current; + (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength); + FrameFlags flags = + (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None) | + (request.Content == null ? FrameFlags.EndStream : FrameFlags.None); + + // Start the write. This serializes access to write to the connection, and ensures that HEADERS + // and CONTINUATION frames stay together, as they must do. We use the lock as well to ensure new + // streams are created and started in order. + Memory writeBuffer = await StartWriteAsync(totalSize, cancellationToken).ConfigureAwait(false); try { - // Allocate the next available stream ID. - // Note that if we fail before sending the headers, we'll just skip this stream ID, which is fine. + // Allocate the next available stream ID. Note that if we fail before sending the headers, + // we'll just skip this stream ID, which is fine. int streamId; lock (SyncObject) { @@ -1252,81 +1310,52 @@ namespace System.Net.Http _nextStream += 2; } - ReadOnlyMemory remaining = _headerBuffer.ActiveMemory; - Debug.Assert(remaining.Length > 0); - - // Calculate the total number of bytes we're going to use (content + headers). - int frameCount = ((remaining.Length - 1) / FrameHeader.MaxLength) + 1; - int totalSize = remaining.Length + frameCount * FrameHeader.Size; - - // Note, HEADERS and CONTINUATION frames must be together, so hold the writer lock across sending all of them. - Memory writeBuffer = await StartWriteAsync(totalSize, cancellationToken).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace(streamId, $"Started writing. {nameof(totalSize)}={totalSize}"); - // Send the HEADERS frame. - ReadOnlyMemory current; - (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength); - - FrameFlags flags = - (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None) | - (request.Content == null ? FrameFlags.EndStream : FrameFlags.None); - - FrameHeader frameHeader = new FrameHeader(current.Length, FrameType.Headers, flags, streamId); - frameHeader.WriteTo(writeBuffer.Span); + // Copy the HEADERS frame. + new FrameHeader(current.Length, FrameType.Headers, flags, streamId).WriteTo(writeBuffer.Span); writeBuffer = writeBuffer.Slice(FrameHeader.Size); - current.CopyTo(writeBuffer); writeBuffer = writeBuffer.Slice(current.Length); - if (NetEventSource.IsEnabled) Trace(streamId, $"Wrote HEADERS frame. Length={current.Length}, flags={flags}"); - // Send CONTINUATION frames, if any. + // Copy CONTINUATION frames, if any. while (remaining.Length > 0) { (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength); + flags = remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None; - flags = (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None); - - frameHeader = new FrameHeader(current.Length, FrameType.Continuation, flags, streamId); - frameHeader.WriteTo(writeBuffer.Span); + new FrameHeader(current.Length, FrameType.Continuation, flags, streamId).WriteTo(writeBuffer.Span); writeBuffer = writeBuffer.Slice(FrameHeader.Size); - current.CopyTo(writeBuffer); writeBuffer = writeBuffer.Slice(current.Length); - if (NetEventSource.IsEnabled) Trace(streamId, $"Wrote CONTINUATION frame. Length={current.Length}, flags={flags}"); } Debug.Assert(writeBuffer.Length == 0); - Http2Stream http2Stream; - try - { - // We're about to write the HEADERS frame, so add the stream to the dictionary now. - // The lifetime of the stream is now controlled by the stream itself and the connection. - // This can fail if the connection is shutting down, in which case we will cancel sending this frame. - http2Stream = AddStream(streamId, request); - } - catch - { - CancelWrite(); - throw; - } + // We're about to flush the HEADERS frame, so add the stream to the dictionary now. + // The lifetime of the stream is now controlled by the stream itself and the connection. + // This can fail if the connection is shutting down, in which case we will cancel sending this frame. + Http2Stream http2Stream = AddStream(streamId, request); FinishWrite(mustFlush || (flags & FrameFlags.EndStream) != 0 ? FlushTiming.AfterPendingWrites : FlushTiming.Eventually); - return http2Stream; } catch { - _concurrentStreams.AdjustCredit(1); + CancelWrite(); throw; } } + catch + { + _concurrentStreams.AdjustCredit(1); + throw; + } finally { - _headerBuffer.Discard(_headerBuffer.ActiveLength); - _headerSerializationLock.Release(); + headerBuffer.Dispose(); } } @@ -1352,7 +1381,7 @@ namespace System.Net.Http writeBuffer = await StartWriteAsync(FrameHeader.Size + current.Length, cancellationToken).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace(streamId, $"Started writing. {nameof(writeBuffer.Length)}={writeBuffer.Length}"); } - catch (OperationCanceledException) + catch { _connectionWindow.AdjustCredit(frameSize); throw; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 5e4a383..2948820 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -71,6 +71,8 @@ namespace System.Net.Http /// Reset _waitSource. /// private ManualResetValueTaskSourceCore _waitSource = new ManualResetValueTaskSourceCore { RunContinuationsAsynchronously = true }; // mutable struct, do not make this readonly + /// Cancellation registration used to cancel the . + private CancellationTokenRegistration _waitSourceCancellation; /// /// Whether code has requested or is about to request a wait be performed and thus requires a call to SetResult to complete it. /// This is read and written while holding the lock so that most operations on _waitSource don't need to be. @@ -181,7 +183,7 @@ namespace System.Net.Http { using (Http2WriteStream writeStream = new Http2WriteStream(this)) { - await _request.Content.CopyToAsync(writeStream, null, _requestBodyCancellationToken).ConfigureAwait(false); + await _request.Content.InternalCopyToAsync(writeStream, null, _requestBodyCancellationToken).ConfigureAwait(false); } } @@ -1162,7 +1164,18 @@ namespace System.Net.Http // associated with the implementation is just delegated to the ManualResetValueTaskSourceCore. ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitSource.GetStatus(token); void IValueTaskSource.OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => _waitSource.OnCompleted(continuation, state, token, flags); - void IValueTaskSource.GetResult(short token) => _waitSource.GetResult(token); + void IValueTaskSource.GetResult(short token) + { + Debug.Assert(!Monitor.IsEntered(SyncObject)); + + // Clean up the registration. It's important to Dispose rather than Unregister, so that we wait + // for any in-flight cancellation to complete. + _waitSourceCancellation.Dispose(); + _waitSourceCancellation = default; + + // Propagate any exceptions if there were any. + _waitSource.GetResult(token); + } private void WaitForData() { @@ -1180,46 +1193,41 @@ namespace System.Net.Http // Reset'ing it is this code here. It's possible for this to race with the _waitSource being completed, but that's ok and is // handled by _waitSource as one of its primary purposes. We can't assert _hasWaiter here, though, as once we released the // lock, a producer could have seen _hasWaiter as true and both set it to false and signaled _waitSource. - if (!cancellationToken.CanBeCanceled) - { - return new ValueTask(this, _waitSource.Version); - } // With HttpClient, the supplied cancellation token will always be cancelable, as HttpClient supplies a token that // will have cancellation requested if CancelPendingRequests is called (or when a non-infinite Timeout expires). // However, this could still be non-cancelable if HttpMessageInvoker was used, at which point this will only be - // cancelable if the caller's token was cancelable. To avoid the extra allocation here in such a case, we make - // this pay-for-play: if the token isn't cancelable, return a ValueTask wrapping this object directly, and only - // if it is cancelable, then register for the cancellation callback, allocate a task for the asynchronously - // completing case, etc. - return GetCancelableWaiterTask(cancellationToken); + // cancelable if the caller's token was cancelable. - async ValueTask GetCancelableWaiterTask(CancellationToken cancellationToken) + _waitSourceCancellation = cancellationToken.UnsafeRegister(s => { - using (cancellationToken.UnsafeRegister(s => - { - var thisRef = (Http2Stream)s!; + var thisRef = (Http2Stream)s!; - bool signalWaiter; - Debug.Assert(!Monitor.IsEntered(thisRef.SyncObject)); - lock (thisRef.SyncObject) - { - signalWaiter = thisRef._hasWaiter; - thisRef._hasWaiter = false; - } + bool signalWaiter; + Debug.Assert(!Monitor.IsEntered(thisRef.SyncObject)); + lock (thisRef.SyncObject) + { + signalWaiter = thisRef._hasWaiter; + thisRef._hasWaiter = false; + } - if (signalWaiter) - { - // Wake up the wait. It will then immediately check whether cancellation was requested and throw if it was. - thisRef._waitSource.SetResult(true); - } - }, this)) + if (signalWaiter) { - await new ValueTask(this, _waitSource.Version).ConfigureAwait(false); + // Wake up the wait. It will then immediately check whether cancellation was requested and throw if it was. + thisRef._waitSource.SetException(ExceptionDispatchInfo.SetCurrentStackTrace( + CancellationHelper.CreateOperationCanceledException(null, _waitSourceCancellation.Token))); } + }, this); - CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - } + // There's a race condition in UnsafeRegister above. If cancellation is requested prior to UnsafeRegister, + // the delegate may be invoked synchronously as part of the UnsafeRegister call. In that case, it will execute + // before _waitSourceCancellation has been set, which means UnsafeRegister will have set a cancellation + // exception into the wait source with a default token rather than the ideal one. To handle that, + // we check for cancellation again, and throw here with the right token. Worst case, if cancellation is + // requested prior to here, we end up allocating an extra OCE object. + CancellationHelper.ThrowIfCancellationRequested(cancellationToken); + + return new ValueTask(this, _waitSource.Version); } public void Trace(string message, [CallerMemberName] string? memberName = null) => diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs index 7104057..f02edd7 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs @@ -101,41 +101,27 @@ namespace System.Net.Http /// Awaits a task, ignoring any resulting exceptions. internal static void IgnoreExceptions(ValueTask task) { - _ = IgnoreExceptionsAsync(task); - - static async Task IgnoreExceptionsAsync(ValueTask task) + // Avoid TaskScheduler.UnobservedTaskException firing for any exceptions. + if (task.IsCompleted) { - try { await task.ConfigureAwait(false); } catch { } + if (task.IsFaulted) + { + _ = task.AsTask().Exception; + } } - } - - /// Awaits a task, ignoring any resulting exceptions. - internal static void IgnoreExceptions(Task task) - { - _ = IgnoreExceptionsAsync(task); - - static async Task IgnoreExceptionsAsync(Task task) + else { - try { await task.ConfigureAwait(false); } catch { } + task.AsTask().ContinueWith(t => _ = t.Exception, + CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Default); } } /// Awaits a task, logging any resulting exceptions (which are otherwise ignored). - internal void LogExceptions(Task task) - { - _ = LogExceptionsAsync(task); - - async Task LogExceptionsAsync(Task task) + internal void LogExceptions(Task task) => + task.ContinueWith(t => { - try - { - await task.ConfigureAwait(false); - } - catch (Exception e) - { - if (NetEventSource.IsEnabled) Trace($"Exception from asynchronous processing: {e}"); - } - } - } + Exception? e = t.Exception?.InnerException; // Access Exception even if not tracing, to avoid TaskScheduler.UnobservedTaskException firing + if (NetEventSource.IsEnabled) Trace($"Exception from asynchronous processing: {e}"); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Default); } } diff --git a/src/libraries/System.Net.Http/src/System/Threading/AsyncMutex.cs b/src/libraries/System.Net.Http/src/System/Threading/AsyncMutex.cs new file mode 100644 index 0000000..6c85c58 --- /dev/null +++ b/src/libraries/System.Net.Http/src/System/Threading/AsyncMutex.cs @@ -0,0 +1,313 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; + +namespace System.Threading +{ + /// Provides an async mutex. + /// + /// This could be achieved with a constructed with an initial + /// and max limit of 1. However, this implementation is optimized to the needs of HTTP/2, + /// where the mutex is held for a very short period of time, when it is held any other + /// attempts to access it must wait asynchronously, where it's only binary rather than counting, and where + /// we want to minimize contention that a releaser incurs while trying to unblock a waiter. The primary + /// value-add is the fast-path interlocked checks that minimize contention for these use cases (essentially + /// making it an async futex), and then as long as we're wrapping something and we know exactly how all + /// consumers use the type, we can offer a ValueTask-based implementation that reuses waiter nodes. + /// + internal sealed class AsyncMutex + { + /// Fast-path gate count tracking access to the mutex. + /// + /// If the value is 1, the mutex can be entered atomically with an interlocked operation. + /// If the value is less than or equal to 0, the mutex is held and requires fallback to enter it. + /// + private int _gate = 1; + /// Secondary check guarded by the lock to indicate whether the mutex is acquired. + /// + /// This is only meaningful after having updated via interlockeds and taken the appropriate path. + /// If after decrementing we end up with a negative count, the mutex is contended, hence + /// starting as true. The primary purpose of this field + /// is to handle the race condition between one thread acquiring the mutex, then another thread trying to acquire + /// and getting as far as completing the interlocked operation, and then the original thread releasing; at that point + /// it'll hit the lock and we need to store that the mutex is available to enter. If we instead used a + /// SemaphoreSlim as the fallback from the interlockeds, this would have been its count, and it would have started + /// with an initial count of 0. + /// + private bool _lockedSemaphoreFull = true; + /// The head of the double-linked waiting queue. Waiters are dequeued from the head. + private Waiter? _waitersHead; + /// The tail of the double-linked waiting queue. Waiters are added at the tail. + private Waiter? _waitersTail; + /// A pool of waiter objects that are ready to be reused. + /// + /// There is no bound on this pool, but it ends up being implicitly bounded by the maximum number of concurrent + /// waiters there ever were, which for our uses in HTTP/2 will end up being the high-water mark of concurrent streams + /// on a single connection. + /// + private readonly ConcurrentQueue _unusedWaiters = new ConcurrentQueue(); + + /// Gets whether the mutex is currently held by some operation (not necessarily the caller). + /// This should be used only for asserts and debugging. + public bool IsHeld => _gate != 1; + + /// Objects used to synchronize operations on the instance. + private object SyncObj => _unusedWaiters; + + /// Asynchronously waits to enter the mutex. + /// The CancellationToken token to observe. + /// A task that will complete when the mutex has been entered or the enter canceled. + public ValueTask EnterAsync(CancellationToken cancellationToken) + { + // If cancellation was requested, bail immediately. + // If the mutex is not currently held nor contended, enter immediately. + // Otherwise, fall back to a more expensive likely-asynchronous wait. + return + cancellationToken.IsCancellationRequested ? FromCanceled(cancellationToken) : + Interlocked.Decrement(ref _gate) >= 0 ? default : + Contended(cancellationToken); + + // Everything that follows is the equivalent of: + // return _sem.WaitAsync(cancellationToken); + // if _sem were to be constructed as `new SemaphoreSlim(0)`. + + ValueTask Contended(CancellationToken cancellationToken) + { + // Get a reusable waiter object. We do this before the lock to minimize work (and especially allocation) + // done while holding the lock. It's possible we'll end up dequeuing a waiter and then under the lock + // discovering the mutex is now available, at which point we will have wasted an object. That's currently + // showing to be the better alternative (including not trying to put it back in that case). + if (!_unusedWaiters.TryDequeue(out Waiter? w)) + { + w = new Waiter(this); + } + + lock (SyncObj) + { + // Now that we're holding the lock, check to see whether the async lock is acquirable. + if (!_lockedSemaphoreFull) + { + _lockedSemaphoreFull = true; + return default; + } + else + { + // Add it to the linked list of waiters. + if (_waitersTail is null) + { + Debug.Assert(_waitersHead is null); + _waitersTail = _waitersHead = w; + } + else + { + Debug.Assert(_waitersHead != null); + w.Prev = _waitersTail; + _waitersTail.Next = w; + _waitersTail = w; + } + } + } + + // At this point the waiter was added to the list of waiters, so we want to + // register for cancellation in order to cancel it and remove it from the list + // if cancellation is requested. However, since we've released the lock, it's + // possible the waiter could have actually already been completed and removed + // from the list by another thread releasing the mutex. That's ok; we'll + // end up registering for cancellation here, and then when the consumer awaits + // it, the act of awaiting it will Dispose of the registration, ensuring that + // it won't run after that point, making it safe to pool that instance. + w.CancellationRegistration = cancellationToken.UnsafeRegister(s => OnCancellation(s), w); + + // Return the waiter as a value task. + return new ValueTask(w, w.Version); + + // Cancels the specified waiter if it's still in the list. + static void OnCancellation(object? state) + { + Waiter? w = (Waiter)state!; + AsyncMutex m = w.Owner; + + lock (m.SyncObj) + { + bool inList = w.Next != null || w.Prev != null || m._waitersHead == w; + if (inList) + { + // The waiter was still in the list. + Debug.Assert( + m._waitersHead == w || + (m._waitersTail == w && w.Prev != null && w.Next is null) || + (w.Next != null && w.Prev != null)); + + // The gate counter was decremented when this waiter was added. We need + // to undo that. Since the waiter is still in the list, the lock must + // still be held by someone, which means we don't need to do anything with + // the result of this increment. If it increments to < 1, then there are + // still other waiters. If it increments to 1, we're in a rare race condition + // where there are no other waiters and the owner just incremented the gate + // count; they would have seen it be < 1, so they will proceed to take the + // contended code path and synchronize on the lock we're holding... once we + // release it, they will appropriately update state. + Interlocked.Increment(ref m._gate); + + // Remove it from the list. + if (m._waitersHead == w && m._waitersTail == w) + { + // It's the only node in the list. + m._waitersHead = m._waitersTail = null; + } + else if (m._waitersTail == w) + { + // It's the most recently queued item in the list. + m._waitersTail = w.Prev; + Debug.Assert(m._waitersTail != null); + m._waitersTail.Next = null; + } + else if (m._waitersHead == w) + { + // It's the next item to be removed from the list. + m._waitersHead = w.Next; + Debug.Assert(m._waitersHead != null); + m._waitersHead.Prev = null; + } + else + { + // It's in the middle of the list. + Debug.Assert(w.Next != null); + Debug.Assert(w.Prev != null); + w.Next.Prev = w.Prev; + w.Prev.Next = w.Next; + } + + // Remove it from the list. + w.Next = w.Prev = null; + } + else + { + // The waiter was no longer in the list. We must not cancel it. + w = null; + } + } + + // If the waiter was in the list, we removed it under the lock and thus own + // the ability to cancel it. Do so. + w?.Cancel(); + } + } + } + + /// Releases the mutex. + /// The caller must logically own the mutex. This is not validated. + public void Exit() + { + if (Interlocked.Increment(ref _gate) < 1) + { + // This is the equivalent of: + // _sem.Release(); + // if _sem were to be constructed as `new SemaphoreSlim(0)`. + Contended(); + } + + void Contended() + { + Waiter? w; + + lock (SyncObj) + { + Debug.Assert(_lockedSemaphoreFull); + + // Wake up the next waiter in the list. + w = _waitersHead; + if (w != null) + { + // Remove the waiter. + _waitersHead = w.Next; + if (w.Next != null) + { + w.Next.Prev = null; + } + else + { + Debug.Assert(_waitersTail == w); + _waitersTail = null; + } + w.Next = w.Prev = null; + } + else + { + // There wasn't a waiter. Mark that the async lock is no longer full. + Debug.Assert(_waitersTail is null); + _lockedSemaphoreFull = false; + } + } + + // Either there wasn't a waiter, or we got one and successfully removed it from the list, + // at which point we own the ability to complete it. Do so. + w?.Set(); + } + } + + /// Creates a canceled ValueTask. + /// Separated out to reduce asm for this rare path in the call site. + [MethodImpl(MethodImplOptions.NoInlining)] + private static ValueTask FromCanceled(CancellationToken cancellationToken) => + new ValueTask(Task.FromCanceled(cancellationToken)); + + /// Represents a waiter for the mutex. + /// Implemented as a reusable backing source for a value task. + private sealed class Waiter : IValueTaskSource + { + private ManualResetValueTaskSourceCore _mrvtsc; // mutable struct; do not make this readonly + + public Waiter(AsyncMutex owner) + { + Owner = owner; + _mrvtsc.RunContinuationsAsynchronously = true; + } + + public AsyncMutex Owner { get; } + public CancellationTokenRegistration CancellationRegistration { get; set; } + public Waiter? Next { get; set; } + public Waiter? Prev { get; set; } + + public short Version => _mrvtsc.Version; + + public void Set() => _mrvtsc.SetResult(true); + public void Cancel() => _mrvtsc.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(CancellationRegistration.Token))); + + void IValueTaskSource.GetResult(short token) + { + Debug.Assert(Next is null && Prev is null); + + // Dispose of the registration. It's critical that this Dispose rather than Unregister, + // so that we can be guaranteed all cancellation-related work has completed by the time + // we return the instance to the pool. Otherwise, a race condition could result in + // a cancellation request for this operation canceling another unlucky request that + // happened to reuse the same node. + Debug.Assert(!Monitor.IsEntered(Owner.SyncObj)); + CancellationRegistration.Dispose(); + + // Complete the operation, propagating any exceptions. + _mrvtsc.GetResult(token); + + // Reset the instance and return it to the pool. + // We don't bother with a try/finally to return instances + // to the pool in the case of exceptions. + _mrvtsc.Reset(); + Owner._unusedWaiters.Enqueue(this); + } + + public ValueTaskSourceStatus GetStatus(short token) => + _mrvtsc.GetStatus(token); + + public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => + _mrvtsc.OnCompleted(continuation, state, token, flags); + } + } +} diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs index b2bd2be..3d65a6c 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs @@ -48,6 +48,9 @@ namespace System.Net.Http.Tests [InlineData("Expect-CT")] [InlineData("Expires")] [InlineData("From")] + [InlineData("grpc-encoding")] + [InlineData("grpc-message")] + [InlineData("grpc-status")] [InlineData("Host")] [InlineData("If-Match")] [InlineData("If-Modified-Since")] @@ -135,5 +138,113 @@ namespace System.Net.Http.Tests Assert.Null(KnownHeaders.TryGetKnownHeader(casedName.Select(c => (byte)c).ToArray())); } } + + [Theory] + [InlineData("Access-Control-Allow-Credentials", "true")] + [InlineData("Access-Control-Allow-Headers", "*")] + [InlineData("Access-Control-Allow-Methods", "*")] + [InlineData("Access-Control-Allow-Origin", "*")] + [InlineData("Access-Control-Allow-Origin", "null")] + [InlineData("Access-Control-Expose-Headers", "*")] + [InlineData("Cache-Control", "must-revalidate")] + [InlineData("Cache-Control", "no-cache")] + [InlineData("Cache-Control", "no-store")] + [InlineData("Cache-Control", "no-transform")] + [InlineData("Cache-Control", "private")] + [InlineData("Cache-Control", "proxy-revalidate")] + [InlineData("Cache-Control", "public")] + [InlineData("Connection", "close")] + [InlineData("Content-Disposition", "attachment")] + [InlineData("Content-Disposition", "inline")] + [InlineData("Content-Encoding", "gzip")] + [InlineData("Content-Encoding", "deflate")] + [InlineData("Content-Encoding", "br")] + [InlineData("Content-Encoding", "compress")] + [InlineData("Content-Encoding", "identity")] + [InlineData("Content-Type", "text/xml")] + [InlineData("Content-Type", "text/css")] + [InlineData("Content-Type", "text/csv")] + [InlineData("Content-Type", "image/gif")] + [InlineData("Content-Type", "image/png")] + [InlineData("Content-Type", "text/html")] + [InlineData("Content-Type", "text/plain")] + [InlineData("Content-Type", "image/jpeg")] + [InlineData("Content-Type", "application/pdf")] + [InlineData("Content-Type", "application/xml")] + [InlineData("Content-Type", "application/zip")] + [InlineData("Content-Type", "application/grpc")] + [InlineData("Content-Type", "application/json")] + [InlineData("Content-Type", "multipart/form-data")] + [InlineData("Content-Type", "application/javascript")] + [InlineData("Content-Type", "application/octet-stream")] + [InlineData("Content-Type", "text/html; charset=utf-8")] + [InlineData("Content-Type", "text/plain; charset=utf-8")] + [InlineData("Content-Type", "application/json; charset=utf-8")] + [InlineData("Content-Type", "application/x-www-form-urlencoded")] + [InlineData("Expect", "100-continue")] + [InlineData("grpc-encoding", "identity")] + [InlineData("grpc-encoding", "gzip")] + [InlineData("grpc-encoding", "deflate")] + [InlineData("grpc-status", "0")] + [InlineData("Pragma", "no-cache")] + [InlineData("Referrer-Policy", "strict-origin-when-cross-origin")] + [InlineData("Referrer-Policy", "origin-when-cross-origin")] + [InlineData("Referrer-Policy", "strict-origin")] + [InlineData("Referrer-Policy", "origin")] + [InlineData("Referrer-Policy", "same-origin")] + [InlineData("Referrer-Policy", "no-referrer-when-downgrade")] + [InlineData("Referrer-Policy", "no-referrer")] + [InlineData("Referrer-Policy", "unsafe-url")] + [InlineData("TE", "trailers")] + [InlineData("TE", "compress")] + [InlineData("TE", "deflate")] + [InlineData("TE", "gzip")] + [InlineData("Transfer-Encoding", "chunked")] + [InlineData("Transfer-Encoding", "compress")] + [InlineData("Transfer-Encoding", "deflate")] + [InlineData("Transfer-Encoding", "gzip")] + [InlineData("Transfer-Encoding", "identity")] + [InlineData("Upgrade-Insecure-Requests", "1")] + [InlineData("Vary", "*")] + [InlineData("X-Content-Type-Options", "nosniff")] + [InlineData("X-Frame-Options", "DENY")] + [InlineData("X-Frame-Options", "SAMEORIGIN")] + [InlineData("X-XSS-Protection", "0")] + [InlineData("X-XSS-Protection", "1")] + [InlineData("X-XSS-Protection", "1; mode=block")] + public void GetKnownHeaderValue_Known_Found(string name, string value) + { + foreach (string casedValue in new[] { value, value.ToUpperInvariant(), value.ToLowerInvariant() }) + { + Validate(KnownHeaders.TryGetKnownHeader(name), casedValue); + } + + static void Validate(KnownHeader knownHeader, string value) + { + Assert.NotNull(knownHeader); + + string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray()); + Assert.NotNull(v1); + Assert.Equal(value, v1, StringComparer.OrdinalIgnoreCase); + + string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray()); + Assert.Same(v1, v2); + } + } + + [Theory] + [InlineData("Content-Type", "application/jsot")] + [InlineData("Content-Type", "application/jsons")] + public void GetKnownHeaderValue_Unknown_NotFound(string name, string value) + { + KnownHeader knownHeader = KnownHeaders.TryGetKnownHeader(name); + Assert.NotNull(knownHeader); + + string v1 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray()); + string v2 = knownHeader.Descriptor.GetHeaderValue(value.Select(c => (byte)c).ToArray()); + Assert.Equal(value, v1); + Assert.Equal(value, v2); + Assert.NotSame(v1, v2); + } } } -- 2.7.4