Add logging test for SocketsHttpHandler.PlaintextStreamFilter (#42690)
authorStephen Toub <stoub@microsoft.com>
Fri, 25 Sep 2020 02:46:57 +0000 (22:46 -0400)
committerGitHub <noreply@github.com>
Fri, 25 Sep 2020 02:46:57 +0000 (22:46 -0400)
* Add logging test for SocketsHttpHandler.PlaintextStreamFilter

* Make ByteLoggingStream more efficient

src/libraries/Common/tests/System/IO/ByteLoggingStream.cs [new file with mode: 0644]
src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs
src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs
src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj

diff --git a/src/libraries/Common/tests/System/IO/ByteLoggingStream.cs b/src/libraries/Common/tests/System/IO/ByteLoggingStream.cs
new file mode 100644 (file)
index 0000000..cd88159
--- /dev/null
@@ -0,0 +1,205 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#nullable enable
+using System.IO;
+using System.Runtime.InteropServices;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.IO
+{
+    internal sealed class BytesLoggingStream : Stream
+    {
+        public delegate void FormattedBytesCallback(Stream stream, ReadOnlySpan<char> hex, ReadOnlySpan<char> ascii);
+
+        [ThreadStatic]
+        private static char[]? s_hexBuffer;
+
+        [ThreadStatic]
+        private static char[]? s_asciiBuffer;
+
+        private readonly Stream _stream;
+        private readonly FormattedBytesCallback _readCallback;
+        private readonly FormattedBytesCallback _writeCallback;
+        private int _bytesPerLine = 24;
+
+        public BytesLoggingStream(Stream stream, FormattedBytesCallback writeCallback, FormattedBytesCallback readCallback)
+        {
+            _stream = stream;
+            _readCallback = readCallback;
+            _writeCallback = writeCallback;
+        }
+
+        public override bool CanRead => _stream.CanRead;
+        public override bool CanSeek => _stream.CanSeek;
+        public override bool CanWrite => _stream.CanWrite;
+        public override bool CanTimeout => _stream.CanTimeout;
+
+        public override long Length => _stream.Length;
+        public override long Position { get => _stream.Position; set => _stream.Position = value; }
+
+        public override void Flush() => _stream.Flush();
+        public override Task FlushAsync(CancellationToken cancellationToken) => _stream.FlushAsync(cancellationToken);
+
+        public override long Seek(long offset, SeekOrigin origin) => _stream.Seek(offset, origin);
+        public override void SetLength(long value) => _stream.SetLength(value);
+
+        public override int ReadTimeout { get => _stream.ReadTimeout; set => _stream.ReadTimeout = value; }
+        public override int WriteTimeout { get => _stream.WriteTimeout; set => _stream.WriteTimeout = value; }
+
+        protected override void Dispose(bool disposing)
+        {
+            if (disposing)
+            {
+                _stream.Dispose();
+            }
+        }
+
+        public int BytesPerLine
+        {
+            get => _bytesPerLine;
+            set
+            {
+                if (value < 1) throw new ArgumentOutOfRangeException(nameof(BytesPerLine));
+                _bytesPerLine = value;
+            }
+        }
+
+        public override int ReadByte()
+        {
+            int read = _stream.ReadByte();
+            if (read != -1)
+            {
+                byte b = (byte)read;
+                FormatBytes(read: true, MemoryMarshal.CreateReadOnlySpan(ref b, 1));
+            }
+            return read;
+        }
+
+        public override int Read(Span<byte> buffer)
+        {
+            int read = _stream.Read(buffer);
+            FormatBytes(read: true, buffer);
+            return read;
+        }
+
+        public override int Read(byte[] buffer, int offset, int count)
+        {
+            int read = _stream.Read(buffer, offset, count);
+            FormatBytes(read: true, buffer.AsSpan(offset, read));
+            return read;
+        }
+
+        public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            int read = await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
+            FormatBytes(read: true, buffer.AsSpan(offset, read));
+            return read;
+        }
+
+        public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
+        {
+            int read = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
+            FormatBytes(read: true, buffer.Span.Slice(0, read));
+            return read;
+        }
+
+        public override void WriteByte(byte value)
+        {
+            FormatBytes(read: false, MemoryMarshal.CreateReadOnlySpan(ref value, 1));
+            _stream.WriteByte(value);
+        }
+
+        public override void Write(ReadOnlySpan<byte> buffer)
+        {
+            FormatBytes(read: false, buffer);
+            _stream.Write(buffer);
+        }
+
+        public override void Write(byte[] buffer, int offset, int count)
+        {
+            FormatBytes(read: false, buffer.AsSpan(offset, count));
+            _stream.Write(buffer, offset, count);
+        }
+
+        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            FormatBytes(read: false, buffer.AsSpan(offset, count));
+            return _stream.WriteAsync(buffer, offset, count, cancellationToken);
+        }
+
+        public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+        {
+            FormatBytes(read: false, buffer.Span);
+            return _stream.WriteAsync(buffer, cancellationToken);
+        }
+
+        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
+        {
+            FormatBytes(read: false, buffer.AsSpan(offset, count));
+            return _stream.BeginWrite(buffer, offset, count, callback, state);
+        }
+
+        public override void EndWrite(IAsyncResult asyncResult) =>
+            _stream.EndWrite(asyncResult);
+
+        private void FormatBytes(bool read, ReadOnlySpan<byte> bytes)
+        {
+            if (bytes.IsEmpty)
+            {
+                return;
+            }
+
+            ReadOnlySpan<byte> hex = new byte[]
+            {
+                (byte)'0', (byte)'1', (byte)'2', (byte)'3', (byte)'4', (byte)'5', (byte)'6', (byte)'7',
+                (byte)'8', (byte)'9', (byte)'A', (byte)'B', (byte)'C', (byte)'D', (byte)'E', (byte)'F',
+            };
+
+            int bytesPerLine = _bytesPerLine;
+            int requiredHexLength = bytesPerLine * 3 - 1;
+
+            char[]? hexBuffer = s_hexBuffer;
+            if (hexBuffer is null || hexBuffer.Length < requiredHexLength)
+            {
+                s_hexBuffer = hexBuffer = new char[requiredHexLength];
+            }
+
+            char[]? asciiBuffer = s_asciiBuffer;
+            if (asciiBuffer is null || asciiBuffer.Length < bytesPerLine)
+            {
+                s_asciiBuffer = asciiBuffer = new char[bytesPerLine];
+            }
+
+            while (!bytes.IsEmpty)
+            {
+                ReadOnlySpan<byte> span = bytes.Slice(0, Math.Min(bytes.Length, bytesPerLine));
+                int hexPos = 0;
+                int asciiPos = 0;
+
+                for (int i = 0; i < span.Length; i++)
+                {
+                    byte b = span[i];
+                    hexBuffer[hexPos++] = (char)hex[b >> 4];
+                    hexBuffer[hexPos++] = (char)hex[b & 0XF];
+                    if (i != span.Length - 1)
+                    {
+                        hexBuffer[hexPos++] = ' ';
+                    }
+
+                    asciiBuffer[asciiPos++] =
+                        b switch
+                        {
+                            < 32 or >= 0x7F => '.',
+                            _ => (char)b,
+                        };
+                }
+
+                (read ? _readCallback : _writeCallback)(this, new ReadOnlySpan<char>(hexBuffer, 0, hexPos), new ReadOnlySpan<char>(asciiBuffer, 0, asciiPos));
+                bytes = bytes.Slice(span.Length);
+            }
+        }
+    }
+}
index f335a5d..214ebd3 100644 (file)
@@ -916,7 +916,7 @@ namespace System.Net.Test.Common
                 newHeaders.Add(new HttpHeaderData("Connection", "Close"));
                 if (!hasDate)
                 {
-                    newHeaders.Add(new HttpHeaderData("Date", "{DateTimeOffset.UtcNow:R}"));
+                    newHeaders.Add(new HttpHeaderData("Date", $"{DateTimeOffset.UtcNow:R}"));
                 }
 
                 await SendResponseAsync(statusCode, newHeaders, content: content).ConfigureAwait(false);
index 777bda4..e880668 100644 (file)
@@ -2309,7 +2309,7 @@ namespace System.Net.Http.Functional.Tests
         }
 
         [Fact]
-        public async void ConnectCallback_ContextHasCorrectProperties_Success()
+        public async Task ConnectCallback_ContextHasCorrectProperties_Success()
         {
             await LoopbackServerFactory.CreateClientAndServerAsync(
                 async uri =>
@@ -2592,7 +2592,7 @@ namespace System.Net.Http.Functional.Tests
         [Theory]
         [InlineData(true)]
         [InlineData(false)]
-        public async void PlaintextStreamFilter_ContextHasCorrectProperties_Success(bool useSsl)
+        public async Task PlaintextStreamFilter_ContextHasCorrectProperties_Success(bool useSsl)
         {
             GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
             await LoopbackServerFactory.CreateClientAndServerAsync(
@@ -2627,7 +2627,7 @@ namespace System.Net.Http.Functional.Tests
         [Theory]
         [InlineData(true)]
         [InlineData(false)]
-        public async void PlaintextStreamFilter_SimpleDelegatingStream_Success(bool useSsl)
+        public async Task PlaintextStreamFilter_SimpleDelegatingStream_Success(bool useSsl)
         {
             GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
             await LoopbackServerFactory.CreateClientAndServerAsync(
@@ -2745,7 +2745,7 @@ namespace System.Net.Http.Functional.Tests
         [Theory]
         [InlineData(true)]
         [InlineData(false)]
-        public async void PlaintextStreamFilter_ExceptionDuringCallback_ThrowsHttpRequestExceptionWithInnerException(bool useSsl)
+        public async Task PlaintextStreamFilter_ExceptionDuringCallback_ThrowsHttpRequestExceptionWithInnerException(bool useSsl)
         {
             Exception e = new Exception("hello!");
 
@@ -2783,7 +2783,7 @@ namespace System.Net.Http.Functional.Tests
         [Theory]
         [InlineData(true)]
         [InlineData(false)]
-        public async void PlaintextStreamFilter_ReturnsNull_ThrowsHttpRequestException(bool useSsl)
+        public async Task PlaintextStreamFilter_ReturnsNull_ThrowsHttpRequestException(bool useSsl)
         {
             GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
             await LoopbackServerFactory.CreateClientAndServerAsync(
@@ -2823,7 +2823,7 @@ namespace System.Net.Http.Functional.Tests
         [Theory]
         [InlineData(true)]
         [InlineData(false)]
-        public async void PlaintextStreamFilter_CustomStream_Success(bool useSsl)
+        public async Task PlaintextStreamFilter_CustomStream_Success(bool useSsl)
         {
             GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
             await LoopbackServerFactory.CreateClientAndServerAsync(
@@ -2866,6 +2866,55 @@ namespace System.Net.Http.Functional.Tests
                     catch (IOException) { }
                 }, options: options);
         }
+
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task PlaintextStreamFilter_Logging_Success(bool useSsl)
+        {
+            bool log = int.TryParse(Environment.GetEnvironmentVariable("DOTNET_TEST_SOCKETSHTTPHANDLERLOG"), out int value) && value == 1;
+
+            GenericLoopbackOptions options = new GenericLoopbackOptions() { UseSsl = useSsl };
+            await LoopbackServerFactory.CreateClientAndServerAsync(
+                async uri =>
+                {
+                    string sendText = "";
+                    string recvText = "";
+
+                    using HttpClientHandler handler = CreateHttpClientHandler();
+                    handler.ServerCertificateCustomValidationCallback = TestHelper.AllowAllCertificates;
+                    var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler);
+                    socketsHandler.PlaintextStreamFilter = (context, token) =>
+                    {
+                        Assert.Equal(HttpVersion.Version11, context.NegotiatedHttpVersion);
+
+                        static void Log(ref string text, bool log, string prefix, Stream stream, ReadOnlySpan<char> hex, ReadOnlySpan<char> ascii)
+                        {
+                            if (log) Console.WriteLine($"[{prefix} {stream.GetHashCode():X8}] {hex.ToString().PadRight(71)}  {ascii.ToString()}");
+                            text += ascii.ToString();
+                        }
+
+                        return ValueTask.FromResult<Stream>(new BytesLoggingStream(
+                            context.PlaintextStream,
+                            (stream, hex, ascii) => Log(ref sendText, log, "SEND", stream, hex, ascii),
+                            (stream, hex, ascii) => Log(ref recvText, log, "RECV", stream, hex, ascii)));
+                    };
+
+                    using HttpClient client = CreateHttpClient(handler);
+                    using HttpResponseMessage response = await client.GetAsync(uri);
+                    Assert.Equal("hello", await response.Content.ReadAsStringAsync());
+
+                    Assert.Contains("GET / HTTP/1.1", sendText);
+                    Assert.Contains("Host: ", sendText);
+
+                    Assert.Contains("HTTP/1.1 200 OK", recvText);
+                    Assert.Contains("hello", recvText);
+                },
+                async server =>
+                {
+                    await server.AcceptConnectionSendResponseAndCloseAsync(content: "hello");
+                }, options: options);
+        }
     }
 
     [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))]
index 75b93f3..95d2fdf 100644 (file)
@@ -92,6 +92,8 @@
              Link="Common\System\Diagnostics\Tracing\TestEventListener.cs" />
     <Compile Include="$(CommonTestPath)System\Diagnostics\Tracing\ConsoleEventListener.cs"
              Link="Common\System\Diagnostics\Tracing\ConsoleEventListener.cs" />
+    <Compile Include="$(CommonTestPath)System\IO\ByteLoggingStream.cs"
+             Link="Common\System\IO\ByteLoggingStream.cs" />
     <Compile Include="$(CommonTestPath)System\IO\DelegateStream.cs"
              Link="Common\System\IO\DelegateStream.cs" />
     <Compile Include="$(CommonTestPath)System\Net\RemoteServerQuery.cs"