[release/6.0-rc1] Enable SocketHttpHandler to decompress zlib or deflate (#57940)
authorgithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Mon, 23 Aug 2021 19:57:39 +0000 (15:57 -0400)
committerGitHub <noreply@github.com>
Mon, 23 Aug 2021 19:57:39 +0000 (15:57 -0400)
* Enable SocketHttpHandler to decompress zlib or deflate

Some servers incorrectly implement the deflate content-coding with the raw deflate algorithm rather than with deflate wrapped with a zlib header/footer.  Auto-detect whether to use ZLibStream or DeflateStream in order to accomodate both kinds of responses.

* Fix test build for WinHttpHandler on .NET Framework

* Apply suggestions from code review

* Add decompression test for empty response body

* Add decompression tests for multiple source content lengths

Co-authored-by: Stephen Toub <stoub@microsoft.com>
src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Decompression.cs
src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.RemoteServer.cs
src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTestBase.cs
src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DecompressionHandler.cs

index 88ac334..439cadf 100644 (file)
@@ -4,6 +4,7 @@
 using System.Collections.Generic;
 using System.IO;
 using System.IO.Compression;
+using System.Linq;
 using System.Net.Test.Common;
 using System.Text.RegularExpressions;
 using System.Threading.Tasks;
@@ -26,19 +27,29 @@ namespace System.Net.Http.Functional.Tests
 #endif
         public HttpClientHandler_Decompression_Test(ITestOutputHelper output) : base(output) { }
 
+        public static IEnumerable<object[]> DecompressedResponse_MethodSpecified_DecompressedContentReturned_MemberData() =>
+            from compressionName in new[] { "gzip", "zlib", "deflate", "br" }
+            from all in new[] { false, true }
+            from copyTo in new[] { false, true }
+            from contentLength in new[] { 0, 1, 12345 }
+            select new object[] { compressionName, all, copyTo, contentLength };
+
         [Theory]
-        [InlineData("gzip", false)]
-        [InlineData("gzip", true)]
-        [InlineData("deflate", false)]
-        [InlineData("deflate", true)]
-        [InlineData("br", false)]
-        [InlineData("br", true)]
+        [MemberData(nameof(DecompressedResponse_MethodSpecified_DecompressedContentReturned_MemberData))]
         [SkipOnPlatform(TestPlatforms.Browser, "AutomaticDecompression not supported on Browser")]
-        public async Task DecompressedResponse_MethodSpecified_DecompressedContentReturned(string encodingName, bool all)
+        public async Task DecompressedResponse_MethodSpecified_DecompressedContentReturned(string compressionName, bool all, bool useCopyTo, int contentLength)
         {
+            if (IsWinHttpHandler &&
+                (compressionName == "br" || compressionName == "zlib"))
+            {
+                // brotli and zlib not supported on WinHttpHandler
+                return;
+            }
+
             Func<Stream, Stream> compress;
             DecompressionMethods methods;
-            switch (encodingName)
+            string encodingName = compressionName;
+            switch (compressionName)
             {
                 case "gzip":
                     compress = s => new GZipStream(s, CompressionLevel.Optimal, leaveOpen: true);
@@ -47,32 +58,27 @@ namespace System.Net.Http.Functional.Tests
 
 #if !NETFRAMEWORK
                 case "br":
-                    if (IsWinHttpHandler)
-                    {
-                        // Brotli only supported on SocketsHttpHandler.
-                        return;
-                    }
-
                     compress = s => new BrotliStream(s, CompressionLevel.Optimal, leaveOpen: true);
                     methods = all ? DecompressionMethods.Brotli : _all;
                     break;
 
-                case "deflate":
-                    // WinHttpHandler continues to use DeflateStream as it doesn't have a newer build than netstandard2.0
-                    // and doesn't have access to ZLibStream.
-                    compress = IsWinHttpHandler ?
-                        new Func<Stream, Stream>(s => new DeflateStream(s, CompressionLevel.Optimal, leaveOpen: true)) :
-                        new Func<Stream, Stream>(s => new ZLibStream(s, CompressionLevel.Optimal, leaveOpen: true));
+                case "zlib":
+                    compress = s => new ZLibStream(s, CompressionLevel.Optimal, leaveOpen: true);
                     methods = all ? DecompressionMethods.Deflate : _all;
+                    encodingName = "deflate";
                     break;
 #endif
 
+                case "deflate":
+                    compress = s => new DeflateStream(s, CompressionLevel.Optimal, leaveOpen: true);
+                    methods = all ? DecompressionMethods.Deflate : _all;
+                    break;
+
                 default:
-                    Assert.Contains(encodingName, new[] { "br", "deflate", "gzip" });
-                    return;
+                    throw new Exception($"Unexpected compression: {compressionName}");
             }
 
-            var expectedContent = new byte[12345];
+            var expectedContent = new byte[contentLength];
             new Random(42).NextBytes(expectedContent);
 
             await LoopbackServer.CreateClientAndServerAsync(async uri =>
@@ -81,7 +87,7 @@ namespace System.Net.Http.Functional.Tests
                 using (HttpClient client = CreateHttpClient(handler))
                 {
                     handler.AutomaticDecompression = methods;
-                    Assert.Equal<byte>(expectedContent, await client.GetByteArrayAsync(uri));
+                    AssertExtensions.SequenceEqual(expectedContent, await client.GetByteArrayAsync(TestAsync, useCopyTo, uri));
                 }
             }, async server =>
             {
@@ -99,33 +105,39 @@ namespace System.Net.Http.Functional.Tests
 
         public static IEnumerable<object[]> DecompressedResponse_MethodNotSpecified_OriginalContentReturned_MemberData()
         {
-            yield return new object[]
+            foreach (bool useCopyTo in new[] { false, true })
             {
-                "gzip",
-                new Func<Stream, Stream>(s => new GZipStream(s, CompressionLevel.Optimal, leaveOpen: true)),
-                DecompressionMethods.None
-            };
+                yield return new object[]
+                {
+                    "gzip",
+                    new Func<Stream, Stream>(s => new GZipStream(s, CompressionLevel.Optimal, leaveOpen: true)),
+                    DecompressionMethods.None,
+                    useCopyTo
+                };
 #if !NETFRAMEWORK
-            yield return new object[]
-            {
-                "deflate",
-                new Func<Stream, Stream>(s => new ZLibStream(s, CompressionLevel.Optimal, leaveOpen: true)),
-                DecompressionMethods.Brotli
-            };
-            yield return new object[]
-            {
-                "br",
-                new Func<Stream, Stream>(s => new BrotliStream(s, CompressionLevel.Optimal, leaveOpen: true)),
-                DecompressionMethods.Deflate | DecompressionMethods.GZip
-            };
+                yield return new object[]
+                {
+                    "deflate",
+                    new Func<Stream, Stream>(s => new ZLibStream(s, CompressionLevel.Optimal, leaveOpen: true)),
+                    DecompressionMethods.Brotli,
+                    useCopyTo
+                };
+                yield return new object[]
+                {
+                    "br",
+                    new Func<Stream, Stream>(s => new BrotliStream(s, CompressionLevel.Optimal, leaveOpen: true)),
+                    DecompressionMethods.Deflate | DecompressionMethods.GZip,
+                    useCopyTo
+                };
 #endif
+            }
         }
 
         [Theory]
         [MemberData(nameof(DecompressedResponse_MethodNotSpecified_OriginalContentReturned_MemberData))]
         [SkipOnPlatform(TestPlatforms.Browser, "AutomaticDecompression not supported on Browser")]
         public async Task DecompressedResponse_MethodNotSpecified_OriginalContentReturned(
-            string encodingName, Func<Stream, Stream> compress, DecompressionMethods methods)
+            string encodingName, Func<Stream, Stream> compress, DecompressionMethods methods, bool useCopyTo)
         {
             var expectedContent = new byte[12345];
             new Random(42).NextBytes(expectedContent);
@@ -143,7 +155,7 @@ namespace System.Net.Http.Functional.Tests
                 using (HttpClient client = CreateHttpClient(handler))
                 {
                     handler.AutomaticDecompression = methods;
-                    Assert.Equal<byte>(compressedContent, await client.GetByteArrayAsync(uri));
+                    AssertExtensions.SequenceEqual(compressedContent, await client.GetByteArrayAsync(TestAsync, useCopyTo, uri));
                 }
             }, async server =>
             {
@@ -157,6 +169,33 @@ namespace System.Net.Http.Functional.Tests
         }
 
         [Theory]
+        [InlineData("gzip", DecompressionMethods.GZip)]
+#if !NETFRAMEWORK
+        [InlineData("deflate", DecompressionMethods.Deflate)]
+        [InlineData("br", DecompressionMethods.Brotli)]
+#endif
+        [SkipOnPlatform(TestPlatforms.Browser, "AutomaticDecompression not supported on Browser")]
+        public async Task DecompressedResponse_EmptyBody_Success(string encodingName, DecompressionMethods methods)
+        {
+            await LoopbackServer.CreateClientAndServerAsync(async uri =>
+            {
+                using (HttpClientHandler handler = CreateHttpClientHandler())
+                using (HttpClient client = CreateHttpClient(handler))
+                {
+                    handler.AutomaticDecompression = methods;
+                    Assert.Equal(Array.Empty<byte>(), await client.GetByteArrayAsync(TestAsync, useCopyTo: false, uri));
+                }
+            }, async server =>
+            {
+                await server.AcceptConnectionAsync(async connection =>
+                {
+                    await connection.ReadRequestHeaderAsync();
+                    await connection.WriteStringAsync($"HTTP/1.1 200 OK\r\nContent-Encoding: {encodingName}\r\n\r\n");
+                });
+            });
+        }
+
+        [Theory]
 #if NETCOREAPP
         [InlineData(DecompressionMethods.Brotli, "br", "")]
         [InlineData(DecompressionMethods.Brotli, "br", "br")]
index 9fd29a9..a1e86c6 100644 (file)
@@ -1237,8 +1237,9 @@ namespace System.Net.Http.Functional.Tests
             {
                 yield return new object[] { remoteServer, remoteServer.GZipUri };
 
-                // Remote deflate endpoint isn't correctly following the deflate protocol.
-                //yield return new object[] { remoteServer, remoteServer.DeflateUri };
+                // Remote deflate endpoint isn't correctly following the deflate protocol,
+                // but SocketsHttpHandler makes it work, anyway.
+                yield return new object[] { remoteServer, remoteServer.DeflateUri };
             }
         }
 
@@ -1271,10 +1272,6 @@ namespace System.Net.Http.Functional.Tests
             }
         }
 
-        // The remote server endpoint was written to use DeflateStream, which isn't actually a correct
-        // implementation of the deflate protocol (the deflate protocol requires the zlib wrapper around
-        // deflate).  Until we can get that updated (and deal with previous releases still testing it
-        // via a DeflateStream-based implementation), we utilize httpbin.org to help validate behavior.
         [OuterLoop("Uses external servers")]
         [Theory]
         [InlineData("http://httpbin.org/deflate", "\"deflated\": true")]
index 7d56c4c..798b239 100644 (file)
@@ -196,5 +196,47 @@ namespace System.Net.Http.Functional.Tests
 #endif
             }
         }
+
+        public static Task<byte[]> GetByteArrayAsync(this HttpClient client, bool async, bool useCopyTo, Uri uri)
+        {
+#if NETCOREAPP
+            return Task.Run(async () =>
+            {
+                var m = new HttpRequestMessage(HttpMethod.Get, uri);
+                using HttpResponseMessage r = async ? await client.SendAsync(m, HttpCompletionOption.ResponseHeadersRead) : client.Send(m, HttpCompletionOption.ResponseHeadersRead);
+                using Stream s = async ? await r.Content.ReadAsStreamAsync() : r.Content.ReadAsStream();
+
+                var result = new MemoryStream();
+                if (useCopyTo)
+                {
+                    if (async)
+                    {
+                        await s.CopyToAsync(result);
+                    }
+                    else
+                    {
+                        s.CopyTo(result);
+                    }
+                }
+                else
+                {
+                    byte[] buffer = new byte[100];
+                    while (true)
+                    {
+                        int bytesRead = async ? await s.ReadAsync(buffer) : s.Read(buffer);
+                        if (bytesRead == 0)
+                        {
+                            break;
+                        }
+                        result.Write(buffer.AsSpan(0, bytesRead));
+                    }
+                }
+                return result.ToArray();
+            });
+#else
+            // For WinHttpHandler on .NET Framework, we fall back to ignoring async and useCopyTo.
+            return client.GetByteArrayAsync(uri);
+#endif
+        }
     }
 }
index 0e354ff..939a29e 100644 (file)
@@ -6,7 +6,6 @@ using System.Diagnostics;
 using System.IO;
 using System.IO.Compression;
 using System.Net.Http.Headers;
-using System.Runtime.CompilerServices;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -224,12 +223,203 @@ namespace System.Net.Http
             { }
 
             protected override Stream GetDecompressedStream(Stream originalStream) =>
-                // As described in RFC 2616, the deflate content-coding is actually
-                // the "zlib" format (RFC 1950) in combination with the "deflate"
-                // compression algrithm (RFC 1951).  So while potentially
-                // counterintuitive based on naming, this needs to use ZLibStream
-                // rather than DeflateStream.
-                new ZLibStream(originalStream, CompressionMode.Decompress);
+                new ZLibOrDeflateStream(originalStream);
+
+            /// <summary>Stream that wraps either <see cref="ZLibStream"/> or <see cref="DeflateStream"/> for decompression.</summary>
+            private sealed class ZLibOrDeflateStream : HttpBaseStream
+            {
+                // As described in RFC 2616, the deflate content-coding is the "zlib" format (RFC 1950) in combination with
+                // the "deflate" compression algrithm (RFC 1951). Thus, the right stream to use here is ZLibStream.  However,
+                // some servers incorrectly interpret "deflate" to mean the raw, unwrapped deflate protocol.  To account for
+                // that, this switches between using ZLibStream (correct) and DeflateStream (incorrect) in order to maximize
+                // compatibility with servers.
+
+                private readonly PeekFirstByteReadStream _stream;
+                private Stream? _decompressionStream;
+
+                public ZLibOrDeflateStream(Stream stream) => _stream = new PeekFirstByteReadStream(stream);
+
+                protected override void Dispose(bool disposing)
+                {
+                    if (disposing)
+                    {
+                        _decompressionStream?.Dispose();
+                        _stream.Dispose();
+                    }
+                    base.Dispose(disposing);
+                }
+
+                public override bool CanRead => true;
+                public override bool CanWrite => false;
+                public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken) => throw new NotSupportedException();
+
+                // On the first read request, peek at the first nibble of the response. If it's an 8, use ZLibStream, otherwise
+                // use DeflateStream. This heuristic works because we're deciding only between raw deflate and zlib wrapped around
+                // deflate, in which case the first nibble will always be 8 for zlib and never be 8 for deflate.
+                // https://stackoverflow.com/a/37528114 provides an explanation for why.
+
+                public override int Read(Span<byte> buffer)
+                {
+                    if (_decompressionStream is null)
+                    {
+                        int firstByte = _stream.PeekFirstByte();
+                        _decompressionStream = CreateDecompressionStream(firstByte, _stream);
+                    }
+
+                    return _decompressionStream.Read(buffer);
+                }
+
+                public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken)
+                {
+                    if (_decompressionStream is null)
+                    {
+                        return CreateAndReadAsync(this, buffer, cancellationToken);
+
+                        static async ValueTask<int> CreateAndReadAsync(ZLibOrDeflateStream thisRef, Memory<byte> buffer, CancellationToken cancellationToken)
+                        {
+                            int firstByte = await thisRef._stream.PeekFirstByteAsync(cancellationToken).ConfigureAwait(false);
+                            thisRef._decompressionStream = CreateDecompressionStream(firstByte, thisRef._stream);
+                            return await thisRef._decompressionStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
+                        }
+                    }
+
+                    return _decompressionStream.ReadAsync(buffer, cancellationToken);
+                }
+
+                public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
+                {
+                    ValidateCopyToArguments(destination, bufferSize);
+                    return Core(destination, bufferSize, cancellationToken);
+                    async Task Core(Stream destination, int bufferSize, CancellationToken cancellationToken)
+                    {
+                        if (_decompressionStream is null)
+                        {
+                            int firstByte = await _stream.PeekFirstByteAsync(cancellationToken).ConfigureAwait(false);
+                            _decompressionStream = CreateDecompressionStream(firstByte, _stream);
+                        }
+
+                        await _decompressionStream.CopyToAsync(destination, bufferSize, cancellationToken).ConfigureAwait(false);
+                    }
+                }
+
+                private static Stream CreateDecompressionStream(int firstByte, Stream stream) =>
+                    (firstByte & 0xF) == 8 ?
+                        new ZLibStream(stream, CompressionMode.Decompress) :
+                        new DeflateStream(stream, CompressionMode.Decompress);
+
+                private sealed class PeekFirstByteReadStream : HttpBaseStream
+                {
+                    private readonly Stream _stream;
+                    private byte _firstByte;
+                    private FirstByteStatus _firstByteStatus;
+
+                    public PeekFirstByteReadStream(Stream stream) => _stream = stream;
+
+                    protected override void Dispose(bool disposing)
+                    {
+                        if (disposing)
+                        {
+                            _stream.Dispose();
+                        }
+                        base.Dispose(disposing);
+                    }
+
+                    public override bool CanRead => true;
+                    public override bool CanWrite => false;
+                    public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken) => throw new NotSupportedException();
+
+                    public int PeekFirstByte()
+                    {
+                        Debug.Assert(_firstByteStatus == FirstByteStatus.None);
+
+                        int value = _stream.ReadByte();
+                        if (value == -1)
+                        {
+                            _firstByteStatus = FirstByteStatus.Consumed;
+                            return -1;
+                        }
+
+                        _firstByte = (byte)value;
+                        _firstByteStatus = FirstByteStatus.Available;
+                        return value;
+                    }
+
+                    public async ValueTask<int> PeekFirstByteAsync(CancellationToken cancellationToken)
+                    {
+                        Debug.Assert(_firstByteStatus == FirstByteStatus.None);
+
+                        var buffer = new byte[1];
+
+                        int bytesRead = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
+                        if (bytesRead == 0)
+                        {
+                            _firstByteStatus = FirstByteStatus.Consumed;
+                            return -1;
+                        }
+
+                        _firstByte = buffer[0];
+                        _firstByteStatus = FirstByteStatus.Available;
+                        return buffer[0];
+                    }
+
+                    public override int Read(Span<byte> buffer)
+                    {
+                        if (_firstByteStatus == FirstByteStatus.Available)
+                        {
+                            if (buffer.Length != 0)
+                            {
+                                buffer[0] = _firstByte;
+                                _firstByteStatus = FirstByteStatus.Consumed;
+                                return 1;
+                            }
+
+                            return 0;
+                        }
+
+                        Debug.Assert(_firstByteStatus == FirstByteStatus.Consumed);
+                        return _stream.Read(buffer);
+                    }
+
+                    public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken)
+                    {
+                        if (_firstByteStatus == FirstByteStatus.Available)
+                        {
+                            if (buffer.Length != 0)
+                            {
+                                buffer.Span[0] = _firstByte;
+                                _firstByteStatus = FirstByteStatus.Consumed;
+                                return new ValueTask<int>(1);
+                            }
+
+                            return new ValueTask<int>(0);
+                        }
+
+                        Debug.Assert(_firstByteStatus == FirstByteStatus.Consumed);
+                        return _stream.ReadAsync(buffer, cancellationToken);
+                    }
+
+                    public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
+                    {
+                        Debug.Assert(_firstByteStatus != FirstByteStatus.None);
+
+                        ValidateCopyToArguments(destination, bufferSize);
+                        if (_firstByteStatus == FirstByteStatus.Available)
+                        {
+                            await destination.WriteAsync(new byte[] { _firstByte }, cancellationToken).ConfigureAwait(false);
+                            _firstByteStatus = FirstByteStatus.Consumed;
+                        }
+
+                        await _stream.CopyToAsync(destination, bufferSize, cancellationToken).ConfigureAwait(false);
+                    }
+
+                    private enum FirstByteStatus : byte
+                    {
+                        None = 0,
+                        Available = 1,
+                        Consumed = 2
+                    }
+                }
+            }
         }
 
         private sealed class BrotliDecompressedContent : DecompressedContent