From 5b112c310f9ea28ca7a06941787f2bea7eb315eb Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 15 May 2020 17:30:17 -0400 Subject: [PATCH] More HTTP/2 performance (and a few functional) improvements (#36246) * Use span instead of array for StatusHeaderName * Fix potential leak into CancellationToken We need to dispose of the linked token source we create. Also cleaned up some unnecessarily complicated code nearby. * Fix HttpConnectionBase.LogExceptions My previous changes here were flawed for the sync-completing case, and also accidentally introduced a closure. * Clean up protocol state if/else cascades into switches * Consolidate a bunch of exception throws into helpers * Fix cancellation handling of WaitFor100ContinueAsync * Change AsyncMutex's linked list to be circular * Remove linked token sources Rather than creating temporary linked token sources with the request body source and the supplied cancellation token, we can instead just register with the supplied token to cancel the request body source. This is valid because canceling any part of sending a request cancels any further sending of that request, not just that one constituent operation. * Avoid registering for linked cancellation until absolutely necessary We can avoid registering with the cancellation token until after we know that our send is completing asynchronously. * Remove closure/delegate allocation from WaitForDataAsync `this` was being closed over accidentally. I can't wait for static lambdas. * Avoid a temporary list for storing trailers Since it only exists to be defensive but we don't expect response.TrailingHeaders to be accessed until after the whole response has been received, we can store the headers into an HttpResponseHeaders instance and swap that instance in at the end. Best and common case, we avoid the list. Worst and uncommon case, we pay the overhead of the extra HttpResponseHeaders instead of the List. * Delete dead AcquireWriteLockAsync method * Reduce header frame overhead Minor optimizations to improve the asm * Remove unnecessary throws with GetShutdownException * Avoid extra lock in SendHeadersAsync * Move Http2Stream construction out of lock Makes a significant impact on reducing lock contention. * Streamline RemoveStream Including moving credit adjustment out of the lock * Move response message allocation to ctor Remove it from within the lock * Reorder interfaces on Http2Stream IHttpTrace doesn't need to be prioritized. * Address PR feedback --- .../Net/Http/HttpClientHandlerTest.Cancellation.cs | 36 +- .../System/Net/Http/Headers/HttpResponseHeaders.cs | 2 + .../src/System/Net/Http/HttpResponseMessage.cs | 37 ++- .../Net/Http/SocketsHttpHandler/Http2Connection.cs | 366 ++++++++------------- .../Net/Http/SocketsHttpHandler/Http2Stream.cs | 316 +++++++++--------- .../Http/SocketsHttpHandler/HttpConnectionBase.cs | 27 +- .../src/System/Threading/AsyncMutex.cs | 108 +++--- 7 files changed, 416 insertions(+), 476 deletions(-) diff --git a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Cancellation.cs b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Cancellation.cs index 6d87c67..679b540 100644 --- a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Cancellation.cs +++ b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Cancellation.cs @@ -437,7 +437,7 @@ namespace System.Net.Http.Functional.Tests bool called = false; var content = new StreamContent(new DelegateStream( canReadFunc: () => true, - readAsyncFunc: (buffer, offset, count, cancellationToken) => + readAsyncFunc: async (buffer, offset, count, cancellationToken) => { int result = 1; if (called) @@ -445,11 +445,17 @@ namespace System.Net.Http.Functional.Tests result = 0; Assert.False(cancellationToken.IsCancellationRequested); tokenSource.Cancel(); - Assert.True(cancellationToken.IsCancellationRequested); + + // Wait for cancellation to occur. It should be very quickly after it's been requested. + var tcs = new TaskCompletionSource(); + using (cancellationToken.Register(() => tcs.SetResult(true))) + { + await tcs.Task; + } } called = true; - return Task.FromResult(result); + return result; } )); yield return new object[] { content, tokenSource }; @@ -467,7 +473,7 @@ namespace System.Net.Http.Functional.Tests lengthFunc: () => 1, positionGetFunc: () => 0, positionSetFunc: _ => {}, - readAsyncFunc: (buffer, offset, count, cancellationToken) => + readAsyncFunc: async (buffer, offset, count, cancellationToken) => { int result = 1; if (called) @@ -475,11 +481,17 @@ namespace System.Net.Http.Functional.Tests result = 0; Assert.False(cancellationToken.IsCancellationRequested); tokenSource.Cancel(); - Assert.True(cancellationToken.IsCancellationRequested); + + // Wait for cancellation to occur. It should be very quickly after it's been requested. + var tcs = new TaskCompletionSource(); + using (cancellationToken.Register(() => tcs.SetResult(true))) + { + await tcs.Task; + } } called = true; - return Task.FromResult(result); + return result; } ))); yield return new object[] { content, tokenSource }; @@ -497,7 +509,7 @@ namespace System.Net.Http.Functional.Tests lengthFunc: () => 1, positionGetFunc: () => 0, positionSetFunc: _ => {}, - readAsyncFunc: (buffer, offset, count, cancellationToken) => + readAsyncFunc: async (buffer, offset, count, cancellationToken) => { int result = 1; if (called) @@ -505,11 +517,17 @@ namespace System.Net.Http.Functional.Tests result = 0; Assert.False(cancellationToken.IsCancellationRequested); tokenSource.Cancel(); - Assert.True(cancellationToken.IsCancellationRequested); + + // Wait for cancellation to occur. It should be very quickly after it's been requested. + var tcs = new TaskCompletionSource(); + using (cancellationToken.Register(() => tcs.SetResult(true))) + { + await tcs.Task; + } } called = true; - return Task.FromResult(result); + return result; } ))); yield return new object[] { content, tokenSource }; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpResponseHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpResponseHeaders.cs index 8da6741..3864380 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpResponseHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpResponseHeaders.cs @@ -147,6 +147,8 @@ namespace System.Net.Http.Headers _containsTrailingHeaders = containsTrailingHeaders; } + internal bool ContainsTrailingHeaders => _containsTrailingHeaders; + internal override void AddHeaders(HttpHeaders sourceHeaders) { base.AddHeaders(sourceHeaders); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpResponseMessage.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpResponseMessage.cs index b0763db..9873b42 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpResponseMessage.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpResponseMessage.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Net.Http.Headers; using System.Text; @@ -106,28 +107,28 @@ namespace System.Net.Http internal void SetReasonPhraseWithoutValidation(string value) => _reasonPhrase = value; - public HttpResponseHeaders Headers + public HttpResponseHeaders Headers => _headers ??= new HttpResponseHeaders(); + + public HttpResponseHeaders TrailingHeaders => _trailingHeaders ??= new HttpResponseHeaders(containsTrailingHeaders: true); + + /// Stores the supplied trailing headers into this instance. + /// + /// In the common/desired case where response.TrailingHeaders isn't accessed until after the whole payload has been + /// received, will still be null, and we can simply store the supplied instance into + /// and assume ownership of the instance. In the uncommon case where it was accessed, + /// we add all of the headers to the existing instance. + /// + internal void StoreReceivedTrailingHeaders(HttpResponseHeaders headers) { - get + Debug.Assert(headers.ContainsTrailingHeaders); + + if (_trailingHeaders is null) { - if (_headers == null) - { - _headers = new HttpResponseHeaders(); - } - return _headers; + _trailingHeaders = headers; } - } - - public HttpResponseHeaders TrailingHeaders - { - get + else { - if (_trailingHeaders == null) - { - _trailingHeaders = new HttpResponseHeaders(containsTrailingHeaders: true); - } - - return _trailingHeaders; + _trailingHeaders.AddHeaders(headers); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 00b9bc1..9af6445 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -5,6 +5,7 @@ using System.Buffers.Binary; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.Http.Headers; using System.Net.Http.HPack; @@ -133,17 +134,17 @@ namespace System.Net.Http s_http2ConnectionPreface.AsSpan().CopyTo(_outgoingBuffer.AvailableSpan); _outgoingBuffer.Commit(s_http2ConnectionPreface.Length); - // Send SETTINGS frame - WriteFrameHeader(new FrameHeader(FrameHeader.SettingLength, FrameType.Settings, FrameFlags.None, 0)); - - // Disable push promise + // Send SETTINGS frame. Disable push promise. + FrameHeader.WriteTo(_outgoingBuffer.AvailableSpan, FrameHeader.SettingLength, FrameType.Settings, FrameFlags.None, streamId: 0); + _outgoingBuffer.Commit(FrameHeader.Size); BinaryPrimitives.WriteUInt16BigEndian(_outgoingBuffer.AvailableSpan, (ushort)SettingId.EnablePush); _outgoingBuffer.Commit(2); BinaryPrimitives.WriteUInt32BigEndian(_outgoingBuffer.AvailableSpan, 0); _outgoingBuffer.Commit(4); // Send initial connection-level WINDOW_UPDATE - WriteFrameHeader(new FrameHeader(FrameHeader.WindowUpdateLength, FrameType.WindowUpdate, FrameFlags.None, 0)); + FrameHeader.WriteTo(_outgoingBuffer.AvailableSpan, FrameHeader.WindowUpdateLength, FrameType.WindowUpdate, FrameFlags.None, streamId: 0); + _outgoingBuffer.Commit(FrameHeader.Size); BinaryPrimitives.WriteUInt32BigEndian(_outgoingBuffer.AvailableSpan, (ConnectionWindowSize - DefaultInitialWindowSize)); _outgoingBuffer.Commit(4); @@ -195,7 +196,7 @@ namespace System.Net.Http // Parse the frame header from our read buffer and validate it. FrameHeader frameHeader = FrameHeader.ReadFrom(_incomingBuffer.ActiveSpan); - if (frameHeader.Length > FrameHeader.MaxLength) + if (frameHeader.PayloadLength > FrameHeader.MaxPayloadLength) { if (initialFrame && NetEventSource.IsEnabled) { @@ -204,21 +205,21 @@ namespace System.Net.Http } _incomingBuffer.Discard(FrameHeader.Size); - throw new Http2ConnectionException(initialFrame ? Http2ProtocolErrorCode.ProtocolError : Http2ProtocolErrorCode.FrameSizeError); + ThrowProtocolError(initialFrame ? Http2ProtocolErrorCode.ProtocolError : Http2ProtocolErrorCode.FrameSizeError); } _incomingBuffer.Discard(FrameHeader.Size); // Ensure we've read the frame contents into our buffer. - if (_incomingBuffer.ActiveLength < frameHeader.Length) + if (_incomingBuffer.ActiveLength < frameHeader.PayloadLength) { - _incomingBuffer.EnsureAvailableSpace(frameHeader.Length - _incomingBuffer.ActiveLength); + _incomingBuffer.EnsureAvailableSpace(frameHeader.PayloadLength - _incomingBuffer.ActiveLength); do { int bytesRead = await _stream.ReadAsync(_incomingBuffer.AvailableMemory).ConfigureAwait(false); _incomingBuffer.Commit(bytesRead); - if (bytesRead == 0) ThrowPrematureEOF(frameHeader.Length); + if (bytesRead == 0) ThrowPrematureEOF(frameHeader.PayloadLength); } - while (_incomingBuffer.ActiveLength < frameHeader.Length); + while (_incomingBuffer.ActiveLength < frameHeader.PayloadLength); } // Return the read frame header. @@ -236,7 +237,7 @@ namespace System.Net.Http FrameHeader frameHeader = await ReadFrameAsync(initialFrame: true).ConfigureAwait(false); if (frameHeader.Type != FrameType.Settings || frameHeader.AckFlag) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } if (NetEventSource.IsEnabled) Trace($"Frame 0: {frameHeader}."); @@ -318,7 +319,8 @@ namespace System.Net.Http case FrameType.PushPromise: // Should not happen, since we disable this in our initial SETTINGS case FrameType.Continuation: // Should only be received while processing headers in ProcessHeadersFrame default: - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); + break; } } } @@ -338,7 +340,7 @@ namespace System.Net.Http { if (streamId <= 0 || streamId >= _nextStream) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } lock (SyncObject) @@ -369,10 +371,10 @@ namespace System.Net.Http http2Stream?.OnHeadersStart(); _hpackDecoder.Decode( - GetFrameData(_incomingBuffer.ActiveSpan.Slice(0, frameHeader.Length), frameHeader.PaddedFlag, frameHeader.PriorityFlag), + GetFrameData(_incomingBuffer.ActiveSpan.Slice(0, frameHeader.PayloadLength), frameHeader.PaddedFlag, frameHeader.PriorityFlag), frameHeader.EndHeadersFlag, http2Stream); - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); while (!frameHeader.EndHeadersFlag) { @@ -380,14 +382,14 @@ namespace System.Net.Http if (frameHeader.Type != FrameType.Continuation || frameHeader.StreamId != streamId) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } _hpackDecoder.Decode( - _incomingBuffer.ActiveSpan.Slice(0, frameHeader.Length), + _incomingBuffer.ActiveSpan.Slice(0, frameHeader.PayloadLength), frameHeader.EndHeadersFlag, http2Stream); - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); } _hpackDecoder.CompleteDecode(); @@ -404,7 +406,7 @@ namespace System.Net.Http { if (frameData.Length == 0) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } int padLength = frameData[0]; @@ -412,7 +414,7 @@ namespace System.Net.Http if (frameData.Length < padLength) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } frameData = frameData.Slice(0, frameData.Length - padLength); @@ -422,7 +424,7 @@ namespace System.Net.Http { if (frameData.Length < FrameHeader.PriorityInfoLength) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } // We ignore priority info. @@ -443,7 +445,7 @@ namespace System.Net.Http if (NetEventSource.IsEnabled) Trace($"{frameHeader}"); Debug.Assert(frameHeader.Type == FrameType.AltSvc); - ReadOnlySpan span = _incomingBuffer.ActiveSpan.Slice(0, frameHeader.Length); + ReadOnlySpan span = _incomingBuffer.ActiveSpan.Slice(0, frameHeader.PayloadLength); if (BinaryPrimitives.TryReadUInt16BigEndian(span, out ushort originLength)) { @@ -464,7 +466,7 @@ namespace System.Net.Http } } - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); } private void ProcessDataFrame(FrameHeader frameHeader) @@ -476,7 +478,7 @@ namespace System.Net.Http // Note, http2Stream will be null if this is a closed stream. // Just ignore the frame in this case. - ReadOnlySpan frameData = GetFrameData(_incomingBuffer.ActiveSpan.Slice(0, frameHeader.Length), hasPad: frameHeader.PaddedFlag, hasPriority: false); + ReadOnlySpan frameData = GetFrameData(_incomingBuffer.ActiveSpan.Slice(0, frameHeader.PayloadLength), hasPad: frameHeader.PaddedFlag, hasPriority: false); if (http2Stream != null) { @@ -490,7 +492,7 @@ namespace System.Net.Http ExtendWindow(frameData.Length); } - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); } private void ProcessSettingsFrame(FrameHeader frameHeader) @@ -499,19 +501,19 @@ namespace System.Net.Http if (frameHeader.StreamId != 0) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } if (frameHeader.AckFlag) { - if (frameHeader.Length != 0) + if (frameHeader.PayloadLength != 0) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.FrameSizeError); + ThrowProtocolError(Http2ProtocolErrorCode.FrameSizeError); } if (!_expectingSettingsAck) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } // We only send SETTINGS once initially, so we don't need to do anything in response to the ACK. @@ -520,13 +522,13 @@ namespace System.Net.Http } else { - if ((frameHeader.Length % 6) != 0) + if ((frameHeader.PayloadLength % 6) != 0) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.FrameSizeError); + ThrowProtocolError(Http2ProtocolErrorCode.FrameSizeError); } // Parse settings and process the ones we care about. - ReadOnlySpan settings = _incomingBuffer.ActiveSpan.Slice(0, frameHeader.Length); + ReadOnlySpan settings = _incomingBuffer.ActiveSpan.Slice(0, frameHeader.PayloadLength); while (settings.Length > 0) { Debug.Assert((settings.Length % 6) == 0); @@ -545,7 +547,7 @@ namespace System.Net.Http case SettingId.InitialWindowSize: if (settingValue > 0x7FFFFFFF) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.FlowControlError); + ThrowProtocolError(Http2ProtocolErrorCode.FlowControlError); } ChangeInitialWindowSize((int)settingValue); @@ -554,7 +556,7 @@ namespace System.Net.Http case SettingId.MaxFrameSize: if (settingValue < 16384 || settingValue > 16777215) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } // We don't actually store this value; we always send frames of the minimum size (16K). @@ -567,7 +569,7 @@ namespace System.Net.Http } } - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); // Send acknowledgement // Don't wait for completion, which could happen asynchronously. @@ -611,14 +613,14 @@ namespace System.Net.Http { Debug.Assert(frameHeader.Type == FrameType.Priority); - if (frameHeader.StreamId == 0 || frameHeader.Length != FrameHeader.PriorityInfoLength) + if (frameHeader.StreamId == 0 || frameHeader.PayloadLength != FrameHeader.PriorityInfoLength) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } // Ignore priority info. - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); } private void ProcessPingFrame(FrameHeader frameHeader) @@ -627,18 +629,18 @@ namespace System.Net.Http if (frameHeader.StreamId != 0) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } if (frameHeader.AckFlag) { // We never send PING, so an ACK indicates a protocol error - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } - if (frameHeader.Length != FrameHeader.PingLength) + if (frameHeader.PayloadLength != FrameHeader.PingLength) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.FrameSizeError); + ThrowProtocolError(Http2ProtocolErrorCode.FrameSizeError); } // We don't wait for SendPingAckAsync to complete before discarding @@ -650,16 +652,16 @@ namespace System.Net.Http LogExceptions(SendPingAckAsync(pingContentLong)); - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); } private void ProcessWindowUpdateFrame(FrameHeader frameHeader) { Debug.Assert(frameHeader.Type == FrameType.WindowUpdate); - if (frameHeader.Length != FrameHeader.WindowUpdateLength) + if (frameHeader.PayloadLength != FrameHeader.WindowUpdateLength) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.FrameSizeError); + ThrowProtocolError(Http2ProtocolErrorCode.FrameSizeError); } int amount = BinaryPrimitives.ReadInt32BigEndian(_incomingBuffer.ActiveSpan) & 0x7FFFFFFF; @@ -668,10 +670,10 @@ namespace System.Net.Http Debug.Assert(amount >= 0); if (amount == 0) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); if (frameHeader.StreamId == 0) { @@ -694,28 +696,28 @@ namespace System.Net.Http { Debug.Assert(frameHeader.Type == FrameType.RstStream); - if (frameHeader.Length != FrameHeader.RstStreamLength) + if (frameHeader.PayloadLength != FrameHeader.RstStreamLength) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.FrameSizeError); + ThrowProtocolError(Http2ProtocolErrorCode.FrameSizeError); } if (frameHeader.StreamId == 0) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } Http2Stream? http2Stream = GetStream(frameHeader.StreamId); if (http2Stream == null) { // Ignore invalid stream ID, as per RFC - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); return; } var protocolError = (Http2ProtocolErrorCode)BinaryPrimitives.ReadInt32BigEndian(_incomingBuffer.ActiveSpan); if (NetEventSource.IsEnabled) Trace(frameHeader.StreamId, $"{nameof(protocolError)}={protocolError}"); - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); if (protocolError == Http2ProtocolErrorCode.RefusedStream) { @@ -731,16 +733,16 @@ namespace System.Net.Http { Debug.Assert(frameHeader.Type == FrameType.GoAway); - if (frameHeader.Length < FrameHeader.GoAwayMinLength) + if (frameHeader.PayloadLength < FrameHeader.GoAwayMinLength) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.FrameSizeError); + ThrowProtocolError(Http2ProtocolErrorCode.FrameSizeError); } // GoAway frames always apply to the whole connection, never to a single stream. // According to RFC 7540 section 6.8, this should be a connection error. if (frameHeader.StreamId != 0) { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + ThrowProtocolError(); } int lastValidStream = (int)(BinaryPrimitives.ReadUInt32BigEndian(_incomingBuffer.ActiveSpan) & 0x7FFFFFFF); @@ -749,7 +751,7 @@ namespace System.Net.Http StartTerminatingConnection(lastValidStream, new Http2ConnectionException(errorCode)); - _incomingBuffer.Discard(frameHeader.Length); + _incomingBuffer.Discard(frameHeader.PayloadLength); } internal async Task FlushAsync(CancellationToken cancellationToken = default) @@ -800,7 +802,7 @@ namespace System.Net.Http if (_abortException != null) { _writerLock.Exit(); - throw new IOException(SR.net_http_request_aborted, _abortException); + ThrowRequestAborted(_abortException); } // Flush anything necessary, and return back the write buffer to use. @@ -891,58 +893,12 @@ namespace System.Net.Http } } - private async ValueTask AcquireWriteLockAsync(CancellationToken cancellationToken) - { - ValueTask acquireLockTask = _writerLock.EnterAsync(cancellationToken); - if (acquireLockTask.IsCompletedSuccessfully) - { - acquireLockTask.GetAwaiter().GetResult(); // to enable the value task sources to be pooled - } - else - { - Interlocked.Increment(ref _pendingWriters); - - try - { - await acquireLockTask.ConfigureAwait(false); - } - catch - { - if (Interlocked.Decrement(ref _pendingWriters) == 0) - { - // If a pending waiter is canceled, we may end up in a situation where a previously written frame - // saw that there were pending writers and as such deferred its flush to them, but if/when that pending - // writer is canceled, nothing may end up flushing the deferred work (at least not promptly). To compensate, - // if a pending writer does end up being canceled, we flush asynchronously. We can't check whether there's such - // a pending operation because we failed to acquire the lock that protects that state. But we can at least only - // do the flush if our decrement caused the pending count to reach 0: if it's still higher than zero, then there's - // at least one other pending writer who can handle the flush. Worst case, we pay for a flush that ends up being - // a nop. Note: we explicitly do not pass in the cancellationToken; if we're here, it's almost certainly because - // cancellation was requested, and it's because of that cancellation that we need to flush. - LogExceptions(FlushAsync(cancellationToken: default)); - } - - throw; - } - - Interlocked.Decrement(ref _pendingWriters); - } - - // If the connection has been aborted, then fail now instead of trying to send more data. - if (_abortException != null) - { - _writerLock.Exit(); - throw new IOException(SR.net_http_request_aborted, _abortException); - } - } - private async Task SendSettingsAckAsync() { Memory writeBuffer = await StartWriteAsync(FrameHeader.Size).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace("Started writing."); - FrameHeader frameHeader = new FrameHeader(0, FrameType.Settings, FrameFlags.Ack, 0); - frameHeader.WriteTo(writeBuffer); + FrameHeader.WriteTo(writeBuffer.Span, 0, FrameType.Settings, FrameFlags.Ack, streamId: 0); FinishWrite(FlushTiming.AfterPendingWrites); } @@ -953,12 +909,9 @@ namespace System.Net.Http Memory writeBuffer = await StartWriteAsync(FrameHeader.Size + FrameHeader.PingLength).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace("Started writing."); - FrameHeader frameHeader = new FrameHeader(FrameHeader.PingLength, FrameType.Ping, FrameFlags.Ack, 0); - frameHeader.WriteTo(writeBuffer); - writeBuffer = writeBuffer.Slice(FrameHeader.Size); - Debug.Assert(sizeof(long) == FrameHeader.PingLength); - BinaryPrimitives.WriteInt64BigEndian(writeBuffer.Span, pingContent); + FrameHeader.WriteTo(writeBuffer.Span, FrameHeader.PingLength, FrameType.Ping, FrameFlags.Ack, streamId: 0); + BinaryPrimitives.WriteInt64BigEndian(writeBuffer.Span.Slice(FrameHeader.Size), pingContent); FinishWrite(FlushTiming.AfterPendingWrites); } @@ -968,11 +921,8 @@ namespace System.Net.Http Memory writeBuffer = await StartWriteAsync(FrameHeader.Size + FrameHeader.RstStreamLength).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace(streamId, $"Started writing. {nameof(errorCode)}={errorCode}"); - FrameHeader frameHeader = new FrameHeader(FrameHeader.RstStreamLength, FrameType.RstStream, FrameFlags.None, streamId); - frameHeader.WriteTo(writeBuffer); - writeBuffer = writeBuffer.Slice(FrameHeader.Size); - - BinaryPrimitives.WriteInt32BigEndian(writeBuffer.Span, (int)errorCode); + FrameHeader.WriteTo(writeBuffer.Span, FrameHeader.RstStreamLength, FrameType.RstStream, FrameFlags.None, streamId); + BinaryPrimitives.WriteInt32BigEndian(writeBuffer.Span.Slice(FrameHeader.Size), (int)errorCode); FinishWrite(FlushTiming.Now); // ensure cancellation is seen as soon as possible } @@ -1206,7 +1156,8 @@ namespace System.Net.Http } } - private HttpRequestException GetShutdownException() + [DoesNotReturn] + private void ThrowShutdownException() { Debug.Assert(Monitor.IsEntered(SyncObject)); @@ -1233,7 +1184,7 @@ namespace System.Net.Http innerException = new ObjectDisposedException(nameof(Http2Connection)); } - return new HttpRequestException(SR.net_http_client_execution_error, innerException, allowRetry: RequestRetryType.RetryOnSameOrNextProxy); + ThrowRetry(SR.net_http_client_execution_error, innerException); } private async ValueTask SendHeadersAsync(HttpRequestMessage request, CancellationToken cancellationToken, bool mustFlush) @@ -1261,7 +1212,8 @@ namespace System.Net.Http { Debug.Assert(_disposed || _lastStreamId != -1); Debug.Assert(_httpStreams.Count == 0); - throw GetShutdownException(); + ThrowShutdownException(); + throw; // unreachable } } @@ -1276,15 +1228,20 @@ namespace System.Net.Http Debug.Assert(remaining.Length > 0); // Calculate the total number of bytes we're going to use (content + headers). - int frameCount = ((remaining.Length - 1) / FrameHeader.MaxLength) + 1; + int frameCount = ((remaining.Length - 1) / FrameHeader.MaxPayloadLength) + 1; int totalSize = remaining.Length + (frameCount * FrameHeader.Size); ReadOnlyMemory current; - (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength); + (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxPayloadLength); FrameFlags flags = (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None) | (request.Content == null ? FrameFlags.EndStream : FrameFlags.None); + // Construct and initialize the new Http2Stream instance. It's stream ID must be set below + // before the instance is used and stored into the dictionary. However, we construct it here + // so as to avoid the allocation and initialization expense while holding multiple locks. + var http2Stream = new Http2Stream(request, this, _initialWindowSize); + // Start the write. This serializes access to write to the connection, and ensures that HEADERS // and CONTINUATION frames stay together, as they must do. We use the lock as well to ensure new // streams are created and started in order. @@ -1293,7 +1250,6 @@ namespace System.Net.Http { // Allocate the next available stream ID. Note that if we fail before sending the headers, // we'll just skip this stream ID, which is fine. - int streamId; lock (SyncObject) { if (_nextStream == MaxStreamId || _disposed || _lastStreamId != -1) @@ -1301,44 +1257,43 @@ namespace System.Net.Http // We ran out of stream IDs or we raced between acquiring the connection from the pool and shutting down. // Throw a retryable request exception. This will cause retry logic to kick in // and perform another connection attempt. The user should never see this exception. - throw GetShutdownException(); + ThrowShutdownException(); } - streamId = _nextStream; - // Client-initiated streams are always odd-numbered, so increase by 2. + http2Stream.StreamId = _nextStream; _nextStream += 2; + + // We're about to flush the HEADERS frame, so add the stream to the dictionary now. + // The lifetime of the stream is now controlled by the stream itself and the connection. + // This can fail if the connection is shutting down, in which case we will cancel sending this frame. + _httpStreams.Add(http2Stream.StreamId, http2Stream); } - if (NetEventSource.IsEnabled) Trace(streamId, $"Started writing. {nameof(totalSize)}={totalSize}"); + if (NetEventSource.IsEnabled) Trace(http2Stream.StreamId, $"Started writing. {nameof(totalSize)}={totalSize}"); // Copy the HEADERS frame. - new FrameHeader(current.Length, FrameType.Headers, flags, streamId).WriteTo(writeBuffer.Span); + FrameHeader.WriteTo(writeBuffer.Span, current.Length, FrameType.Headers, flags, http2Stream.StreamId); writeBuffer = writeBuffer.Slice(FrameHeader.Size); current.CopyTo(writeBuffer); writeBuffer = writeBuffer.Slice(current.Length); - if (NetEventSource.IsEnabled) Trace(streamId, $"Wrote HEADERS frame. Length={current.Length}, flags={flags}"); + if (NetEventSource.IsEnabled) Trace(http2Stream.StreamId, $"Wrote HEADERS frame. Length={current.Length}, flags={flags}"); // Copy CONTINUATION frames, if any. while (remaining.Length > 0) { - (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength); + (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxPayloadLength); flags = remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None; - new FrameHeader(current.Length, FrameType.Continuation, flags, streamId).WriteTo(writeBuffer.Span); + FrameHeader.WriteTo(writeBuffer.Span, current.Length, FrameType.Continuation, flags, http2Stream.StreamId); writeBuffer = writeBuffer.Slice(FrameHeader.Size); current.CopyTo(writeBuffer); writeBuffer = writeBuffer.Slice(current.Length); - if (NetEventSource.IsEnabled) Trace(streamId, $"Wrote CONTINUATION frame. Length={current.Length}, flags={flags}"); + if (NetEventSource.IsEnabled) Trace(http2Stream.StreamId, $"Wrote CONTINUATION frame. Length={current.Length}, flags={flags}"); } Debug.Assert(writeBuffer.Length == 0); - // We're about to flush the HEADERS frame, so add the stream to the dictionary now. - // The lifetime of the stream is now controlled by the stream itself and the connection. - // This can fail if the connection is shutting down, in which case we will cancel sending this frame. - Http2Stream http2Stream = AddStream(streamId, request); - FinishWrite(mustFlush || (flags & FrameFlags.EndStream) != 0 ? FlushTiming.AfterPendingWrites : FlushTiming.Eventually); return http2Stream; } @@ -1365,7 +1320,7 @@ namespace System.Net.Http while (remaining.Length > 0) { - int frameSize = Math.Min(remaining.Length, FrameHeader.MaxLength); + int frameSize = Math.Min(remaining.Length, FrameHeader.MaxPayloadLength); // Once credit had been granted, we want to actually consume those bytes. frameSize = await _connectionWindow.RequestCreditAsync(frameSize, cancellationToken).ConfigureAwait(false); @@ -1387,14 +1342,8 @@ namespace System.Net.Http throw; } - FrameHeader frameHeader = new FrameHeader(current.Length, FrameType.Data, FrameFlags.None, streamId); - frameHeader.WriteTo(writeBuffer); - writeBuffer = writeBuffer.Slice(FrameHeader.Size); - - current.CopyTo(writeBuffer); - writeBuffer = writeBuffer.Slice(current.Length); - - Debug.Assert(writeBuffer.Length == 0); + FrameHeader.WriteTo(writeBuffer.Span, current.Length, FrameType.Data, FrameFlags.None, streamId); + current.CopyTo(writeBuffer.Slice(FrameHeader.Size)); FinishWrite(FlushTiming.Eventually); // no need to flush, as the request content may do so explicitly, or worst case we'll do so as part of the end data frame } @@ -1405,8 +1354,7 @@ namespace System.Net.Http Memory writeBuffer = await StartWriteAsync(FrameHeader.Size).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace(streamId, "Started writing."); - FrameHeader frameHeader = new FrameHeader(0, FrameType.Data, FrameFlags.EndStream, streamId); - frameHeader.WriteTo(writeBuffer); + FrameHeader.WriteTo(writeBuffer.Span, 0, FrameType.Data, FrameFlags.EndStream, streamId); FinishWrite(FlushTiming.AfterPendingWrites); // finished sending request body, so flush soon (but ok to wait for pending packets) } @@ -1419,11 +1367,8 @@ namespace System.Net.Http Memory writeBuffer = await StartWriteAsync(FrameHeader.Size + FrameHeader.WindowUpdateLength).ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace(streamId, $"Started writing. {nameof(amount)}={amount}"); - FrameHeader frameHeader = new FrameHeader(FrameHeader.WindowUpdateLength, FrameType.WindowUpdate, FrameFlags.None, streamId); - frameHeader.WriteTo(writeBuffer); - writeBuffer = writeBuffer.Slice(FrameHeader.Size); - - BinaryPrimitives.WriteInt32BigEndian(writeBuffer.Span, amount); + FrameHeader.WriteTo(writeBuffer.Span, FrameHeader.WindowUpdateLength, FrameType.WindowUpdate, FrameFlags.None, streamId); + BinaryPrimitives.WriteInt32BigEndian(writeBuffer.Span.Slice(FrameHeader.Size), amount); FinishWrite(FlushTiming.Now); // make sure window updates are seen as soon as possible } @@ -1452,15 +1397,6 @@ namespace System.Net.Http LogExceptions(SendWindowUpdateAsync(0, windowUpdateSize)); } - private void WriteFrameHeader(FrameHeader frameHeader) - { - if (NetEventSource.IsEnabled) Trace($"{frameHeader}"); - Debug.Assert(_outgoingBuffer.AvailableMemory.Length >= FrameHeader.Size); - - frameHeader.WriteTo(_outgoingBuffer.AvailableSpan); - _outgoingBuffer.Commit(FrameHeader.Size); - } - /// Abort all streams and cause further processing to fail. /// Exception causing Abort to be called. private void Abort(Exception abortException) @@ -1648,15 +1584,15 @@ namespace System.Net.Http Last = 10 } - private struct FrameHeader + private readonly struct FrameHeader { - public int Length; - public FrameType Type; - public FrameFlags Flags; - public int StreamId; + public readonly int PayloadLength; + public readonly FrameType Type; + public readonly FrameFlags Flags; + public readonly int StreamId; public const int Size = 9; - public const int MaxLength = 16384; + public const int MaxPayloadLength = 16384; public const int SettingLength = 6; // per setting (total SETTINGS length must be a multiple of this) public const int PriorityInfoLength = 5; // for both PRIORITY frame and priority info within HEADERS @@ -1665,11 +1601,11 @@ namespace System.Net.Http public const int RstStreamLength = 4; public const int GoAwayMinLength = 8; - public FrameHeader(int length, FrameType type, FrameFlags flags, int streamId) + public FrameHeader(int payloadLength, FrameType type, FrameFlags flags, int streamId) { Debug.Assert(streamId >= 0); - Length = length; + PayloadLength = payloadLength; Type = type; Flags = flags; StreamId = streamId; @@ -1685,36 +1621,31 @@ namespace System.Net.Http { Debug.Assert(buffer.Length >= Size); - return new FrameHeader( - (buffer[0] << 16) | (buffer[1] << 8) | buffer[2], - (FrameType)buffer[3], - (FrameFlags)buffer[4], - (int)((uint)((buffer[5] << 24) | (buffer[6] << 16) | (buffer[7] << 8) | buffer[8]) & 0x7FFFFFFF)); + FrameFlags flags = (FrameFlags)buffer[4]; // do first to avoid some bounds checks + int payloadLength = (buffer[0] << 16) | (buffer[1] << 8) | buffer[2]; + FrameType type = (FrameType)buffer[3]; + int streamId = (int)(BinaryPrimitives.ReadUInt32BigEndian(buffer.Slice(5)) & 0x7FFFFFFF); + + return new FrameHeader(payloadLength, type, flags, streamId); } - public void WriteTo(Span buffer) + public static void WriteTo(Span destination, int payloadLength, FrameType type, FrameFlags flags, int streamId) { - Debug.Assert(buffer.Length >= Size); - Debug.Assert(Type <= FrameType.Last); - Debug.Assert((Flags & FrameFlags.ValidBits) == Flags); - Debug.Assert(Length <= MaxLength); + Debug.Assert(destination.Length >= Size); + Debug.Assert(type <= FrameType.Last); + Debug.Assert((flags & FrameFlags.ValidBits) == flags); + Debug.Assert((uint)payloadLength <= MaxPayloadLength); - buffer[0] = (byte)((Length & 0x00FF0000) >> 16); - buffer[1] = (byte)((Length & 0x0000FF00) >> 8); - buffer[2] = (byte)(Length & 0x000000FF); - - buffer[3] = (byte)Type; - buffer[4] = (byte)Flags; - - buffer[5] = (byte)((StreamId & 0xFF000000) >> 24); - buffer[6] = (byte)((StreamId & 0x00FF0000) >> 16); - buffer[7] = (byte)((StreamId & 0x0000FF00) >> 8); - buffer[8] = (byte)(StreamId & 0x000000FF); + // This ordering helps eliminate bounds checks. + BinaryPrimitives.WriteInt32BigEndian(destination.Slice(5), streamId); + destination[4] = (byte)flags; + destination[0] = (byte)((payloadLength & 0x00FF0000) >> 16); + destination[1] = (byte)((payloadLength & 0x0000FF00) >> 8); + destination[2] = (byte)(payloadLength & 0x000000FF); + destination[3] = (byte)type; } - public void WriteTo(Memory buffer) => WriteTo(buffer.Span); - - public override string ToString() => $"StreamId={StreamId}; Type={Type}; Flags={Flags}; Length={Length}"; // Description for diagnostic purposes + public override string ToString() => $"StreamId={StreamId}; Type={Type}; Flags={Flags}; PayloadLength={PayloadLength}"; // Description for diagnostic purposes } [Flags] @@ -1835,26 +1766,6 @@ namespace System.Net.Http } } - private Http2Stream AddStream(int streamId, HttpRequestMessage request) - { - lock (SyncObject) - { - if (_disposed || _lastStreamId != -1) - { - // The connection is shutting down. - // Throw a retryable request exception. This will cause retry logic to kick in - // and perform another connection attempt. The user should never see this exception. - throw GetShutdownException(); - } - - Http2Stream http2Stream = new Http2Stream(request, this, streamId, _initialWindowSize); - - _httpStreams.Add(streamId, http2Stream); - - return http2Stream; - } - } - private void RemoveStream(Http2Stream http2Stream) { if (NetEventSource.IsEnabled) Trace(http2Stream.StreamId, ""); @@ -1862,27 +1773,25 @@ namespace System.Net.Http lock (SyncObject) { - if (!_httpStreams.Remove(http2Stream.StreamId, out Http2Stream? removed)) + if (!_httpStreams.Remove(http2Stream.StreamId)) { Debug.Fail($"Stream {http2Stream.StreamId} not found in dictionary during RemoveStream???"); return; } - _concurrentStreams.AdjustCredit(1); - - Debug.Assert(removed == http2Stream, "_httpStreams.TryRemove returned unexpected stream"); - if (_httpStreams.Count == 0) { // If this was last pending request, get timestamp so we can monitor idle time. _idleSinceTickCount = Environment.TickCount64; - } - if (_disposed || _lastStreamId != -1) - { - CheckForShutdown(); + if (_disposed || _lastStreamId != -1) + { + CheckForShutdown(); + } } } + + _concurrentStreams.AdjustCredit(1); } public sealed override string ToString() => $"{nameof(Http2Connection)}({_pool})"; // Description for diagnostic purposes @@ -1898,5 +1807,20 @@ namespace System.Net.Http memberName, // method name message); // message + [DoesNotReturn] + private static void ThrowRetry(string message, Exception innerException) => + throw new HttpRequestException(message, innerException, allowRetry: RequestRetryType.RetryOnSameOrNextProxy); + + [DoesNotReturn] + private static void ThrowRequestAborted(Exception? innerException = null) => + throw new IOException(SR.net_http_request_aborted, innerException); + + [DoesNotReturn] + private static void ThrowProtocolError() => + ThrowProtocolError(Http2ProtocolErrorCode.ProtocolError); + + [DoesNotReturn] + private static void ThrowProtocolError(Http2ProtocolErrorCode errorCode) => + throw new Http2ConnectionException(errorCode); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 6c66a67..146a48d 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -18,7 +18,7 @@ namespace System.Net.Http { internal sealed partial class Http2Connection { - private sealed class Http2Stream : IValueTaskSource, IHttpTrace, IHttpHeadersHandler + private sealed class Http2Stream : IValueTaskSource, IHttpHeadersHandler, IHttpTrace { private const int InitialStreamBufferSize = #if DEBUG @@ -27,14 +27,13 @@ namespace System.Net.Http 1024; #endif - private static readonly byte[] s_statusHeaderName = Encoding.ASCII.GetBytes(":status"); + private static ReadOnlySpan StatusHeaderName => new byte[] { (byte)':', (byte)'s', (byte)'t', (byte)'a', (byte)'t', (byte)'u', (byte)'s' }; private readonly Http2Connection _connection; - private readonly int _streamId; private readonly HttpRequestMessage _request; private HttpResponseMessage? _response; /// Stores any trailers received after returning the response content to the caller. - private List>? _trailers; + private HttpResponseHeaders? _trailers; private ArrayBuffer _responseBuffer; // mutable struct, do not make this readonly private int _pendingWindowUpdate; @@ -81,9 +80,6 @@ namespace System.Net.Http private readonly CancellationTokenSource? _requestBodyCancellationSource; - // This is a linked token combining the above source and the user-supplied token to SendRequestBodyAsync - private CancellationToken _requestBodyCancellationToken; - private readonly TaskCompletionSource? _expect100ContinueWaiter; private int _headerBudgetRemaining; @@ -93,11 +89,10 @@ namespace System.Net.Http // See comment on ConnectionWindowThreshold. private const int StreamWindowThreshold = StreamWindowSize / 8; - public Http2Stream(HttpRequestMessage request, Http2Connection connection, int streamId, int initialWindowSize) + public Http2Stream(HttpRequestMessage request, Http2Connection connection, int initialWindowSize) { _request = request; _connection = connection; - _streamId = streamId; _requestCompletionState = StreamCompletionState.InProgress; _responseCompletionState = StreamCompletionState.InProgress; @@ -134,12 +129,19 @@ namespace System.Net.Http } } + _response = new HttpResponseMessage() + { + Version = HttpVersion.Version20, + RequestMessage = _request, + Content = new HttpConnectionResponseContent() + }; + if (NetEventSource.IsEnabled) Trace($"{request}, {nameof(initialWindowSize)}={initialWindowSize}"); } private object SyncObject => this; // this isn't handed out to code that may lock on it - public int StreamId => _streamId; + public int StreamId { get; set; } public HttpResponseMessage GetAndClearResponse() { @@ -162,28 +164,44 @@ namespace System.Net.Http } if (NetEventSource.IsEnabled) Trace($"{_request.Content}"); - Debug.Assert(_requestBodyCancellationSource != null); - // Create a linked cancellation token source so that we can cancel the request in the event of receiving RST_STREAM - // and similiar situations where we need to cancel the request body (see Cancel method). - _requestBodyCancellationToken = cancellationToken.CanBeCanceled ? - CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _requestBodyCancellationSource.Token).Token : - _requestBodyCancellationSource.Token; - + // Cancel the request body sending if cancellation is requested on the supplied cancellation token. + // Normally we might create a linked token, but once cancellation is requested, we can't recover anyway, + // so it's fine to cancel the source representing the whole request body, and doing so allows us to avoid + // creating another CTS instance and the associated nodes inside of it. With this, cancellation will be + // requested on _requestBodyCancellationSource when we need to cancel the request stream for any reason, + // such as receiving an RST_STREAM or when the passed in token has cancellation requested. However, to + // avoid unnecessarily registering with the cancellation token unless we have to, we wait to do so until + // either we know we need to do a Expect: 100-continue send or until we know that the copying of our + // content completed asynchronously. + CancellationTokenRegistration linkedRegistration = default; try { bool sendRequestContent = true; if (_expect100ContinueWaiter != null) { - sendRequestContent = await WaitFor100ContinueAsync(_requestBodyCancellationToken).ConfigureAwait(false); + linkedRegistration = RegisterRequestBodyCancellation(cancellationToken); + sendRequestContent = await WaitFor100ContinueAsync(_requestBodyCancellationSource.Token).ConfigureAwait(false); } if (sendRequestContent) { - using (Http2WriteStream writeStream = new Http2WriteStream(this)) + using var writeStream = new Http2WriteStream(this); + + ValueTask vt = _request.Content.InternalCopyToAsync(writeStream, context: null, _requestBodyCancellationSource.Token); + if (vt.IsCompleted) { - await _request.Content.InternalCopyToAsync(writeStream, null, _requestBodyCancellationToken).ConfigureAwait(false); + vt.GetAwaiter().GetResult(); + } + else + { + if (linkedRegistration.Equals(default)) + { + linkedRegistration = RegisterRequestBodyCancellation(cancellationToken); + } + + await vt.ConfigureAwait(false); } } @@ -192,9 +210,7 @@ namespace System.Net.Http catch (Exception e) { if (NetEventSource.IsEnabled) Trace($"Failed to send request body: {e}"); - - bool signalWaiter = false; - bool sendReset = false; + bool signalWaiter; Debug.Assert(!Monitor.IsEntered(SyncObject)); lock (SyncObject) @@ -213,19 +229,15 @@ namespace System.Net.Http } // This should not cause RST_STREAM to be sent because the request is still marked as in progress. + bool sendReset; (signalWaiter, sendReset) = CancelResponseBody(); Debug.Assert(!sendReset); _requestCompletionState = StreamCompletionState.Failed; - sendReset = true; Complete(); } - if (sendReset) - { - SendReset(); - } - + SendReset(); if (signalWaiter) { _waitSource.SetResult(true); @@ -233,6 +245,10 @@ namespace System.Net.Http throw; } + finally + { + linkedRegistration.Dispose(); + } // New scope here to avoid variable name conflict on "sendReset" { @@ -241,21 +257,14 @@ namespace System.Net.Http lock (SyncObject) { Debug.Assert(_requestCompletionState == StreamCompletionState.InProgress, $"Request already completed with state={_requestCompletionState}"); - _requestCompletionState = StreamCompletionState.Completed; - if (_responseCompletionState == StreamCompletionState.Failed) + + if (_responseCompletionState != StreamCompletionState.InProgress) { // Note, we can reach this point if the response stream failed but cancellation didn't propagate before we finished. - sendReset = true; + sendReset = _responseCompletionState == StreamCompletionState.Failed; Complete(); } - else - { - if (_responseCompletionState == StreamCompletionState.Completed) - { - Complete(); - } - } } if (sendReset) @@ -266,7 +275,7 @@ namespace System.Net.Http { // Send EndStream asynchronously and without cancellation. // If this fails, it means that the connection is aborting and we will be reset. - _connection.LogExceptions(_connection.SendEndStreamAsync(_streamId)); + _connection.LogExceptions(_connection.SendEndStreamAsync(StreamId)); } } } @@ -277,29 +286,31 @@ namespace System.Net.Http // If we get response status >= 300, we will not send the request body. public async ValueTask WaitFor100ContinueAsync(CancellationToken cancellationToken) { - Debug.Assert(_request.Content != null); + Debug.Assert(_request?.Content != null); if (NetEventSource.IsEnabled) Trace($"Waiting to send request body content for 100-Continue."); - // use TCS created in constructor. It will complete when one of two things occurs: - // 1. if a timer fires before we receive the relevant response from the server. - // 2. if we receive the relevant response from the server before a timer fires. - // In the first case, we could run this continuation synchronously, but in the latter, we shouldn't, - // as we could end up starting the body copy operation on the main event loop thread, which could - // then starve the processing of other requests. So, we make the TCS RunContinuationsAsynchronously. - bool sendRequestContent; + // Use TCS created in constructor. It will complete when one of three things occurs: + // 1. we receive the relevant response from the server. + // 2. the timer fires before we receive the relevant response from the server. + // 3. cancellation is requested before we receive the relevant response from the server. + // We need to run the continuation asynchronously for cases 1 and 3 (for 1 so that we don't starve the body copy operation, and + // for 3 so that we don't run a lot of work as part of code calling Cancel), so the TCS is created to run continuations asynchronously. + // We await the created Timer's disposal so that we ensure any work associated with it has quiesced prior to this method + // returning, just in case this object is pooled and potentially reused for another operation in the future. TaskCompletionSource waiter = _expect100ContinueWaiter!; - using (var expect100Timer = new Timer(s => + using (cancellationToken.UnsafeRegister(s => ((TaskCompletionSource)s!).TrySetResult(false), waiter)) + await using (new Timer(s => { var thisRef = (Http2Stream)s!; if (NetEventSource.IsEnabled) thisRef.Trace($"100-Continue timer expired."); thisRef._expect100ContinueWaiter?.TrySetResult(true); - }, this, _connection._pool.Settings._expect100ContinueTimeout, Timeout.InfiniteTimeSpan)) + }, this, _connection._pool.Settings._expect100ContinueTimeout, Timeout.InfiniteTimeSpan).ConfigureAwait(false)) { - sendRequestContent = await waiter.Task.ConfigureAwait(false); - // By now, either we got a response from the server or the timer expired. + bool shouldSendContent = await waiter.Task.ConfigureAwait(false); + // By now, either we got a response from the server or the timer expired or cancellation was requested. + CancellationHelper.ThrowIfCancellationRequested(cancellationToken); + return shouldSendContent; } - - return sendRequestContent; } private void SendReset() @@ -315,7 +326,7 @@ namespace System.Net.Http // Don't send a RST_STREAM if we've already received one from the server. if (_resetException == null) { - _connection.LogExceptions(_connection.SendRstStreamAsync(_streamId, Http2ProtocolErrorCode.Cancel)); + _connection.LogExceptions(_connection.SendRstStreamAsync(StreamId, Http2ProtocolErrorCode.Cancel)); } } @@ -357,11 +368,8 @@ namespace System.Net.Http (signalWaiter, sendReset) = CancelResponseBody(); } - if (requestBodyCancellationSource != null) - { - // When cancellation propagates, SendRequestBodyAsync will set _requestCompletionState to Failed - requestBodyCancellationSource.Cancel(); - } + // When cancellation propagates, SendRequestBodyAsync will set _requestCompletionState to Failed + requestBodyCancellationSource?.Cancel(); if (sendReset) { @@ -436,7 +444,7 @@ namespace System.Net.Http public void OnHeader(ReadOnlySpan name, ReadOnlySpan value) { if (NetEventSource.IsEnabled) Trace($"{Encoding.ASCII.GetString(name)}: {Encoding.ASCII.GetString(value)}"); - Debug.Assert(name != null && name.Length > 0); + Debug.Assert(name.Length > 0); _headerBudgetRemaining -= name.Length + value.Length; if (_headerBudgetRemaining < 0) @@ -462,7 +470,7 @@ namespace System.Net.Http throw new HttpRequestException(SR.net_http_invalid_response_pseudo_header_in_trailer); } - if (name.SequenceEqual(s_statusHeaderName)) + if (name.SequenceEqual(StatusHeaderName)) { if (_responseProtocolState != ResponseProtocolState.ExpectingStatus) { @@ -471,13 +479,8 @@ namespace System.Net.Http } int statusValue = ParseStatusCode(value); - _response = new HttpResponseMessage() - { - Version = HttpVersion.Version20, - RequestMessage = _request, - Content = new HttpConnectionResponseContent(), - StatusCode = (HttpStatusCode)statusValue - }; + Debug.Assert(_response != null); + _response.StatusCode = (HttpStatusCode)statusValue; if (statusValue < 200) { @@ -537,7 +540,7 @@ namespace System.Net.Http { Debug.Assert(_trailers != null); string headerValue = descriptor.GetHeaderValue(value); - _trailers.Add(KeyValuePair.Create((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue)); + _trailers.TryAddWithoutValidation((descriptor.HeaderType & HttpHeaderType.Request) == HttpHeaderType.Request ? descriptor.AsCustomHeader() : descriptor, headerValue); } else if ((descriptor.HeaderType & HttpHeaderType.Content) == HttpHeaderType.Content) { @@ -560,20 +563,20 @@ namespace System.Net.Http Debug.Assert(!Monitor.IsEntered(SyncObject)); lock (SyncObject) { - if (_responseProtocolState == ResponseProtocolState.Aborted) + switch (_responseProtocolState) { - return; - } - - if (_responseProtocolState != ResponseProtocolState.ExpectingStatus && _responseProtocolState != ResponseProtocolState.ExpectingData) - { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); - } - - if (_responseProtocolState == ResponseProtocolState.ExpectingData) - { - _responseProtocolState = ResponseProtocolState.ExpectingTrailingHeaders; - _trailers ??= new List>(); + case ResponseProtocolState.ExpectingStatus: + case ResponseProtocolState.Aborted: + break; + + case ResponseProtocolState.ExpectingData: + _responseProtocolState = ResponseProtocolState.ExpectingTrailingHeaders; + _trailers ??= new HttpResponseHeaders(containsTrailingHeaders: true); + break; + + default: + ThrowProtocolError(); + break; } } } @@ -584,45 +587,38 @@ namespace System.Net.Http bool signalWaiter; lock (SyncObject) { - if (_responseProtocolState == ResponseProtocolState.Aborted) + switch (_responseProtocolState) { - return; - } + case ResponseProtocolState.Aborted: + return; - if (_responseProtocolState != ResponseProtocolState.ExpectingHeaders && _responseProtocolState != ResponseProtocolState.ExpectingTrailingHeaders && _responseProtocolState != ResponseProtocolState.ExpectingIgnoredHeaders) - { - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); - } + case ResponseProtocolState.ExpectingHeaders: + _responseProtocolState = endStream ? ResponseProtocolState.Complete : ResponseProtocolState.ExpectingData; + break; - if (_responseProtocolState == ResponseProtocolState.ExpectingHeaders) - { - _responseProtocolState = endStream ? ResponseProtocolState.Complete : ResponseProtocolState.ExpectingData; - } - else if (_responseProtocolState == ResponseProtocolState.ExpectingTrailingHeaders) - { - if (!endStream) - { - if (NetEventSource.IsEnabled) Trace("Trailing headers received without endStream"); - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); - } + case ResponseProtocolState.ExpectingTrailingHeaders: + if (!endStream) + { + if (NetEventSource.IsEnabled) Trace("Trailing headers received without endStream"); + ThrowProtocolError(); + } + _responseProtocolState = ResponseProtocolState.Complete; + break; - _responseProtocolState = ResponseProtocolState.Complete; - } - else if (_responseProtocolState == ResponseProtocolState.ExpectingIgnoredHeaders) - { - if (endStream) - { - // we should not get endStream while processing 1xx response. - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); - } + case ResponseProtocolState.ExpectingIgnoredHeaders: + if (endStream) + { + // we should not get endStream while processing 1xx response. + ThrowProtocolError(); + } - _responseProtocolState = ResponseProtocolState.ExpectingStatus; - // We should wait for final response before signaling to waiter. - return; - } - else - { - _responseProtocolState = ResponseProtocolState.ExpectingData; + // We should wait for final response before signaling to waiter. + _responseProtocolState = ResponseProtocolState.ExpectingStatus; + return; + + default: + ThrowProtocolError(); + break; } if (endStream) @@ -656,21 +652,24 @@ namespace System.Net.Http bool signalWaiter; lock (SyncObject) { - if (_responseProtocolState == ResponseProtocolState.Aborted) + switch (_responseProtocolState) { - return; - } + case ResponseProtocolState.ExpectingData: + break; - if (_responseProtocolState != ResponseProtocolState.ExpectingData) - { - // Flow control messages are not valid in this state. - throw new Http2ConnectionException(Http2ProtocolErrorCode.ProtocolError); + case ResponseProtocolState.Aborted: + return; + + default: + // Flow control messages are not valid in this state. + ThrowProtocolError(); + break; } if (_responseBuffer.ActiveLength + buffer.Length > StreamWindowSize) { // Window size exceeded. - throw new Http2ConnectionException(Http2ProtocolErrorCode.FlowControlError); + ThrowProtocolError(Http2ProtocolErrorCode.FlowControlError); } _responseBuffer.EnsureAvailableSpace(buffer.Length); @@ -782,19 +781,19 @@ namespace System.Net.Http { Debug.Assert(Monitor.IsEntered(SyncObject)); - if (_resetException != null) + if (_resetException is Exception resetException) { if (_canRetry) { - throw new HttpRequestException(SR.net_http_request_aborted, _resetException, allowRetry: RequestRetryType.RetryOnSameOrNextProxy); + ThrowRetry(SR.net_http_request_aborted, resetException); } - throw new IOException(SR.net_http_request_aborted, _resetException); + ThrowRequestAborted(resetException); } if (_responseProtocolState == ResponseProtocolState.Aborted) { - throw new IOException(SR.net_http_request_aborted); + ThrowRequestAborted(); } } @@ -856,7 +855,7 @@ namespace System.Net.Http { // If there are any trailers, copy them over to the response. Normally this would be handled by // the response stream hitting EOF, but if there is no response body, we do it here. - CopyTrailersToResponseMessage(_response); + MoveTrailersToResponseMessage(_response); responseContent.SetStream(EmptyReadStream.Instance); } else @@ -892,7 +891,7 @@ namespace System.Net.Http int windowUpdateSize = _pendingWindowUpdate; _pendingWindowUpdate = 0; - _connection.LogExceptions(_connection.SendWindowUpdateAsync(_streamId, windowUpdateSize)); + _connection.LogExceptions(_connection.SendWindowUpdateAsync(StreamId, windowUpdateSize)); } private (bool wait, int bytesRead) TryReadFromBuffer(Span buffer, bool partOfSyncRead = false) @@ -951,7 +950,7 @@ namespace System.Net.Http else { // We've hit EOF. Pull in from the Http2Stream any trailers that were temporarily stored there. - CopyTrailersToResponseMessage(responseMessage); + MoveTrailersToResponseMessage(responseMessage); } return bytesRead; @@ -980,7 +979,7 @@ namespace System.Net.Http else { // We've hit EOF. Pull in from the Http2Stream any trailers that were temporarily stored there. - CopyTrailersToResponseMessage(responseMessage); + MoveTrailersToResponseMessage(responseMessage); } return bytesRead; @@ -991,7 +990,7 @@ namespace System.Net.Http byte[] buffer = ArrayPool.Shared.Rent(bufferSize); try { - // Generallly the same logic as in ReadData, but wrapped in a loop where every read segment is written to the destination. + // Generally the same logic as in ReadData, but wrapped in a loop where every read segment is written to the destination. while (true) { (bool wait, int bytesRead) = TryReadFromBuffer(buffer, partOfSyncRead: true); @@ -1011,7 +1010,7 @@ namespace System.Net.Http else { // We've hit EOF. Pull in from the Http2Stream any trailers that were temporarily stored there. - CopyTrailersToResponseMessage(responseMessage); + MoveTrailersToResponseMessage(responseMessage); return; } } @@ -1027,7 +1026,7 @@ namespace System.Net.Http byte[] buffer = ArrayPool.Shared.Rent(bufferSize); try { - // Generallly the same logic as in ReadDataAsync, but wrapped in a loop where every read segment is written to the destination. + // Generally the same logic as in ReadDataAsync, but wrapped in a loop where every read segment is written to the destination. while (true) { (bool wait, int bytesRead) = TryReadFromBuffer(buffer); @@ -1047,7 +1046,7 @@ namespace System.Net.Http else { // We've hit EOF. Pull in from the Http2Stream any trailers that were temporarily stored there. - CopyTrailersToResponseMessage(responseMessage); + MoveTrailersToResponseMessage(responseMessage); return; } } @@ -1058,36 +1057,24 @@ namespace System.Net.Http } } - private void CopyTrailersToResponseMessage(HttpResponseMessage responseMessage) + private void MoveTrailersToResponseMessage(HttpResponseMessage responseMessage) { - if (_trailers != null && _trailers.Count > 0) + if (_trailers != null) { - foreach (KeyValuePair trailer in _trailers) - { - responseMessage.TrailingHeaders.TryAddWithoutValidation(trailer.Key, trailer.Value); - } - _trailers.Clear(); + responseMessage.StoreReceivedTrailingHeaders(_trailers); } } private async ValueTask SendDataAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { + Debug.Assert(_requestBodyCancellationSource != null); + // Deal with [ActiveIssue("https://github.com/dotnet/runtime/issues/17492")] // Custom HttpContent classes do not get passed the cancellationToken. // So, inject the expected CancellationToken here, to ensure we can cancel the request body send if needed. - CancellationTokenSource? customCancellationSource = null; - if (!cancellationToken.CanBeCanceled) - { - cancellationToken = _requestBodyCancellationToken; - } - else if (cancellationToken != _requestBodyCancellationToken) - { - // User passed a custom CancellationToken. - // We can't tell if it includes our Token or not, so assume it doesn't. - customCancellationSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _requestBodyCancellationSource!.Token); - cancellationToken = customCancellationSource.Token; - } - + CancellationTokenRegistration linkedRegistration = cancellationToken.CanBeCanceled && cancellationToken != _requestBodyCancellationSource.Token ? + RegisterRequestBodyCancellation(cancellationToken) : + default; try { while (buffer.Length > 0) @@ -1104,12 +1091,12 @@ namespace System.Net.Http { if (_creditWaiter is null) { - _creditWaiter = new CancelableCreditWaiter(SyncObject, cancellationToken); + _creditWaiter = new CancelableCreditWaiter(SyncObject, _requestBodyCancellationSource.Token); } else { Debug.Assert(!_creditWaiter.IsPending); - _creditWaiter.ResetForAwait(cancellationToken); + _creditWaiter.ResetForAwait(_requestBodyCancellationSource.Token); } _creditWaiter.Amount = buffer.Length; } @@ -1125,12 +1112,12 @@ namespace System.Net.Http ReadOnlyMemory current; (current, buffer) = SplitBuffer(buffer, sendSize); - await _connection.SendStreamDataAsync(_streamId, current, cancellationToken).ConfigureAwait(false); + await _connection.SendStreamDataAsync(StreamId, current, _requestBodyCancellationSource.Token).ConfigureAwait(false); } } finally { - customCancellationSource?.Dispose(); + linkedRegistration.Dispose(); } } @@ -1156,6 +1143,9 @@ namespace System.Net.Http _responseBuffer.Dispose(); } + private CancellationTokenRegistration RegisterRequestBodyCancellation(CancellationToken cancellationToken) => + cancellationToken.UnsafeRegister(s => ((CancellationTokenSource)s!).Cancel(), _requestBodyCancellationSource); + // This object is itself usable as a backing source for ValueTask. Since there's only ever one awaiter // for this object's state transitions at a time, we allow the object to be awaited directly. All functionality // associated with the implementation is just delegated to the ManualResetValueTaskSourceCore. @@ -1212,7 +1202,7 @@ namespace System.Net.Http { // Wake up the wait. It will then immediately check whether cancellation was requested and throw if it was. thisRef._waitSource.SetException(ExceptionDispatchInfo.SetCurrentStackTrace( - CancellationHelper.CreateOperationCanceledException(null, _waitSourceCancellation.Token))); + CancellationHelper.CreateOperationCanceledException(null, thisRef._waitSourceCancellation.Token))); } }, this); @@ -1228,7 +1218,7 @@ namespace System.Net.Http } public void Trace(string message, [CallerMemberName] string? memberName = null) => - _connection.Trace(_streamId, message, memberName); + _connection.Trace(StreamId, message, memberName); private enum ResponseProtocolState : byte { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs index f02edd7..a516450 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.Http.Headers; @@ -117,11 +118,27 @@ namespace System.Net.Http } /// Awaits a task, logging any resulting exceptions (which are otherwise ignored). - internal void LogExceptions(Task task) => - task.ContinueWith(t => + internal void LogExceptions(Task task) + { + if (task.IsCompleted) + { + if (task.IsFaulted) + { + LogFaulted(this, task); + } + } + else { - Exception? e = t.Exception?.InnerException; // Access Exception even if not tracing, to avoid TaskScheduler.UnobservedTaskException firing - if (NetEventSource.IsEnabled) Trace($"Exception from asynchronous processing: {e}"); - }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Default); + task.ContinueWith((t, state) => LogFaulted((HttpConnectionBase)state!, t), this, + CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Default); + } + + static void LogFaulted(HttpConnectionBase connection, Task task) + { + Debug.Assert(task.IsFaulted); + Exception? e = task.Exception!.InnerException; // Access Exception even if not tracing, to avoid TaskScheduler.UnobservedTaskException firing + if (NetEventSource.IsEnabled) connection.Trace($"Exception from asynchronous processing: {e}"); + } + } } } diff --git a/src/libraries/System.Net.Http/src/System/Threading/AsyncMutex.cs b/src/libraries/System.Net.Http/src/System/Threading/AsyncMutex.cs index 6c85c58..1065304 100644 --- a/src/libraries/System.Net.Http/src/System/Threading/AsyncMutex.cs +++ b/src/libraries/System.Net.Http/src/System/Threading/AsyncMutex.cs @@ -42,9 +42,11 @@ namespace System.Threading /// with an initial count of 0. /// private bool _lockedSemaphoreFull = true; - /// The head of the double-linked waiting queue. Waiters are dequeued from the head. - private Waiter? _waitersHead; - /// The tail of the double-linked waiting queue. Waiters are added at the tail. + /// The tail of the double-linked circular waiting queue. + /// + /// Waiters are added at the tail. + /// Items are dequeued from the head (tail.Prev). + /// private Waiter? _waitersTail; /// A pool of waiter objects that are ready to be reused. /// @@ -94,25 +96,25 @@ namespace System.Threading // Now that we're holding the lock, check to see whether the async lock is acquirable. if (!_lockedSemaphoreFull) { + // If we are able to acquire the lock, we're done. _lockedSemaphoreFull = true; return default; } + + // The lock couldn't be acquired. + // Add the waiter to the linked list of waiters. + if (_waitersTail is null) + { + w.Next = w.Prev = w; + } else { - // Add it to the linked list of waiters. - if (_waitersTail is null) - { - Debug.Assert(_waitersHead is null); - _waitersTail = _waitersHead = w; - } - else - { - Debug.Assert(_waitersHead != null); - w.Prev = _waitersTail; - _waitersTail.Next = w; - _waitersTail = w; - } + Debug.Assert(_waitersTail.Next != null && _waitersTail.Prev != null); + w.Next = _waitersTail; + w.Prev = _waitersTail.Prev; + w.Prev.Next = w.Next.Prev = w; } + _waitersTail = w; } // At this point the waiter was added to the list of waiters, so we want to @@ -136,14 +138,11 @@ namespace System.Threading lock (m.SyncObj) { - bool inList = w.Next != null || w.Prev != null || m._waitersHead == w; + bool inList = w.Next != null; if (inList) { - // The waiter was still in the list. - Debug.Assert( - m._waitersHead == w || - (m._waitersTail == w && w.Prev != null && w.Next is null) || - (w.Next != null && w.Prev != null)); + // The waiter is in the list. + Debug.Assert(w.Prev != null); // The gate counter was decremented when this waiter was added. We need // to undo that. Since the waiter is still in the list, the lock must @@ -156,33 +155,19 @@ namespace System.Threading // release it, they will appropriately update state. Interlocked.Increment(ref m._gate); - // Remove it from the list. - if (m._waitersHead == w && m._waitersTail == w) - { - // It's the only node in the list. - m._waitersHead = m._waitersTail = null; - } - else if (m._waitersTail == w) - { - // It's the most recently queued item in the list. - m._waitersTail = w.Prev; - Debug.Assert(m._waitersTail != null); - m._waitersTail.Next = null; - } - else if (m._waitersHead == w) + if (w.Next == w) { - // It's the next item to be removed from the list. - m._waitersHead = w.Next; - Debug.Assert(m._waitersHead != null); - m._waitersHead.Prev = null; + Debug.Assert(m._waitersTail == w); + m._waitersTail = null; } else { - // It's in the middle of the list. - Debug.Assert(w.Next != null); - Debug.Assert(w.Prev != null); - w.Next.Prev = w.Prev; + w.Next!.Prev = w.Prev; w.Prev.Next = w.Next; + if (m._waitersTail == w) + { + m._waitersTail = w.Next; + } } // Remove it from the list. @@ -217,34 +202,37 @@ namespace System.Threading void Contended() { Waiter? w; - lock (SyncObj) { Debug.Assert(_lockedSemaphoreFull); - // Wake up the next waiter in the list. - w = _waitersHead; - if (w != null) + w = _waitersTail; + if (w is null) { - // Remove the waiter. - _waitersHead = w.Next; - if (w.Next != null) + _lockedSemaphoreFull = false; + } + else + { + Debug.Assert(w.Next != null && w.Prev != null); + Debug.Assert(w.Next != w || w.Prev == w); + Debug.Assert(w.Prev != w || w.Next == w); + + if (w.Next == w) { - w.Next.Prev = null; + _waitersTail = null; } else { - Debug.Assert(_waitersTail == w); - _waitersTail = null; + w = w.Prev; // get the head + Debug.Assert(w.Next != null && w.Prev != null); + Debug.Assert(w.Next != w && w.Prev != w); + + w.Next.Prev = w.Prev; + w.Prev.Next = w.Next; } + w.Next = w.Prev = null; } - else - { - // There wasn't a waiter. Mark that the async lock is no longer full. - Debug.Assert(_waitersTail is null); - _lockedSemaphoreFull = false; - } } // Either there wasn't a waiter, or we got one and successfully removed it from the list, -- 2.7.4