From 294a0381c564bb5008c6bc5a0b7da2a9ea66335d Mon Sep 17 00:00:00 2001 From: Geoff Kizer Date: Wed, 20 Feb 2019 12:54:39 -0800 Subject: [PATCH] add StreamState and rework HTTP2 stream state handling cleanup Commit migrated from https://github.com/dotnet/corefx/commit/1eba83637d47fde09c8e1748b0ebb4ca42da5ca4 --- .../Common/tests/System/Net/Http/Http2Frames.cs | 59 ++++ .../Net/Http/SocketsHttpHandler/Http2Stream.cs | 190 +++++++++---- .../FunctionalTests/HttpClientHandlerTest.Http2.cs | 315 +++++++++++++++------ 3 files changed, 411 insertions(+), 153 deletions(-) diff --git a/src/libraries/Common/tests/System/Net/Http/Http2Frames.cs b/src/libraries/Common/tests/System/Net/Http/Http2Frames.cs index f95fa23..11f4309 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2Frames.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2Frames.cs @@ -207,6 +207,65 @@ namespace System.Net.Test.Common } } + public class ContinuationFrame : Frame + { + public byte PadLength = 0; + public int StreamDependency = 0; + public byte Weight = 0; + public Memory Data; + + public ContinuationFrame(Memory data, FrameFlags flags, byte padLength, int streamDependency, byte weight, int streamId) : + base(0, FrameType.Continuation, flags, streamId) + { + Length = data.Length + (PaddedFlag ? padLength + 1 : 0) + (PriorityFlag ? 5 : 0); + + Data = data; + PadLength = padLength; + StreamDependency = streamDependency; + Weight = weight; + } + + public static ContinuationFrame ReadFrom(Frame header, ReadOnlySpan buffer) + { + int idx = 0; + + byte padLength = (byte)(header.PaddedFlag ? buffer[idx++] : 0); + int streamDependency = header.PriorityFlag ? (int)((uint)((buffer[idx++] << 24) | (buffer[idx++] << 16) | (buffer[idx++] << idx++) | buffer[idx++]) & 0x7FFFFFFF) : 0; + byte weight = (byte)(header.PaddedFlag ? buffer[idx++] : 0); + + byte[] data = buffer.Slice(idx).ToArray(); + + return new ContinuationFrame(data, header.Flags, padLength, streamDependency, weight, header.StreamId); + } + + public override void WriteTo(Span buffer) + { + base.WriteTo(buffer); + + int idx = Frame.FrameHeaderLength; + if (PaddedFlag) + { + buffer[idx++] = PadLength; + } + + if (PriorityFlag) + { + buffer[idx++] = (byte)((StreamDependency & 0xFF000000) >> 24); + buffer[idx++] = (byte)((StreamDependency & 0x00FF0000) >> 16); + buffer[idx++] = (byte)((StreamDependency & 0x0000FF00) >> 8); + buffer[idx++] = (byte)(StreamDependency & 0x000000FF); + + buffer[idx++] = Weight; + } + Data.Span.CopyTo(buffer.Slice(idx)); + } + + public override string ToString() + { + return base.ToString() + $"\nPadding: {PadLength}\nStream Dependency: {StreamDependency}\nWeight: {Weight}"; + } + } + public class PriorityFrame : Frame { public int StreamDependency = 0; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index cc846d8..cc0e3a9 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -15,6 +15,14 @@ namespace System.Net.Http { private sealed class Http2Stream : IDisposable { + private enum StreamState : byte + { + ExpectingHeaders, + ExpectingData, + Complete, + Aborted + } + private const int InitialStreamBufferSize = #if DEBUG 10; @@ -27,13 +35,12 @@ namespace System.Net.Http private readonly CreditManager _streamWindow; private readonly HttpRequestMessage _request; private readonly HttpResponseMessage _response; - private readonly TaskCompletionSource _responseHeadersAvailable; private ArrayBuffer _responseBuffer; // mutable struct, do not make this readonly private int _pendingWindowUpdate; - private TaskCompletionSource _responseDataAvailable; - private bool _responseComplete; - private bool _responseAborted; + + private StreamState _state; + private TaskCompletionSource _waiterTaskSource; private bool _disposed; private const int StreamWindowSize = DefaultInitialWindowSize; @@ -46,6 +53,8 @@ namespace System.Net.Http _connection = connection; _streamId = streamId; + _state = StreamState.ExpectingHeaders; + _request = request; _response = new HttpResponseMessage() { @@ -62,11 +71,9 @@ namespace System.Net.Http _streamWindow = new CreditManager(initialWindowSize); - _responseHeadersAvailable = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - // TODO: ISSUE 31313: Avoid allocating a TaskCompletionSource repeatedly by using a resettable ValueTaskSource. // See: https://github.com/dotnet/corefx/blob/master/src/Common/tests/System/Threading/Tasks/Sources/ManualResetValueTaskSource.cs - _responseDataAvailable = null; + _waiterTaskSource = null; } private object SyncObject => _streamWindow; @@ -92,23 +99,6 @@ namespace System.Net.Http } } - public async Task ReadResponseHeadersAsync() - { - // Wait for response headers to be read. - bool emptyResponse = await _responseHeadersAvailable.Task.ConfigureAwait(false); - - // Start to process the response body. - ((HttpConnectionResponseContent)_response.Content).SetStream(emptyResponse ? - EmptyReadStream.Instance : - (Stream)new Http2ReadStream(this)); - - // Process Set-Cookie headers. - if (_connection._pool.Settings._useCookies) - { - CookieHelper.ProcessReceivedCookies(_response, _connection._pool.Settings._cookieContainer); - } - } - public void OnWindowUpdate(int amount) { _streamWindow.AdjustCredit(amount); @@ -124,6 +114,11 @@ namespace System.Net.Http { // TODO: ISSUE 31309: Optimize HPACK static table decoding + if (_state != StreamState.ExpectingHeaders) + { + throw new Http2ProtocolException(Http2ProtocolErrorCode.ProtocolError); + } + if (name.SequenceEqual(s_statusHeaderName)) { if (value.Length != 3) @@ -163,12 +158,33 @@ namespace System.Net.Http public void OnResponseHeadersComplete(bool endStream) { - _responseHeadersAvailable.SetResult(endStream); + TaskCompletionSource waiterTaskSource = null; + + lock (SyncObject) + { + if (_state != StreamState.ExpectingHeaders) + { + throw new Http2ProtocolException(Http2ProtocolErrorCode.ProtocolError); + } + + _state = endStream ? StreamState.Complete : StreamState.ExpectingData; + + if (_waiterTaskSource != null) + { + waiterTaskSource = _waiterTaskSource; + _waiterTaskSource = null; + } + } + + if (waiterTaskSource != null) + { + waiterTaskSource.SetResult(true); + } } public void OnResponseData(ReadOnlySpan buffer, bool endStream) { - TaskCompletionSource readDataAvailable = null; + TaskCompletionSource waiterTaskSource = null; lock (SyncObject) { @@ -177,7 +193,10 @@ namespace System.Net.Http return; } - Debug.Assert(!_responseComplete); + if (_state != StreamState.ExpectingData) + { + throw new Http2ProtocolException(Http2ProtocolErrorCode.ProtocolError); + } if (_responseBuffer.ActiveSpan.Length + buffer.Length > StreamWindowSize) { @@ -191,24 +210,26 @@ namespace System.Net.Http if (endStream) { - _responseComplete = true; + _state = StreamState.Complete; } - if (_responseDataAvailable != null) + if (_waiterTaskSource != null) { - readDataAvailable = _responseDataAvailable; - _responseDataAvailable = null; + waiterTaskSource = _waiterTaskSource; + _waiterTaskSource = null; } } - if (readDataAvailable != null) + if (waiterTaskSource != null) { - readDataAvailable.SetResult(true); + waiterTaskSource.SetResult(true); } } public void OnResponseAbort() { + TaskCompletionSource waiterTaskSource = null; + lock (SyncObject) { if (_disposed) @@ -216,39 +237,91 @@ namespace System.Net.Http return; } - Debug.Assert(!_responseComplete); + if (_state == StreamState.Aborted) + { + return; + } + + _state = StreamState.Aborted; - _responseComplete = true; - _responseAborted = true; + if (_waiterTaskSource != null) + { + waiterTaskSource = _waiterTaskSource; + _waiterTaskSource = null; + } + } - if (!_responseHeadersAvailable.Task.IsCompleted) + if (waiterTaskSource != null) + { + waiterTaskSource.SetResult(true); + } + } + + private (Task waiterTask, bool isEmptyResponse) TryEnsureHeaders() + { + lock (SyncObject) + { + if (_disposed) { - // We are still waiting for response headers, so fail that task - _responseHeadersAvailable.SetException(new IOException(SR.net_http_invalid_response)); + throw new ObjectDisposedException(nameof(Http2Stream)); + } - // We shouldn't be waiting on data, since we haven't processed headers yet - Debug.Assert(_responseDataAvailable == null); + if (_state == StreamState.Aborted) + { + throw new IOException(SR.net_http_invalid_response); + } + else if (_state == StreamState.ExpectingHeaders) + { + Debug.Assert(_waiterTaskSource == null); + _waiterTaskSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + return (_waiterTaskSource.Task, false); + } + else if (_state == StreamState.ExpectingData) + { + return (null, false); } else { - if (_responseDataAvailable != null) - { - _responseDataAvailable.SetException(new IOException(SR.net_http_invalid_response)); - _responseDataAvailable = null; - } + Debug.Assert(_state == StreamState.Complete); + return (null, _responseBuffer.ActiveSpan.Length == 0); } } } + public async Task ReadResponseHeadersAsync() + { + // Wait for response headers to be read. + Task waiterTask; + bool emptyResponse; + (waiterTask, emptyResponse) = TryEnsureHeaders(); + if (waiterTask != null) + { + await waiterTask; + (waiterTask, emptyResponse) = TryEnsureHeaders(); + Debug.Assert(waiterTask == null); + } + + // Start to process the response body. + ((HttpConnectionResponseContent)_response.Content).SetStream(emptyResponse ? + EmptyReadStream.Instance : + (Stream)new Http2ReadStream(this)); + + // Process Set-Cookie headers. + if (_connection._pool.Settings._useCookies) + { + CookieHelper.ProcessReceivedCookies(_response, _connection._pool.Settings._cookieContainer); + } + } + private void ExtendWindow(int amount) { Debug.Assert(amount > 0); Debug.Assert(_pendingWindowUpdate < StreamWindowThreshold); - if (_responseComplete) + if (_state != StreamState.ExpectingData) { - // We have already read to the end of the response, so there's no need to send - // WINDOW_UPDATEs any more. + // We are not expecting any more data (because we've either completed or aborted). + // So no need to send any more WINDOW_UPDATEs. return; } @@ -283,21 +356,20 @@ namespace System.Net.Http return (null, bytesRead); } - else if (_responseComplete) + else if (_state == StreamState.Complete) { - if (_responseAborted) - { - throw new IOException(SR.net_http_invalid_response); - } - return (null, 0); } + else if (_state == StreamState.Aborted) + { + throw new IOException(SR.net_http_invalid_response); + } - Debug.Assert(_responseDataAvailable == null); - Debug.Assert(!_responseAborted); + Debug.Assert(_state == StreamState.ExpectingData); - _responseDataAvailable = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - return (_responseDataAvailable.Task, 0); + Debug.Assert(_waiterTaskSource == null); + _waiterTaskSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + return (_waiterTaskSource.Task, 0); } } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs index e003241..0533ce4 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs @@ -12,16 +12,15 @@ namespace System.Net.Http.Functional.Tests public abstract class HttpClientHandlerTest_Http2 : HttpClientHandlerTestBase { protected override bool UseSocketsHttpHandler => true; + protected override bool UseHttp2LoopbackServer => true; + public static bool SupportsAlpn => PlatformDetection.SupportsAlpn; [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_ClientPreface_Sent() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -34,11 +33,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_InitialSettings_SentAndAcked() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -69,11 +65,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_DataSentBeforeServerPreface_ProtocolError() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -90,11 +83,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_NoResponseBody_Success() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -113,11 +103,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_ZeroLengthResponseBody_Success() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -140,11 +127,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_ServerSendsValidSettingsValues_Success() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -184,11 +168,8 @@ namespace System.Net.Http.Functional.Tests return; } - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -209,11 +190,8 @@ namespace System.Net.Http.Functional.Tests return; } - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -231,11 +209,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_StreamResetByServerAfterHeadersSent_RequestFails() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -256,11 +231,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_StreamResetByServerAfterPartialBodySent_RequestFails() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -286,11 +258,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task DataFrame_NoStream_ConnectionError() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -303,6 +272,9 @@ namespace System.Net.Http.Functional.Tests // As this is a connection level error, the client should see the request fail. await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); } } @@ -318,11 +290,8 @@ namespace System.Net.Http.Functional.Tests return; } - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -335,6 +304,9 @@ namespace System.Net.Http.Functional.Tests // As this is a connection level error, the client should see the request fail. await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); } } @@ -350,11 +322,8 @@ namespace System.Net.Http.Functional.Tests return; } - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -366,6 +335,193 @@ namespace System.Net.Http.Functional.Tests // As this is a connection level error, the client should see the request fail. await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); + } + } + + private static Frame MakeSimpleHeadersFrame(int streamId, bool endHeaders = false, bool endStream = false) => + new HeadersFrame(new byte[] { 0x88 }, // :status: 200 + (endHeaders ? FrameFlags.EndHeaders : FrameFlags.None) | (endStream ? FrameFlags.EndStream : FrameFlags.None), + 0, 0, 0, streamId); + + private static Frame MakeSimpleContinuationFrame(int streamId, bool endHeaders = false) => + new ContinuationFrame(new byte[] { 0x88 }, // :status: 200 + (endHeaders ? FrameFlags.EndHeaders : FrameFlags.None), + 0, 0, 0, streamId); + + private static Frame MakeSimpleDataFrame(int streamId, bool endStream = false) => + new DataFrame(new byte[] { 0x00 }, + (endStream ? FrameFlags.EndStream : FrameFlags.None), + 0, streamId); + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task ResponseStreamFrames_ContinuationBeforeHeaders_ConnectionError() + { + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + await server.EstablishConnectionAsync(); + int streamId = await server.ReadRequestHeaderAsync(); + + await server.WriteFrameAsync(MakeSimpleContinuationFrame(streamId)); + + // As this is a connection level error, the client should see the request fail. + await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); + } + } + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task ResponseStreamFrames_DataBeforeHeaders_ConnectionError() + { + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + await server.EstablishConnectionAsync(); + int streamId = await server.ReadRequestHeaderAsync(); + + await server.WriteFrameAsync(MakeSimpleDataFrame(streamId)); + + // As this is a connection level error, the client should see the request fail. + await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); + } + } + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task ResponseStreamFrames_HeadersAfterHeadersWithoutEndHeaders_ConnectionError() + { + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + await server.EstablishConnectionAsync(); + int streamId = await server.ReadRequestHeaderAsync(); + + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + + // As this is a connection level error, the client should see the request fail. + await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); + } + } + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task ResponseStreamFrames_HeadersAfterHeadersAndContinuationWithoutEndHeaders_ConnectionError() + { + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + await server.EstablishConnectionAsync(); + int streamId = await server.ReadRequestHeaderAsync(); + + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + await server.WriteFrameAsync(MakeSimpleContinuationFrame(streamId, endHeaders: false)); + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + + // As this is a connection level error, the client should see the request fail. + await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); + } + } + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task ResponseStreamFrames_HeadersAfterHeadersWithEndHeaders_ConnectionError() + { + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + await server.EstablishConnectionAsync(); + int streamId = await server.ReadRequestHeaderAsync(); + + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: true)); + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + + // As this is a connection level error, the client should see the request fail. + await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); + } + } + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task ResponseStreamFrames_HeadersAfterHeadersAndContinuationWithEndHeaders_ConnectionError() + { + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + await server.EstablishConnectionAsync(); + int streamId = await server.ReadRequestHeaderAsync(); + + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + await server.WriteFrameAsync(MakeSimpleContinuationFrame(streamId, endHeaders: true)); + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + + // As this is a connection level error, the client should see the request fail. + await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); + } + } + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task ResponseStreamFrames_DataAfterHeadersWithoutEndHeaders_ConnectionError() + { + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + await server.EstablishConnectionAsync(); + int streamId = await server.ReadRequestHeaderAsync(); + + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + await server.WriteFrameAsync(MakeSimpleDataFrame(streamId)); + + // As this is a connection level error, the client should see the request fail. + await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); + } + } + + [ConditionalFact(nameof(SupportsAlpn))] + public async Task ResponseStreamFrames_DataAfterHeadersAndContinuationWithoutEndHeaders_ConnectionError() + { + using (var server = Http2LoopbackServer.CreateServer()) + using (var client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + await server.EstablishConnectionAsync(); + int streamId = await server.ReadRequestHeaderAsync(); + + await server.WriteFrameAsync(MakeSimpleHeadersFrame(streamId, endHeaders: false)); + await server.WriteFrameAsync(MakeSimpleContinuationFrame(streamId, endHeaders: false)); + await server.WriteFrameAsync(MakeSimpleDataFrame(streamId)); + + // As this is a connection level error, the client should see the request fail. + await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); } } @@ -375,11 +531,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task GoAwayFrame_NonzeroStream_ConnectionError() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -392,17 +545,17 @@ namespace System.Net.Http.Functional.Tests // As this is a connection level error, the client should see the request fail. await Assert.ThrowsAsync(async () => await sendTask); + + // The client should close the connection as this is a fatal connection level error. + Assert.Null(await server.ReadFrameAsync(TimeSpan.FromSeconds(30))); } } [ConditionalFact(nameof(SupportsAlpn))] public async Task DataFrame_TooLong_ConnectionError() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -421,11 +574,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task CompletedResponse_FrameReceived_ConnectionError() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -456,11 +606,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task EmptyResponse_FrameReceived_ConnectionError() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -488,11 +635,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task ResetResponseStream_FrameReceived_ConnectionError() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); @@ -538,11 +682,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task GoAwayFrame_NoPendingStreams_ConnectionClosed() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { int streamId = await EstablishConnectionAndProcessOneRequestAsync(client, server); @@ -561,11 +702,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task GoAwayFrame_AllPendingStreamsValid_RequestsSucceedAndConnectionClosed() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { await EstablishConnectionAndProcessOneRequestAsync(client, server); @@ -656,14 +794,10 @@ namespace System.Net.Http.Functional.Tests const int InitialWindowSize = 65535; const int ContentSize = 100_000; - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - TestHelper.EnsureHttp2Feature(handler); - var content = new ByteArrayContent(TestHelper.GenerateRandomContent(ContentSize)); using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task clientTask = client.PostAsync(server.Address, content); @@ -764,14 +898,10 @@ namespace System.Net.Http.Functional.Tests const int DefaultInitialWindowSize = 65535; const int ContentSize = 100_000; - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - TestHelper.EnsureHttp2Feature(handler); - var content = new ByteArrayContent(TestHelper.GenerateRandomContent(ContentSize)); using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task clientTask = client.PostAsync(server.Address, content); @@ -889,11 +1019,8 @@ namespace System.Net.Http.Functional.Tests [ConditionalFact(nameof(SupportsAlpn))] public async Task Http2_MaxConcurrentStreams_LimitEnforced() { - HttpClientHandler handler = CreateHttpClientHandler(); - handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates; - using (var server = Http2LoopbackServer.CreateServer()) - using (var client = new HttpClient(handler)) + using (var client = CreateHttpClient()) { Task sendTask = client.GetAsync(server.Address); -- 2.7.4