HTTP/2 Request Cancellation (dotnet/corefx#35118)
authorMax Kerr <richard.kerr@microsoft.com>
Thu, 11 Apr 2019 18:23:45 +0000 (11:23 -0700)
committerGitHub <noreply@github.com>
Thu, 11 Apr 2019 18:23:45 +0000 (11:23 -0700)
HTTP/2 cancellation support, plus improvements to outgoing write buffering.

Commit migrated from https://github.com/dotnet/corefx/commit/d6e36e4ecb0d004de54eb7bdf0ed6cdcb042e3c9

src/libraries/System.Net.Http/src/System.Net.Http.csproj
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ArrayBuffer.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/CreditManager.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/TaskCompletionSourceWithCancellation.cs [new file with mode: 0644]
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Cancellation.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs
src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs

index ca8b821..12c44f1 100644 (file)
     <Compile Include="System\Net\Http\SocketsHttpHandler\SocketsHttpHandler.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\RawConnectionStream.cs" />
     <Compile Include="System\Net\Http\SocketsHttpHandler\RedirectHandler.cs" />
+    <Compile Include="System\Net\Http\SocketsHttpHandler\TaskCompletionSourceWithCancellation.cs" />
     <Compile Include="$(CommonPath)\CoreLib\System\Collections\Concurrent\ConcurrentQueueSegment.cs">
       <Link>Common\CoreLib\System\Collections\Concurrent\ConcurrentQueueSegment.cs</Link>
     </Compile>
index 7d82ab4..74eee0d 100644 (file)
@@ -59,6 +59,8 @@ namespace System.Net.Http
         public Memory<byte> ActiveMemory => new Memory<byte>(_bytes, _activeStart, _availableStart - _activeStart);
         public Memory<byte> AvailableMemory => new Memory<byte>(_bytes, _availableStart, _bytes.Length - _availableStart);
 
+        public int Capacity => _bytes.Length;
+
         public void Discard(int byteCount)
         {
             Debug.Assert(byteCount <= ActiveSpan.Length, $"Expected {byteCount} <= {ActiveSpan.Length}");
index 6b5f922..3fd0da1 100644 (file)
@@ -4,6 +4,7 @@
 
 using System.Diagnostics;
 using System.Collections.Generic;
+using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Net.Http
@@ -13,7 +14,7 @@ namespace System.Net.Http
         private struct Waiter
         {
             public int Amount;
-            public TaskCompletionSource<int> TaskCompletionSource;
+            public TaskCompletionSourceWithCancellation<int> TaskCompletionSource;
         }
 
         private int _current;
@@ -37,7 +38,7 @@ namespace System.Net.Http
             }
         }
 
-        public ValueTask<int> RequestCreditAsync(int amount)
+        public ValueTask<int> RequestCreditAsync(int amount, CancellationToken cancellationToken)
         {
             lock (SyncObject)
             {
@@ -55,16 +56,21 @@ namespace System.Net.Http
                     return new ValueTask<int>(granted);
                 }
 
-                var tcs = new TaskCompletionSource<int>(TaskContinuationOptions.RunContinuationsAsynchronously);
+                // Uses RunContinuationsAsynchronously internally.
+                var tcs = new TaskCompletionSourceWithCancellation<int>();
 
                 if (_waiters == null)
                 {
                     _waiters = new Queue<Waiter>();
                 }
 
-                _waiters.Enqueue(new Waiter { Amount = amount, TaskCompletionSource = tcs });
+                Waiter waiter = new Waiter { Amount = amount, TaskCompletionSource = tcs };
 
-                return new ValueTask<int>(tcs.Task);
+                _waiters.Enqueue(waiter);
+
+                return new ValueTask<int>(cancellationToken.CanBeCanceled ?
+                                          tcs.WaitWithCancellationAsync(cancellationToken) :
+                                          tcs.Task);
             }
         }
 
@@ -92,8 +98,12 @@ namespace System.Net.Http
                     while (_current > 0 && _waiters.TryDequeue(out Waiter waiter))
                     {
                         int granted = Math.Min(waiter.Amount, _current);
-                        _current -= granted;
-                        waiter.TaskCompletionSource.SetResult(granted);
+
+                        // Ensure that we grant credit only if the task has not been canceled.
+                        if (waiter.TaskCompletionSource.TrySetResult(granted))
+                        {
+                            _current -= granted;
+                        }
                     }
                 }
             }
@@ -114,7 +124,7 @@ namespace System.Net.Http
                 {
                     while (_waiters.TryDequeue(out Waiter waiter))
                     {
-                        waiter.TaskCompletionSource.SetException(new ObjectDisposedException(nameof(CreditManager)));
+                        waiter.TaskCompletionSource.TrySetException(new ObjectDisposedException(nameof(CreditManager)));
                     }
                 }
             }
index c19ac03..b8e46de 100644 (file)
@@ -31,6 +31,7 @@ namespace System.Net.Http
         private readonly Dictionary<int, Http2Stream> _httpStreams;
 
         private readonly SemaphoreSlim _writerLock;
+        private readonly SemaphoreSlim _headerSerializationLock;
 
         private readonly CreditManager _connectionWindow;
         private readonly CreditManager _concurrentStreams;
@@ -41,9 +42,16 @@ namespace System.Net.Http
         private int _maxConcurrentStreams;
         private int _pendingWindowUpdate;
         private int _idleSinceTickCount;
+        private int _pendingWriters;
 
         private bool _disposed;
 
+        // If an in-progress write is canceled we need to be able to immediately
+        // report a cancellation to the user, but also block the connection until
+        // the write completes. We avoid actually canceling the write, as we would
+        // then have to close the whole connection.
+        private Task _inProgressWrite = null;
+
         private const int MaxStreamId = int.MaxValue;
 
         private static readonly byte[] s_http2ConnectionPreface = Encoding.ASCII.GetBytes("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n");
@@ -65,6 +73,11 @@ namespace System.Net.Http
         // rather than just increase the threshold.
         private const int ConnectionWindowThreshold = ConnectionWindowSize / 8;
 
+        // When buffering outgoing writes, we will automatically buffer up to this number of bytes.
+        // Single writes that are larger than the buffer can cause the buffer to expand beyond
+        // this value, so this is not a hard maximum size.
+        private const int UnflushedOutgoingBufferSize = 32 * 1024;
+
         public Http2Connection(HttpConnectionPool pool, SslStream stream)
         {
             _pool = pool;
@@ -78,6 +91,7 @@ namespace System.Net.Http
             _httpStreams = new Dictionary<int, Http2Stream>();
 
             _writerLock = new SemaphoreSlim(1, 1);
+            _headerSerializationLock = new SemaphoreSlim(1, 1);
             _connectionWindow = new CreditManager(DefaultInitialWindowSize);
             _concurrentStreams = new CreditManager(int.MaxValue);
 
@@ -458,7 +472,7 @@ namespace System.Net.Http
 
                 // Send acknowledgement
                 // Don't wait for completion, which could happen asynchronously.
-                ValueTask ignored = SendSettingsAckAsync();
+                Task ignored = SendSettingsAckAsync();
             }
         }
 
@@ -527,7 +541,7 @@ namespace System.Net.Http
 
             // Send PING ACK
             // Don't wait for completion, which could happen asynchronously.
-            ValueTask ignored = SendPingAckAsync(_incomingBuffer.ActiveMemory.Slice(0, FrameHeader.PingLength));
+            Task ignored = SendPingAckAsync(_incomingBuffer.ActiveMemory.Slice(0, FrameHeader.PingLength));
 
             _incomingBuffer.Discard(frameHeader.Length);
         }
@@ -623,80 +637,116 @@ namespace System.Net.Http
             _incomingBuffer.Discard(frameHeader.Length);
         }
 
-        private async ValueTask AcquireWriteLockAsync()
+        private async Task StartWriteAsync(int writeBytes, CancellationToken cancellationToken = default)
         {
-            await _writerLock.WaitAsync().ConfigureAwait(false);
+            await AcquireWriteLockAsync(cancellationToken).ConfigureAwait(false);
 
-            // If the connection has been aborted, then fail now instead of trying to send more data.
-            if (IsAborted())
+            try
             {
-                throw new IOException(SR.net_http_invalid_response);
-            }
-        }
+                // If there is a pending write that was canceled while in progress, wait for it to complete.
+                if (_inProgressWrite != null)
+                {
+                    await _inProgressWrite.ConfigureAwait(false);
+                    _inProgressWrite = null;
+                }
 
-        private void ReleaseWriteLock()
-        {
-            // Currently, we always flush the write buffer before releasing the lock.
-            // If we change this in the future, we will need to revisit this assert.
-            Debug.Assert(_outgoingBuffer.ActiveMemory.IsEmpty);
+                int totalBufferLength = _outgoingBuffer.Capacity;
+                int activeBufferLength = _outgoingBuffer.ActiveSpan.Length;
 
-            _writerLock.Release();
+                if (totalBufferLength >= UnflushedOutgoingBufferSize &&
+                    writeBytes >= totalBufferLength - activeBufferLength &&
+                    activeBufferLength > 0)
+                {
+                    // If the buffer has already grown to 32k, does not have room for the next request,
+                    // and is non-empty, flush the current contents to the wire.
+                    await FlushOutgoingBytesAsync().ConfigureAwait(false);
+                }
+
+                _outgoingBuffer.EnsureAvailableSpace(writeBytes);
+            }
+            catch
+            {
+                _writerLock.Release();
+                throw;
+            }
         }
 
-        private async ValueTask SendSettingsAckAsync()
+        // This method handles flushing bytes to the wire. Writes here need to be atomic, so as to avoid
+        // killing the whole connection. Callers must hold the write lock, but can specify whether or not
+        // they want to release it.
+        private void FinishWrite(bool mustFlush)
         {
-            await AcquireWriteLockAsync().ConfigureAwait(false);
+            // We can't validate that we hold the semaphore, but we can at least validate that someone is
+            // holding it.
+            Debug.Assert(_writerLock.CurrentCount == 0);
+
             try
             {
-                _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size);
-                WriteFrameHeader(new FrameHeader(0, FrameType.Settings, FrameFlags.Ack, 0));
+                // We must flush if the caller requires it, or if there are no other pending writes.
+                if (mustFlush || _pendingWriters == 0)
+                {
+                    Debug.Assert(_inProgressWrite == null);
 
-                await FlushOutgoingBytesAsync().ConfigureAwait(false);
+                    _inProgressWrite = FlushOutgoingBytesAsync();
+                }
             }
             finally
             {
-                ReleaseWriteLock();
+                _writerLock.Release();
             }
         }
 
-        private async ValueTask SendPingAckAsync(ReadOnlyMemory<byte> pingContent)
+        private async Task AcquireWriteLockAsync(CancellationToken cancellationToken)
         {
-            Debug.Assert(pingContent.Length == FrameHeader.PingLength);
-
-            await AcquireWriteLockAsync().ConfigureAwait(false);
-            try
+            Task acquireLockTask = _writerLock.WaitAsync(cancellationToken);
+            if (!acquireLockTask.IsCompletedSuccessfully)
             {
-                _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + FrameHeader.PingLength);
-                WriteFrameHeader(new FrameHeader(FrameHeader.PingLength, FrameType.Ping, FrameFlags.Ack, 0));
-                pingContent.CopyTo(_outgoingBuffer.AvailableMemory);
-                _outgoingBuffer.Commit(FrameHeader.PingLength);
-
-                await FlushOutgoingBytesAsync().ConfigureAwait(false);
+                Interlocked.Increment(ref _pendingWriters);
+                try
+                {
+                    await acquireLockTask.ConfigureAwait(false);
+                }
+                finally
+                {
+                    Interlocked.Decrement(ref _pendingWriters);
+                }
             }
-            finally
+
+            // If the connection has been aborted, then fail now instead of trying to send more data.
+            if (IsAborted())
             {
-                ReleaseWriteLock();
+                throw new IOException(SR.net_http_invalid_response);
             }
         }
 
-        private async Task SendRstStreamAsync(int streamId, Http2ProtocolErrorCode errorCode)
+        private async Task SendSettingsAckAsync()
         {
-            await AcquireWriteLockAsync().ConfigureAwait(false);
-            try
-            {
-                _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + FrameHeader.RstStreamLength);
-                WriteFrameHeader(new FrameHeader(FrameHeader.RstStreamLength, FrameType.RstStream, FrameFlags.None, streamId));
+            await StartWriteAsync(FrameHeader.Size).ConfigureAwait(false);
+            WriteFrameHeader(new FrameHeader(0, FrameType.Settings, FrameFlags.Ack, 0));
 
-                BinaryPrimitives.WriteInt32BigEndian(_outgoingBuffer.AvailableSpan, (int)errorCode);
+            FinishWrite(mustFlush: true);
+        }
 
-                _outgoingBuffer.Commit(FrameHeader.RstStreamLength);
+        private async Task SendPingAckAsync(ReadOnlyMemory<byte> pingContent)
+        {
+            Debug.Assert(pingContent.Length == FrameHeader.PingLength);
 
-                await FlushOutgoingBytesAsync().ConfigureAwait(false);
-            }
-            finally
-            {
-                ReleaseWriteLock();
-            }
+            await StartWriteAsync(FrameHeader.Size + FrameHeader.PingLength).ConfigureAwait(false);
+            WriteFrameHeader(new FrameHeader(FrameHeader.PingLength, FrameType.Ping, FrameFlags.Ack, 0));
+            pingContent.CopyTo(_outgoingBuffer.AvailableMemory);
+            _outgoingBuffer.Commit(FrameHeader.PingLength);
+
+            FinishWrite(mustFlush: false);
+        }
+
+        private async Task SendRstStreamAsync(int streamId, Http2ProtocolErrorCode errorCode)
+        {
+            await StartWriteAsync(FrameHeader.Size + FrameHeader.RstStreamLength).ConfigureAwait(false);
+            WriteFrameHeader(new FrameHeader(FrameHeader.RstStreamLength, FrameType.RstStream, FrameFlags.None, streamId));
+            BinaryPrimitives.WriteInt32BigEndian(_outgoingBuffer.AvailableSpan, (int)errorCode);
+            _outgoingBuffer.Commit(FrameHeader.RstStreamLength);
+
+            FinishWrite(mustFlush: true);
         }
 
         private static (ReadOnlyMemory<byte> first, ReadOnlyMemory<byte> rest) SplitBuffer(ReadOnlyMemory<byte> buffer, int maxSize) =>
@@ -902,28 +952,35 @@ namespace System.Net.Http
             }
         }
 
-        private async ValueTask<Http2Stream> SendHeadersAsync(HttpRequestMessage request)
+        private async ValueTask<Http2Stream> SendHeadersAsync(HttpRequestMessage request, CancellationToken cancellationToken)
         {
             // Ensure we don't exceed the max concurrent streams setting.
-            await _concurrentStreams.RequestCreditAsync(1).ConfigureAwait(false);
+            await _concurrentStreams.RequestCreditAsync(1, cancellationToken).ConfigureAwait(false);
 
-            // Note, HEADERS and CONTINUATION frames must be together, so hold the writer lock across sending all of them.
-            // We also serialize usage of the header encoder and the header buffer this way.
-            // (If necessary, we could have a separate semaphore just for creating and encoding header blocks,
-            // and defer taking the actual _writerLock until we're ready to do the write below.)
-            await _writerLock.WaitAsync().ConfigureAwait(false);
+            // We serialize usage of the header encoder and the header buffer separately from the
+            // write lock
+            await _headerSerializationLock.WaitAsync(cancellationToken).ConfigureAwait(false);
 
-            Http2Stream http2Stream = AddStream(request);
-            int streamId = http2Stream.StreamId;
+            Http2Stream http2Stream = null;
 
             try
             {
+                http2Stream = AddStream(request);
+                int streamId = http2Stream.StreamId;
+
+                http2Stream = AddStream(request);
+                streamId = http2Stream.StreamId;
+
                 // Generate the entire header block, without framing, into the connection header buffer.
                 WriteHeaders(request);
 
                 ReadOnlyMemory<byte> remaining = _headerBuffer.ActiveMemory;
                 Debug.Assert(remaining.Length > 0);
 
+                // Calculate the total number of bytes we're going to use (content + headers).
+                int totalSize = remaining.Length + (remaining.Length / FrameHeader.MaxLength) * FrameHeader.Size +
+                                (remaining.Length % FrameHeader.MaxLength == 0 ? FrameHeader.Size : 0);
+
                 // Split into frames and send.
                 ReadOnlyMemory<byte> current;
                 (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength);
@@ -932,42 +989,43 @@ namespace System.Net.Http
                     (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None) |
                     (request.Content == null ? FrameFlags.EndStream : FrameFlags.None);
 
-                _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + current.Length);
+                // Note, HEADERS and CONTINUATION frames must be together, so hold the writer lock across sending all of them.
+                await StartWriteAsync(totalSize).ConfigureAwait(false);
+
                 WriteFrameHeader(new FrameHeader(current.Length, FrameType.Headers, flags, streamId));
                 current.CopyTo(_outgoingBuffer.AvailableMemory);
                 _outgoingBuffer.Commit(current.Length);
 
-                await FlushOutgoingBytesAsync().ConfigureAwait(false);
-
                 while (remaining.Length > 0)
                 {
                     (current, remaining) = SplitBuffer(remaining, FrameHeader.MaxLength);
 
                     flags = (remaining.Length == 0 ? FrameFlags.EndHeaders : FrameFlags.None);
 
-                    _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + current.Length);
                     WriteFrameHeader(new FrameHeader(current.Length, FrameType.Continuation, flags, streamId));
                     current.CopyTo(_outgoingBuffer.AvailableMemory);
                     _outgoingBuffer.Commit(current.Length);
-
-                    await FlushOutgoingBytesAsync().ConfigureAwait(false);
                 }
+
+                // If this is not the end of the stream, we can put off flushing the buffer
+                // since we know that there are going to be data frames following.
+                FinishWrite(mustFlush: (flags & FrameFlags.EndStream) != 0);
             }
             catch
             {
-                http2Stream.Dispose();
+                http2Stream?.Dispose();
                 throw;
             }
             finally
             {
                 _headerBuffer.Discard(_headerBuffer.ActiveMemory.Length);
-                _writerLock.Release();
+                _headerSerializationLock.Release();
             }
 
             return http2Stream;
         }
 
-        private async ValueTask SendStreamDataAsync(int streamId, ReadOnlyMemory<byte> buffer)
+        private async Task SendStreamDataAsync(int streamId, ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
         {
             ReadOnlyMemory<byte> remaining = buffer;
 
@@ -975,63 +1033,53 @@ namespace System.Net.Http
             {
                 int frameSize = Math.Min(remaining.Length, FrameHeader.MaxLength);
 
-                frameSize = await _connectionWindow.RequestCreditAsync(frameSize).ConfigureAwait(false);
+                // Once credit had been granted, we want to actually consume those bytes.
+                frameSize = await _connectionWindow.RequestCreditAsync(frameSize, cancellationToken).ConfigureAwait(false);
 
                 ReadOnlyMemory<byte> current;
                 (current, remaining) = SplitBuffer(remaining, frameSize);
 
-                await AcquireWriteLockAsync().ConfigureAwait(false);
+                // It's possible that a cancellation will occur while we wait for the write lock. In that case, we need to
+                // return the credit that we have acquired and don't plan to use.
                 try
                 {
-                    _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + current.Length);
-                    WriteFrameHeader(new FrameHeader(current.Length, FrameType.Data, FrameFlags.None, streamId));
-                    current.CopyTo(_outgoingBuffer.AvailableMemory);
-                    _outgoingBuffer.Commit(current.Length);
-
-                    await FlushOutgoingBytesAsync().ConfigureAwait(false);
+                    await StartWriteAsync(FrameHeader.Size + current.Length, cancellationToken).ConfigureAwait(false);
                 }
-                finally
+                catch (OperationCanceledException)
                 {
-                    ReleaseWriteLock();
+                    _connectionWindow.AdjustCredit(frameSize);
+                    throw;
                 }
+
+                WriteFrameHeader(new FrameHeader(current.Length, FrameType.Data, FrameFlags.None, streamId));
+                current.CopyTo(_outgoingBuffer.AvailableMemory);
+                _outgoingBuffer.Commit(current.Length);
+
+                FinishWrite(mustFlush: false);
             }
         }
 
-        private async ValueTask SendEndStreamAsync(int streamId)
+        private async Task SendEndStreamAsync(int streamId)
         {
-            await AcquireWriteLockAsync().ConfigureAwait(false);
-            try
-            {
-                _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size);
-                WriteFrameHeader(new FrameHeader(0, FrameType.Data, FrameFlags.EndStream, streamId));
+            await StartWriteAsync(FrameHeader.Size).ConfigureAwait(false);
 
-                await FlushOutgoingBytesAsync().ConfigureAwait(false);
-            }
-            finally
-            {
-                ReleaseWriteLock();
-            }
+            WriteFrameHeader(new FrameHeader(0, FrameType.Data, FrameFlags.EndStream, streamId));
+
+            FinishWrite(mustFlush: true);
         }
 
-        private async ValueTask SendWindowUpdateAsync(int streamId, int amount)
+        private async Task SendWindowUpdateAsync(int streamId, int amount)
         {
             Debug.Assert(amount > 0);
 
-            await _writerLock.WaitAsync().ConfigureAwait(false);
-            try
-            {
-                _outgoingBuffer.EnsureAvailableSpace(FrameHeader.Size + FrameHeader.WindowUpdateLength);
+            // We update both the connection-level and stream-level windows at the same time
+            await StartWriteAsync(FrameHeader.Size + FrameHeader.WindowUpdateLength).ConfigureAwait(false);
 
-                WriteFrameHeader(new FrameHeader(FrameHeader.WindowUpdateLength, FrameType.WindowUpdate, FrameFlags.None, streamId));
-                BinaryPrimitives.WriteInt32BigEndian(_outgoingBuffer.AvailableSpan, amount);
-                _outgoingBuffer.Commit(FrameHeader.WindowUpdateLength);
+            WriteFrameHeader(new FrameHeader(FrameHeader.WindowUpdateLength, FrameType.WindowUpdate, FrameFlags.None, streamId));
+            BinaryPrimitives.WriteInt32BigEndian(_outgoingBuffer.AvailableSpan, amount);
+            _outgoingBuffer.Commit(FrameHeader.WindowUpdateLength);
 
-                await FlushOutgoingBytesAsync().ConfigureAwait(false);
-            }
-            finally
-            {
-                _writerLock.Release();
-            }
+            FinishWrite(mustFlush: true);
         }
 
         private void ExtendWindow(int amount)
@@ -1053,7 +1101,7 @@ namespace System.Net.Http
                 _pendingWindowUpdate = 0;
             }
 
-            ValueTask ignored = SendWindowUpdateAsync(0, windowUpdateSize);
+            Task ignored = SendWindowUpdateAsync(0, windowUpdateSize);
         }
 
         private void WriteFrameHeader(FrameHeader frameHeader)
@@ -1295,10 +1343,10 @@ namespace System.Net.Http
             try
             {
                 // Send headers
-                http2Stream = await SendHeadersAsync(request).ConfigureAwait(false);
+                http2Stream = await SendHeadersAsync(request, cancellationToken).ConfigureAwait(false);
 
                 // Send request body, if any
-                await http2Stream.SendRequestBodyAsync().ConfigureAwait(false);
+                await http2Stream.SendRequestBodyAsync(cancellationToken).ConfigureAwait(false);
 
                 // Wait for response headers to be read.
                 await http2Stream.ReadResponseHeadersAsync().ConfigureAwait(false);
@@ -1320,6 +1368,21 @@ namespace System.Net.Http
                     // ISSUE 31315: Determine if/how to expose HTTP2 error codes
                     throw new HttpRequestException(SR.net_http_client_execution_error, e);
                 }
+                else if (e is OperationCanceledException oce)
+                {
+                    // If the operation has been canceled after the stream was allocated an ID, send a RST_STREAM.
+                    if (http2Stream != null && http2Stream.StreamId != 0)
+                    {
+                        http2Stream.Cancel();
+                    }
+
+                    if (oce.CancellationToken == cancellationToken)
+                    {
+                        throw;
+                    }
+
+                    throw new OperationCanceledException(cancellationToken);
+                }
                 else
                 {
                     throw;
index a93acb9..c492381 100644 (file)
@@ -88,7 +88,7 @@ namespace System.Net.Http
             public HttpRequestMessage Request => _request;
             public HttpResponseMessage Response => _response;
 
-            public async Task SendRequestBodyAsync()
+            public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
             {
                 // TODO: ISSUE 31312: Expect: 100-continue and early response handling
                 // Note that in an "early response" scenario, where we get a response before we've finished sending the request body
@@ -100,8 +100,11 @@ namespace System.Net.Http
                 {
                     using (Http2WriteStream writeStream = new Http2WriteStream(this))
                     {
-                        await _request.Content.CopyToAsync(writeStream).ConfigureAwait(false);
+                        await _request.Content.CopyToAsync(writeStream, null, cancellationToken).ConfigureAwait(false);
                     }
+
+                    // Don't wait for completion, which could happen asynchronously.
+                    Task ignored = _connection.SendEndStreamAsync(_streamId);
                 }
             }
 
@@ -364,7 +367,7 @@ namespace System.Net.Http
                 int windowUpdateSize = _pendingWindowUpdate;
                 _pendingWindowUpdate = 0;
 
-                ValueTask ignored = _connection.SendWindowUpdateAsync(_streamId, windowUpdateSize);
+                Task ignored = _connection.SendWindowUpdateAsync(_streamId, windowUpdateSize);
             }
 
             private (bool wait, int bytesRead) TryReadFromBuffer(Span<byte> buffer)
@@ -429,18 +432,18 @@ namespace System.Net.Http
                 return bytesRead;
             }
 
-            private async ValueTask SendDataAsync(ReadOnlyMemory<byte> buffer)
+            private async ValueTask SendDataAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
             {
                 ReadOnlyMemory<byte> remaining = buffer;
 
                 while (remaining.Length > 0)
                 {
-                    int sendSize = await _streamWindow.RequestCreditAsync(remaining.Length).ConfigureAwait(false);
+                    int sendSize = await _streamWindow.RequestCreditAsync(remaining.Length, cancellationToken).ConfigureAwait(false);
 
                     ReadOnlyMemory<byte> current;
                     (current, remaining) = SplitBuffer(remaining, sendSize);
 
-                    await _connection.SendStreamDataAsync(_streamId, current).ConfigureAwait(false);
+                    await _connection.SendStreamDataAsync(_streamId, current, cancellationToken).ConfigureAwait(false);
                 }
             }
 
@@ -460,6 +463,25 @@ namespace System.Net.Http
                 }
             }
 
+            public void Cancel()
+            {
+                bool signalWaiter;
+                lock (SyncObject)
+                {
+                    Task ignored = _connection.SendRstStreamAsync(_streamId, Http2ProtocolErrorCode.Cancel);
+                    _state = StreamState.Aborted;
+
+                    signalWaiter = _hasWaiter;
+                    _hasWaiter = false;
+                }
+                if (signalWaiter)
+                {
+                    _waitSource.SetResult(true);
+                }
+
+                _connection.RemoveStream(this);
+            }
+
             // This object is itself usable as a backing source for ValueTask.  Since there's only ever one awaiter
             // for this object's state transitions at a time, we allow the object to be awaited directly. All functionality
             // associated with the implementation is just delegated to the ManualResetValueTaskSourceCore.
@@ -532,9 +554,6 @@ namespace System.Net.Http
                         return;
                     }
 
-                    // Don't wait for completion, which could happen asynchronously.
-                    ValueTask ignored = http2Stream._connection.SendEndStreamAsync(http2Stream.StreamId);
-
                     base.Dispose(disposing);
                 }
 
@@ -551,7 +570,7 @@ namespace System.Net.Http
                         return new ValueTask(Task.FromException(new ObjectDisposedException(nameof(Http2WriteStream))));
                     }
 
-                    return http2Stream.SendDataAsync(buffer);
+                    return http2Stream.SendDataAsync(buffer, cancellationToken);
                 }
 
                 public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
index f25c0cd..3861882 100644 (file)
@@ -1071,28 +1071,5 @@ namespace System.Net.Http
             public override bool Equals(object obj) => obj is CachedConnection && Equals((CachedConnection)obj);
             public override int GetHashCode() => _connection?.GetHashCode() ?? 0;
         }
-
-        private sealed class TaskCompletionSourceWithCancellation<T> : TaskCompletionSource<T>
-        {
-            private CancellationToken _cancellationToken;
-
-            public TaskCompletionSourceWithCancellation() : base(TaskCreationOptions.RunContinuationsAsynchronously)
-            {
-            }
-
-            private void OnCancellation()
-            {
-                TrySetCanceled(_cancellationToken);
-            }
-
-            public async Task<T> WaitWithCancellationAsync(CancellationToken cancellationToken)
-            {
-                _cancellationToken = cancellationToken;
-                using (cancellationToken.Register(s => ((TaskCompletionSourceWithCancellation<HttpConnection>)s).OnCancellation(), this))
-                {
-                    return await Task.ConfigureAwait(false);
-                }
-            }
-        }
     }
 }
diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/TaskCompletionSourceWithCancellation.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/TaskCompletionSourceWithCancellation.cs
new file mode 100644 (file)
index 0000000..d462b16
--- /dev/null
@@ -0,0 +1,32 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Net.Http
+{
+    internal sealed class TaskCompletionSourceWithCancellation<T> : TaskCompletionSource<T>
+    {
+        private CancellationToken _cancellationToken;
+
+        public TaskCompletionSourceWithCancellation() : base(TaskCreationOptions.RunContinuationsAsynchronously)
+        {
+        }
+
+        private void OnCancellation()
+        {
+            TrySetCanceled(_cancellationToken);
+        }
+
+        public async Task<T> WaitWithCancellationAsync(CancellationToken cancellationToken)
+        {
+            _cancellationToken = cancellationToken;
+            using (cancellationToken.UnsafeRegister(s => ((TaskCompletionSourceWithCancellation<T>)s).OnCancellation(), this))
+            {
+                return await Task.ConfigureAwait(false);
+            }
+        }
+    }
+}
index 9c6e64c..77d4caf 100644 (file)
@@ -343,6 +343,46 @@ namespace System.Net.Http.Functional.Tests
             }
         }
 
+        [Fact]
+        public async Task SendAsync_Cancel_CancellationTokenPropagates()
+        {
+            TaskCompletionSource<bool> clientCanceled = new TaskCompletionSource<bool>();
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    var cts = new CancellationTokenSource();
+                    cts.Cancel();
+
+                    using (HttpClient client = CreateHttpClient())
+                    {
+                        OperationCanceledException ex = null;
+                        try
+                        {
+                            await client.GetAsync(uri, cts.Token);
+                        }
+                        catch(OperationCanceledException e)
+                        {
+                            ex = e;
+                        }
+                        Assert.True(ex != null, "Expected OperationCancelledException, but no exception was thrown.");
+
+                        Assert.True(cts.Token.IsCancellationRequested, "cts token IsCancellationRequested");
+
+                        if (!PlatformDetection.IsFullFramework)
+                        {
+                            // .NET Framework has bug where it doesn't propagate token information.
+                            Assert.True(ex.CancellationToken.IsCancellationRequested, "exception token IsCancellationRequested");
+                        }
+                        clientCanceled.SetResult(true);
+                    }
+                },
+                async server =>
+                {
+                    Task serverTask = server.HandleRequestAsync();
+                    await clientCanceled.Task;
+                });
+        }
+
         private async Task ValidateClientCancellationAsync(Func<Task> clientBodyAsync)
         {
             var stopwatch = Stopwatch.StartNew();
index 3e6c56d..a82e73f 100644 (file)
@@ -2,7 +2,9 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Diagnostics;
 using System.Net.Test.Common;
+using System.Threading;
 using System.Threading.Tasks;
 
 using Xunit;
@@ -1108,5 +1110,157 @@ namespace System.Net.Http.Functional.Tests
                 Assert.Equal(HttpStatusCode.OK, response.StatusCode);
             }
         }
+
+        [OuterLoop("Uses Task.Delay")]
+        [ConditionalFact(nameof(SupportsAlpn))]
+        public async Task Http2_WaitingForStream_Cancellation()
+        {
+            HttpClientHandler handler = CreateHttpClientHandler();
+            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+
+            using (var server = Http2LoopbackServer.CreateServer())
+            using (var client = new HttpClient(handler))
+            {
+                Task<HttpResponseMessage> sendTask = client.GetAsync(server.Address);
+
+                await server.EstablishConnectionAsync();
+                server.IgnoreWindowUpdates();
+
+                // Process first request and send response.
+                int streamId = await server.ReadRequestHeaderAsync();
+                await server.SendDefaultResponseAsync(streamId);
+
+                HttpResponseMessage response = await sendTask;
+                Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+
+                // Change MaxConcurrentStreams setting and wait for ack.
+                // (We don't want to send any new requests until we receive the ack, otherwise we may have a timing issue.)
+                SettingsFrame settingsFrame = new SettingsFrame(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 0 });
+                await server.WriteFrameAsync(settingsFrame);
+                Frame settingsAckFrame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30));
+                Assert.Equal(FrameType.Settings, settingsAckFrame.Type);
+                Assert.Equal(FrameFlags.Ack, settingsAckFrame.Flags);
+
+                // Issue a new request, so that we can cancel it while it waits for a stream.
+                var cts = new CancellationTokenSource();
+                sendTask = client.GetAsync(server.Address, cts.Token);
+
+                // Make sure that the request makes it to the point where it's waiting for a connection.
+                // It's possible that we'll still initiate a cancellation before it makes it to the queue,
+                // but it should still behave in the same way if so.
+                await Task.Delay(500);
+
+                Stopwatch stopwatch = Stopwatch.StartNew();
+                cts.Cancel();
+
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await sendTask);
+
+                // Ensure that the cancellation occurs promptly
+                stopwatch.Stop();
+                Assert.True(stopwatch.ElapsedMilliseconds < 30000);
+
+                // As the client has not allocated a stream ID when the corresponding request is cancelled,
+                // we do not send a RST stream frame.
+            }
+        }
+
+        [ConditionalFact(nameof(SupportsAlpn))]
+        public async Task Http2_WaitingOnWindowCredit_Cancellation()
+        {
+            // The goal of this test is to get the client into the state where it has sent the headers,
+            // but is waiting on window credit before it will send the body. We then issue a cancellation
+            // to ensure the request is cancelled as expected.
+            const int InitialWindowSize = 65535;
+            const int ContentSize = InitialWindowSize + 1;
+
+            HttpClientHandler handler = CreateHttpClientHandler();
+            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+            TestHelper.EnsureHttp2Feature(handler);
+
+            var content = new ByteArrayContent(TestHelper.GenerateRandomContent(ContentSize));
+
+            using (var server = Http2LoopbackServer.CreateServer())
+            using (var client = new HttpClient(handler))
+            {
+                var cts = new CancellationTokenSource();
+                Task<HttpResponseMessage> clientTask = client.PostAsync(server.Address, content, cts.Token);
+
+                await server.EstablishConnectionAsync();
+
+                Frame frame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30));
+                int streamId = frame.StreamId;
+                Assert.Equal(FrameType.Headers, frame.Type);
+                Assert.Equal(FrameFlags.EndHeaders, frame.Flags);
+
+                // Receive up to initial window size
+                int bytesReceived = 0;
+                while (bytesReceived < InitialWindowSize)
+                {
+                    frame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30));
+                    Assert.Equal(streamId, frame.StreamId);
+                    Assert.Equal(FrameType.Data, frame.Type);
+                    Assert.Equal(FrameFlags.None, frame.Flags);
+                    Assert.True(frame.Length > 0);
+
+                    bytesReceived += frame.Length;
+                }
+
+                // The client is waiting for more credit in order to send the last byte of the
+                // request body. Test cancellation at this point.
+                Stopwatch stopwatch = Stopwatch.StartNew();
+
+                cts.Cancel();
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await clientTask);
+
+                // Ensure that the cancellation occurs promptly
+                stopwatch.Stop();
+                Assert.True(stopwatch.ElapsedMilliseconds < 30000);
+
+                // The server should receive a RstStream frame.
+                frame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30));
+                Assert.Equal(FrameType.RstStream, frame.Type);
+            }
+        }
+
+        [OuterLoop("Uses Task.Delay")]
+        [ConditionalFact(nameof(SupportsAlpn))]
+        public async Task Http2_PendingSend_Cancellation()
+        {
+            // The goal of this test is to get the client into the state where it is sending content,
+            // but the send pends because the TCP window is full.
+            const int InitialWindowSize = 65535;
+            const int ContentSize = InitialWindowSize * 2; // Double the default TCP window size.
+
+            HttpClientHandler handler = CreateHttpClientHandler();
+            handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+            TestHelper.EnsureHttp2Feature(handler);
+
+            var content = new ByteArrayContent(TestHelper.GenerateRandomContent(ContentSize));
+
+            using (var server = Http2LoopbackServer.CreateServer())
+            using (var client = new HttpClient(handler))
+            {
+                var cts = new CancellationTokenSource();
+
+                Task<HttpResponseMessage> clientTask = client.PostAsync(server.Address, content, cts.Token);
+
+                await server.EstablishConnectionAsync();
+
+                Frame frame = await server.ReadFrameAsync(TimeSpan.FromSeconds(30));
+                int streamId = frame.StreamId;
+                Assert.Equal(FrameType.Headers, frame.Type);
+                Assert.Equal(FrameFlags.EndHeaders, frame.Flags);
+
+                // Increase the size of the HTTP/2 Window, so that it is large enough to fill the
+                // TCP window when we do not perform any reads on the server side.
+                await server.WriteFrameAsync(new WindowUpdateFrame(InitialWindowSize, streamId));
+
+                // Give the client time to read the window update frame, and for the write to pend.
+                await Task.Delay(1000);
+                cts.Cancel();
+
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await clientTask);
+            }
+        }
     }
 }
index 63a68da..4e51c73 100644 (file)
@@ -367,35 +367,6 @@ namespace System.Net.Http.Functional.Tests
             }
         }
 
-        [OuterLoop("Uses external server")]
-        [Fact]
-        public async Task SendAsync_Cancel_CancellationTokenPropagates()
-        {
-            var cts = new CancellationTokenSource();
-            cts.Cancel();
-            using (HttpClient client = CreateHttpClient())
-            {
-                var request = new HttpRequestMessage(HttpMethod.Post, Configuration.Http.RemoteEchoServer);
-                Task t = client.SendAsync(request, cts.Token);
-                OperationCanceledException ex;
-                if (PlatformDetection.IsUap)
-                {
-                    ex = await Assert.ThrowsAsync<OperationCanceledException>(() => t);
-                }
-                else
-                {
-                    ex = await Assert.ThrowsAsync<TaskCanceledException>(() => t);
-                }
-
-                Assert.True(cts.Token.IsCancellationRequested, "cts token IsCancellationRequested");
-                if (!PlatformDetection.IsFullFramework)
-                {
-                    // .NET Framework has bug where it doesn't propagate token information.
-                    Assert.True(ex.CancellationToken.IsCancellationRequested, "exception token IsCancellationRequested");
-                }
-            }
-        }
-
         [SkipOnTargetFramework(TargetFrameworkMonikers.Uap, "UAP HTTP stack doesn't support .Proxy property")]
         [Theory]
         [InlineData("[::1234]")]
index c11693c..9a963d3 100644 (file)
@@ -1665,4 +1665,11 @@ namespace System.Net.Http.Functional.Tests
         protected override bool UseSocketsHttpHandler => true;
         protected override bool UseHttp2LoopbackServer => true;
     }
+    
+    [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]
+    public sealed class SocketsHttpHandler_HttpClientHandler_Cancellation_Test_Http2 : HttpClientHandler_Cancellation_Test
+    {
+        protected override bool UseSocketsHttpHandler => true;
+        protected override bool UseHttp2LoopbackServer => true;
+    }
 }