public class Http2LoopbackServer : IDisposable
{
private Socket _listenSocket;
+ private Socket _connectionSocket;
private Stream _connectionStream;
private Http2Options _options;
private Uri _uri;
await _connectionStream.WriteAsync(writeBuffer, 0, writeBuffer.Length).ConfigureAwait(false);
}
+ // Read until the buffer is full
+ // Return false on EOF, throw on partial read
+ private async Task<bool> FillBufferAsync(Memory<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
+ {
+ int readBytes = await _connectionStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
+ if (readBytes == 0)
+ {
+ return false;
+ }
+
+ buffer = buffer.Slice(readBytes);
+ while (buffer.Length > 0)
+ {
+ readBytes = await _connectionStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
+ if (readBytes == 0)
+ {
+ throw new Exception("Connection closed when expecting more data.");
+ }
+
+ buffer = buffer.Slice(readBytes);
+ }
+
+ return true;
+ }
+
public async Task<Frame> ReadFrameAsync(TimeSpan timeout)
{
// Prep the timeout cancellation token.
// First read the frame headers, which should tell us how long the rest of the frame is.
byte[] headerBytes = new byte[Frame.FrameHeaderLength];
-
- int totalReadBytes = 0;
- while(totalReadBytes < Frame.FrameHeaderLength)
+ if (!await FillBufferAsync(headerBytes, timeoutCts.Token).ConfigureAwait(false))
{
- int readBytes = await _connectionStream.ReadAsync(headerBytes, totalReadBytes, Frame.FrameHeaderLength - totalReadBytes, timeoutCts.Token).ConfigureAwait(false);
- totalReadBytes += readBytes;
- if (readBytes == 0)
- {
- throw new Exception("Connection stream closed while attempting to read frame header.");
- }
+ return null;
}
Frame header = Frame.ReadFrom(headerBytes);
// Read the data segment of the frame, if it is present.
byte[] data = new byte[header.Length];
-
- totalReadBytes = 0;
- while(totalReadBytes < header.Length)
+ if (header.Length > 0 && !await FillBufferAsync(data, timeoutCts.Token).ConfigureAwait(false))
{
- int readBytes = await _connectionStream.ReadAsync(data, totalReadBytes, header.Length - totalReadBytes, timeoutCts.Token).ConfigureAwait(false);
- totalReadBytes += readBytes;
- if (readBytes == 0)
- {
- throw new Exception("Connection stream closed while attempting to read frame body.");
- }
+ throw new Exception("Connection stream closed while attempting to read frame body.");
}
if (_ignoreSettingsAck && header.Type == FrameType.Settings && header.Flags == FrameFlags.Ack)
// Returns the first 24 bytes read, which should be the connection preface.
public async Task<string> AcceptConnectionAsync()
{
- _connectionStream = new NetworkStream(await _listenSocket.AcceptAsync().ConfigureAwait(false), true);
+ if (_connectionSocket != null)
+ {
+ throw new InvalidOperationException("Connection already established");
+ }
+
+ _connectionSocket = await _listenSocket.AcceptAsync().ConfigureAwait(false);
+ _connectionStream = new NetworkStream(_connectionSocket, true);
if (_options.UseSsl)
{
}
_connectionStream = sslStream;
}
+
byte[] prefix = new byte[24];
-
- int totalReadBytes = 0;
- while(totalReadBytes < Frame.FrameHeaderLength)
+ if (!await FillBufferAsync(prefix).ConfigureAwait(false))
{
- int readBytes = await _connectionStream.ReadAsync(prefix, totalReadBytes, prefix.Length).ConfigureAwait(false);;
- totalReadBytes += readBytes;
- if (readBytes == 0)
- {
- throw new Exception("Connection stream closed while attempting to read connection preface.");
- }
+ throw new Exception("Connection stream closed while attempting to read connection preface.");
}
return System.Text.Encoding.UTF8.GetString(prefix, 0, prefix.Length);
ExpectSettingsAck();
}
+ // This will wait for the client to close the connection,
+ // and ignore any meaningless frames -- i.e. WINDOW_UPDATE or expected SETTINGS ACK --
+ // that we see while waiting for the client to close.
+ // Only call this after sending a GOAWAY.
+ public async Task WaitForConnectionShutdownAsync()
+ {
+ IgnoreWindowUpdates();
+
+ // Shutdown our send side, so the client knows there won't be any more frames coming.
+ _connectionSocket.Shutdown(SocketShutdown.Send);
+
+ Frame frame = await ReadFrameAsync(TimeSpan.FromSeconds(30)).ConfigureAwait(false);
+ if (frame != null)
+ {
+ throw new Exception($"Unexpected frame received while waiting for client shutdown: {frame}");
+ }
+
+ _connectionStream.Close();
+
+ _connectionSocket = null;
+ _connectionStream = null;
+
+ _ignoreSettingsAck = false;
+ _ignoreWindowUpdates = false;
+ }
+
public async Task<int> ReadRequestHeaderAsync()
{
// Receive HEADERS frame for request.
await WriteFrameAsync(headersFrame).ConfigureAwait(false);
}
+ public async Task SendResponseBodyAsync(int streamId, byte[] data, bool endStream = false)
+ {
+ DataFrame dataFrame = new DataFrame(data, endStream ? FrameFlags.EndStream : FrameFlags.None, 0, streamId);
+ await WriteFrameAsync(dataFrame).ConfigureAwait(false);
+ }
+
public void Dispose()
{
if (_listenSocket != null)
}
}
+ private static async Task<int> EstablishConnectionAndProcessOneRequestAsync(HttpClient client, Http2LoopbackServer server)
+ {
+ // Establish connection and send initial request/response to ensure connection is available for subsequent use
+ Task<HttpResponseMessage> sendTask = client.GetAsync(server.Address);
+
+ await server.EstablishConnectionAsync();
+
+ int streamId = await server.ReadRequestHeaderAsync();
+ await server.SendDefaultResponseAsync(streamId);
+
+ HttpResponseMessage response = await sendTask;
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+ Assert.Equal(0, (await response.Content.ReadAsByteArrayAsync()).Length);
+
+ return streamId;
+ }
+
+ [ConditionalFact(nameof(SupportsAlpn))]
+ public async Task GoAwayFrame_NoPendingStreams_ConnectionClosed()
+ {
+ HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+
+ using (var server = Http2LoopbackServer.CreateServer())
+ using (var client = new HttpClient(handler))
+ {
+ int streamId = await EstablishConnectionAndProcessOneRequestAsync(client, server);
+
+ // Send GOAWAY.
+ GoAwayFrame goAwayFrame = new GoAwayFrame(streamId, 0, new byte[0], 0);
+ await server.WriteFrameAsync(goAwayFrame);
+
+ // The client should close the connection.
+ await server.WaitForConnectionShutdownAsync();
+
+ // New request should cause a new connection
+ await EstablishConnectionAndProcessOneRequestAsync(client, server);
+ }
+ }
+
+ [ConditionalFact(nameof(SupportsAlpn))]
+ public async Task GoAwayFrame_AllPendingStreamsValid_RequestsSucceedAndConnectionClosed()
+ {
+ HttpClientHandler handler = CreateHttpClientHandler();
+ handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+
+ using (var server = Http2LoopbackServer.CreateServer())
+ using (var client = new HttpClient(handler))
+ {
+ await EstablishConnectionAndProcessOneRequestAsync(client, server);
+
+ // Issue three requests
+ Task<HttpResponseMessage> sendTask1 = client.GetAsync(server.Address);
+ Task<HttpResponseMessage> sendTask2 = client.GetAsync(server.Address);
+ Task<HttpResponseMessage> sendTask3 = client.GetAsync(server.Address);
+
+ // Receive three requests
+ int streamId1 = await server.ReadRequestHeaderAsync();
+ int streamId2 = await server.ReadRequestHeaderAsync();
+ int streamId3 = await server.ReadRequestHeaderAsync();
+
+ Assert.True(streamId1 < streamId2);
+ Assert.True(streamId2 < streamId3);
+
+ // Send various partial responses
+
+ // First response: Don't send anything yet
+
+ // Second response: Send headers, no body yet
+ await server.SendDefaultResponseHeadersAsync(streamId2);
+
+ // Third response: Send headers, partial body
+ await server.SendDefaultResponseHeadersAsync(streamId3);
+ await server.SendResponseBodyAsync(streamId3, new byte[5], endStream: false);
+
+ // Send a GOAWAY frame that indicates that we will process all three streams
+ GoAwayFrame goAwayFrame = new GoAwayFrame(streamId3, 0, new byte[0], 0);
+ await server.WriteFrameAsync(goAwayFrame);
+
+ // Finish sending responses
+ await server.SendDefaultResponseHeadersAsync(streamId1);
+ await server.SendResponseBodyAsync(streamId1, new byte[10], endStream: true);
+ await server.SendResponseBodyAsync(streamId2, new byte[10], endStream: true);
+ await server.SendResponseBodyAsync(streamId3, new byte[5], endStream: true);
+
+ // Receive all responses
+ HttpResponseMessage response1 = await sendTask1;
+ Assert.Equal(HttpStatusCode.OK, response1.StatusCode);
+ Assert.Equal(10, (await response1.Content.ReadAsByteArrayAsync()).Length);
+ HttpResponseMessage response2 = await sendTask2;
+ Assert.Equal(HttpStatusCode.OK, response2.StatusCode);
+ Assert.Equal(10, (await response2.Content.ReadAsByteArrayAsync()).Length);
+ HttpResponseMessage response3 = await sendTask3;
+ Assert.Equal(HttpStatusCode.OK, response3.StatusCode);
+ Assert.Equal(10, (await response3.Content.ReadAsByteArrayAsync()).Length);
+
+ // Now that all pending responses have been sent, the client should close the connection.
+ await server.WaitForConnectionShutdownAsync();
+
+ // New request should cause a new connection
+ await EstablishConnectionAndProcessOneRequestAsync(client, server);
+ }
+ }
+
private static async Task<int> ReadToEndOfStream(Http2LoopbackServer server, int streamId)
{
int bytesReceived = 0;