Detect truncated GZip streams (#61768)
authormfkl <me@martinfinkel.com>
Mon, 16 May 2022 20:07:48 +0000 (03:07 +0700)
committerGitHub <noreply@github.com>
Mon, 16 May 2022 20:07:48 +0000 (13:07 -0700)
* System.IO.Compression tests: add StreamTruncation_IsDetected

* System.IO.Compression: detect and throw on truncated streams

* account for edge case

* account for edge case in async version

* rename nonZeroInput to nonEmptyInput

* remove else to improve codegen

* run StreamTruncation_IsDetected for all three types of compression streams

* fixup test build and run

* add stream corruption test

* review feedback - cosmetics

* make BrotliStream detect truncation

* make StreamCorruption_IsDetected run for gzip only

other types of stream can't detect corruption properly

* skip byte corruption which results in correct decompression

* code style

* add zlib corruption test, no skipping needed

* clarify why we skip bytes in gzip test

* add and use truncated error data message

Co-authored-by: Marcin Krystianc <marcin.krystianc@gmail.com>
src/libraries/Common/tests/System/IO/Compression/CompressionStreamUnitTestBase.cs
src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs
src/libraries/System.IO.Compression.Brotli/src/Resources/Strings.resx
src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs
src/libraries/System.IO.Compression.Brotli/tests/System.IO.Compression.Brotli.Tests.csproj
src/libraries/System.IO.Compression.ZipFile/tests/ZipFile.Create.cs
src/libraries/System.IO.Compression/src/Resources/Strings.resx
src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs
src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/Inflater.cs
src/libraries/System.IO.Compression/tests/CompressionStreamUnitTests.Gzip.cs
src/libraries/System.IO.Compression/tests/CompressionStreamUnitTests.ZLib.cs

index ba9e6b8..a7a2dc9 100644 (file)
@@ -2,10 +2,13 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Collections.Generic;
+using System.IO.Compression.Tests;
+using System.Linq;
 using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
+using Xunit.Sdk;
 
 namespace System.IO.Compression
 {
@@ -13,6 +16,16 @@ namespace System.IO.Compression
     {
         private const int TaskTimeout = 30 * 1000; // Generous timeout for official test runs
 
+        public enum TestScenario
+        {
+            ReadByte,
+            ReadByteAsync,
+            Read,
+            ReadAsync,
+            Copy,
+            CopyAsync,
+        }
+
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
         public virtual void FlushAsync_DuringWriteAsync()
         {
@@ -475,6 +488,85 @@ namespace System.IO.Compression
             Assert.True(fastestLength >= optimalLength);
             Assert.True(optimalLength >= smallestLength);
         }
+
+        [Theory]
+        [InlineData(TestScenario.ReadAsync)]
+        [InlineData(TestScenario.Read)]
+        [InlineData(TestScenario.Copy)]
+        [InlineData(TestScenario.CopyAsync)]
+        [InlineData(TestScenario.ReadByte)]
+        [InlineData(TestScenario.ReadByteAsync)]
+        public async Task StreamTruncation_IsDetected(TestScenario scenario)
+        {
+            var buffer = new byte[16];
+            byte[] source = Enumerable.Range(0, 64).Select(i => (byte)i).ToArray();
+            byte[] compressedData;
+            using (var compressed = new MemoryStream())
+            using (Stream compressor = CreateStream(compressed, CompressionMode.Compress))
+            {
+                foreach (byte b in source)
+                {
+                    compressor.WriteByte(b);
+                }
+
+                compressor.Dispose();
+                compressedData = compressed.ToArray();
+            }
+
+            for (var i = 1; i <= compressedData.Length; i += 1)
+            {
+                bool expectException = i < compressedData.Length;
+                using (var compressedStream = new MemoryStream(compressedData.Take(i).ToArray()))
+                {
+                    using (Stream decompressor = CreateStream(compressedStream, CompressionMode.Decompress))
+                    {
+                        var decompressedStream = new MemoryStream();
+
+                        try
+                        {
+                            switch (scenario)
+                            {
+                                case TestScenario.Copy:
+                                    decompressor.CopyTo(decompressedStream);
+                                    break;
+
+                                case TestScenario.CopyAsync:
+                                    await decompressor.CopyToAsync(decompressedStream);
+                                    break;
+
+                                case TestScenario.Read:
+                                    while (ZipFileTestBase.ReadAllBytes(decompressor, buffer, 0, buffer.Length) != 0) { };
+                                    break;
+
+                                case TestScenario.ReadAsync:
+                                    while (await ZipFileTestBase.ReadAllBytesAsync(decompressor, buffer, 0, buffer.Length) != 0) { };
+                                    break;
+
+                                case TestScenario.ReadByte:
+                                    while (decompressor.ReadByte() != -1) { }
+                                    break;
+
+                                case TestScenario.ReadByteAsync:
+                                    while (await decompressor.ReadByteAsync() != -1) { }
+                                    break;
+                            }
+                        }
+                        catch (InvalidDataException e)
+                        {
+                            if (expectException)
+                                continue;
+
+                            throw new XunitException($"An unexpected error occured while decompressing data:{e}");
+                        }
+
+                        if (expectException)
+                        {
+                            throw new XunitException($"Truncated stream was decompressed successfully but exception was expected: length={i}/{compressedData.Length}");
+                        }
+                    }
+                }
+            }
+        }
     }
 
     internal sealed class BadWrappedStream : MemoryStream
index 9308fe1..4e105b0 100644 (file)
@@ -65,6 +65,17 @@ namespace System.IO.Compression.Tests
             }
         }
 
+        public static async Task<int> ReadAllBytesAsync(Stream stream, byte[] buffer, int offset, int count)
+        {
+            int bytesRead;
+            int totalRead = 0;
+            while ((bytesRead = await stream.ReadAsync(buffer, offset + totalRead, count - totalRead)) != 0)
+            {
+                totalRead += bytesRead;
+            }
+            return totalRead;
+        }
+
         public static int ReadAllBytes(Stream stream, byte[] buffer, int offset, int count)
         {
             int bytesRead;
@@ -100,6 +111,11 @@ namespace System.IO.Compression.Tests
             StreamsEqual(ast, bst, -1);
         }
 
+        public static async Task StreamsEqualAsync(Stream ast, Stream bst)
+        {
+            await StreamsEqualAsync(ast, bst, -1);
+        }
+
         public static void StreamsEqual(Stream ast, Stream bst, int blocksToRead)
         {
             if (ast.CanSeek)
@@ -122,8 +138,44 @@ namespace System.IO.Compression.Tests
                 if (blocksToRead != -1 && blocksRead >= blocksToRead)
                     break;
 
-                ac = ReadAllBytes(ast, ad, 0, 4096);
-                bc = ReadAllBytes(bst, bd, 0, 4096);
+                ac = ReadAllBytes(ast, ad, 0, bufSize);
+                bc = ReadAllBytes(bst, bd, 0, bufSize);
+
+                if (ac != bc)
+                {
+                    bd = NormalizeLineEndings(bd);
+                }
+
+                Assert.True(ArraysEqual<byte>(ad, bd, ac), "Stream contents not equal: " + ast.ToString() + ", " + bst.ToString());
+
+                blocksRead++;
+            } while (ac == bufSize);
+        }
+
+        public static async Task StreamsEqualAsync(Stream ast, Stream bst, int blocksToRead)
+        {
+            if (ast.CanSeek)
+                ast.Seek(0, SeekOrigin.Begin);
+            if (bst.CanSeek)
+                bst.Seek(0, SeekOrigin.Begin);
+
+            const int bufSize = 4096;
+            byte[] ad = new byte[bufSize];
+            byte[] bd = new byte[bufSize];
+
+            int ac = 0;
+            int bc = 0;
+
+            int blocksRead = 0;
+
+            //assume read doesn't do weird things
+            do
+            {
+                if (blocksToRead != -1 && blocksRead >= blocksToRead)
+                    break;
+
+                ac = await ReadAllBytesAsync(ast, ad, 0, bufSize);
+                bc = await ReadAllBytesAsync(bst, bd, 0, bufSize);
 
                 if (ac != bc)
                 {
@@ -133,7 +185,7 @@ namespace System.IO.Compression.Tests
                 Assert.True(ArraysEqual<byte>(ad, bd, ac), "Stream contents not equal: " + ast.ToString() + ", " + bst.ToString());
 
                 blocksRead++;
-            } while (ac == 4096);
+            } while (ac == bufSize);
         }
 
         public static async Task IsZipSameAsDirAsync(string archiveFile, string directory, ZipArchiveMode mode)
index 355df1f..4f892bf 100644 (file)
   <data name="BrotliStream_Decompress_InvalidStream" xml:space="preserve">
     <value>BrotliStream.BaseStream returned more bytes than requested in Read.</value>
   </data>
+  <data name="BrotliStream_Decompress_TruncatedData" xml:space="preserve">
+    <value>Decoder ran into truncated data.</value>
+  </data>
   <data name="IOCompressionBrotli_PlatformNotSupported" xml:space="preserve">
     <value>System.IO.Compression.Brotli is not supported on this platform.</value>
   </data>
index 894a937..37740cd 100644 (file)
@@ -15,6 +15,7 @@ namespace System.IO.Compression
         private BrotliDecoder _decoder;
         private int _bufferOffset;
         private int _bufferCount;
+        private bool _nonEmptyInput;
 
         /// <summary>Reads a number of decompressed bytes into the specified byte array.</summary>
         /// <param name="buffer">The array used to store decompressed bytes.</param>
@@ -65,9 +66,12 @@ namespace System.IO.Compression
                 int bytesRead = _stream.Read(_buffer, _bufferCount, _buffer.Length - _bufferCount);
                 if (bytesRead <= 0)
                 {
+                    if (_nonEmptyInput && !buffer.IsEmpty)
+                        ThrowTruncatedInvalidData();
                     break;
                 }
 
+                _nonEmptyInput = true;
                 _bufferCount += bytesRead;
 
                 if (_bufferCount > _buffer.Length)
@@ -150,9 +154,12 @@ namespace System.IO.Compression
                         int bytesRead = await _stream.ReadAsync(_buffer.AsMemory(_bufferCount), cancellationToken).ConfigureAwait(false);
                         if (bytesRead <= 0)
                         {
+                            if (_nonEmptyInput && !buffer.IsEmpty)
+                                ThrowTruncatedInvalidData();
                             break;
                         }
 
+                        _nonEmptyInput = true;
                         _bufferCount += bytesRead;
 
                         if (_bufferCount > _buffer.Length)
@@ -231,5 +238,8 @@ namespace System.IO.Compression
             // The stream is either malicious or poorly implemented and returned a number of
             // bytes larger than the buffer supplied to it.
             throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream);
+
+        private static void ThrowTruncatedInvalidData() =>
+            throw new InvalidDataException(SR.BrotliStream_Decompress_TruncatedData);
     }
 }
index d68a7be..6a60c88 100644 (file)
@@ -1,4 +1,4 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
     <TargetFrameworks>$(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser</TargetFrameworks>
     <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
              Link="Common\System\IO\Compression\CompressionStreamTestBase.cs" />
     <Compile Include="$(CommonTestPath)System\IO\Compression\CompressionStreamUnitTestBase.cs"
              Link="Common\System\IO\Compression\CompressionStreamUnitTestBase.cs" />
+    <Compile Include="$(CommonTestPath)System\IO\Compression\CRC.cs"
+             Link="Common\System\IO\Compression\CRC.cs" />
+    <Compile Include="$(CommonTestPath)System\IO\Compression\FileData.cs"
+             Link="Common\System\IO\Compression\FileData.cs" />
     <Compile Include="$(CommonTestPath)System\IO\Compression\LocalMemoryStream.cs"
              Link="Common\System\IO\Compression\LocalMemoryStream.cs" />
     <Compile Include="$(CommonTestPath)System\IO\Compression\StreamHelpers.cs"
              Link="Common\System\IO\Compression\StreamHelpers.cs" />
+    <Compile Include="$(CommonTestPath)System\IO\Compression\ZipTestHelper.cs"
+         Link="Common\System\IO\Compression\ZipTestHelper.cs" />
     <Compile Include="$(CommonTestPath)System\IO\TempFile.cs"
              Link="Common\System\IO\TempFile.cs" />
     <Compile Include="$(CommonPath)System\Threading\Tasks\TaskToApm.cs"
index 25bd7c6..7f450e7 100644 (file)
@@ -46,6 +46,30 @@ namespace System.IO.Compression.Tests
         }
 
         [Fact]
+        public async Task CreateFromDirectory_IncludeBaseDirectoryAsync()
+        {
+            string folderName = zfolder("normal");
+            string withBaseDir = GetTestFilePath();
+            ZipFile.CreateFromDirectory(folderName, withBaseDir, CompressionLevel.Optimal, true);
+
+            IEnumerable<string> expected = Directory.EnumerateFiles(zfolder("normal"), "*", SearchOption.AllDirectories);
+            using (ZipArchive actual_withbasedir = ZipFile.Open(withBaseDir, ZipArchiveMode.Read))
+            {
+                foreach (ZipArchiveEntry actualEntry in actual_withbasedir.Entries)
+                {
+                    string expectedFile = expected.Single(i => Path.GetFileName(i).Equals(actualEntry.Name));
+                    Assert.StartsWith("normal", actualEntry.FullName);
+                    Assert.Equal(new FileInfo(expectedFile).Length, actualEntry.Length);
+                    using (Stream expectedStream = File.OpenRead(expectedFile))
+                    using (Stream actualStream = actualEntry.Open())
+                    {
+                        await StreamsEqualAsync(expectedStream, actualStream);
+                    }
+                }
+            }
+        }
+
+        [Fact]
         public void CreateFromDirectoryUnicode()
         {
             string folderName = zfolder("unicode");
index 1f209ff..21b2964 100644 (file)
   <data name="SplitSpanned" xml:space="preserve">
     <value>Split or spanned archives are not supported.</value>
   </data>
+  <data name="TruncatedData" xml:space="preserve">
+    <value>Found truncated data while decoding.</value>
+  </data>
   <data name="UnexpectedEndOfStream" xml:space="preserve">
     <value>Zip file corrupt: unexpected end of stream reached.</value>
   </data>
index 70c4747..7d8a309 100644 (file)
@@ -279,6 +279,15 @@ namespace System.IO.Compression
                     int n = _stream.Read(_buffer, 0, _buffer.Length);
                     if (n <= 0)
                     {
+                        // - Inflater didn't return any data although a non-empty output buffer was passed by the caller.
+                        // - More input is needed but there is no more input available.
+                        // - Inflation is not finished yet.
+                        // - Provided input wasn't completely empty
+                        // In such case, we are dealing with a truncated input stream.
+                        if (!buffer.IsEmpty && !_inflater.Finished() && _inflater.NonEmptyInput())
+                        {
+                            ThrowTruncatedInvalidData();
+                        }
                         break;
                     }
                     else if (n > _buffer.Length)
@@ -347,6 +356,9 @@ namespace System.IO.Compression
             // bytes < 0 || > than the buffer supplied to it.
             throw new InvalidDataException(SR.GenericInvalidData);
 
+        private static void ThrowTruncatedInvalidData() =>
+            throw new InvalidDataException(SR.TruncatedData);
+
         public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) =>
             TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState);
 
@@ -416,6 +428,15 @@ namespace System.IO.Compression
                             int n = await _stream.ReadAsync(new Memory<byte>(_buffer, 0, _buffer.Length), cancellationToken).ConfigureAwait(false);
                             if (n <= 0)
                             {
+                                // - Inflater didn't return any data although a non-empty output buffer was passed by the caller.
+                                // - More input is needed but there is no more input available.
+                                // - Inflation is not finished yet.
+                                // - Provided input wasn't completely empty
+                                // In such case, we are dealing with a truncated input stream.
+                                if (!_inflater.Finished() && _inflater.NonEmptyInput() && !buffer.IsEmpty)
+                                {
+                                    ThrowTruncatedInvalidData();
+                                }
                                 break;
                             }
                             else if (n > _buffer.Length)
@@ -436,6 +457,7 @@ namespace System.IO.Compression
                             // decompress into at least one byte of output, but it's a reasonable approximation for the 99% case.  If it's
                             // wrong, it just means that a caller using zero-byte reads as a way to delay getting a buffer to use for a
                             // subsequent call may end up getting one earlier than otherwise preferred.
+                            Debug.Assert(bytesRead == 0);
                             break;
                         }
                     }
@@ -893,6 +915,10 @@ namespace System.IO.Compression
 
                     // Now, use the source stream's CopyToAsync to push directly to our inflater via this helper stream
                     await _deflateStream._stream.CopyToAsync(this, _arrayPoolBuffer.Length, _cancellationToken).ConfigureAwait(false);
+                    if (!_deflateStream._inflater.Finished())
+                    {
+                        ThrowTruncatedInvalidData();
+                    }
                 }
                 finally
                 {
@@ -925,6 +951,10 @@ namespace System.IO.Compression
 
                     // Now, use the source stream's CopyToAsync to push directly to our inflater via this helper stream
                     _deflateStream._stream.CopyTo(this, _arrayPoolBuffer.Length);
+                    if (!_deflateStream._inflater.Finished())
+                    {
+                        ThrowTruncatedInvalidData();
+                    }
                 }
                 finally
                 {
index 4544353..3028ce5 100644 (file)
@@ -17,6 +17,7 @@ namespace System.IO.Compression
         private const int MinWindowBits = -15;              // WindowBits must be between -8..-15 to ignore the header, 8..15 for
         private const int MaxWindowBits = 47;               // zlib headers, 24..31 for GZip headers, or 40..47 for either Zlib or GZip
 
+        private bool _nonEmptyInput;                        // Whether there is any non empty input
         private bool _finished;                             // Whether the end of the stream has been reached
         private bool _isDisposed;                           // Prevents multiple disposals
         private readonly int _windowBits;                   // The WindowBits parameter passed to Inflater construction
@@ -34,6 +35,7 @@ namespace System.IO.Compression
         {
             Debug.Assert(windowBits >= MinWindowBits && windowBits <= MaxWindowBits);
             _finished = false;
+            _nonEmptyInput = false;
             _isDisposed = false;
             _windowBits = windowBits;
             InflateInit(windowBits);
@@ -176,6 +178,8 @@ namespace System.IO.Compression
 
         public bool NeedsInput() => _zlibStream.AvailIn == 0;
 
+        public bool NonEmptyInput() => _nonEmptyInput;
+
         public void SetInput(byte[] inputBuffer, int startIndex, int count)
         {
             Debug.Assert(NeedsInput(), "We have something left in previous input!");
@@ -200,6 +204,7 @@ namespace System.IO.Compression
                 _zlibStream.NextIn = (IntPtr)_inputBufferHandle.Pointer;
                 _zlibStream.AvailIn = (uint)inputBuffer.Length;
                 _finished = false;
+                _nonEmptyInput = true;
             }
         }
 
index 2f7195e..20a1d1a 100644 (file)
@@ -2,8 +2,10 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Buffers;
+using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
+using System.IO.Compression.Tests;
 using Xunit;
 
 namespace System.IO.Compression
@@ -153,15 +155,6 @@ namespace System.IO.Compression
             await TestConcatenatedGzipStreams(streamCount, scenario, bufferSize, bytesPerStream);
         }
 
-        public enum TestScenario
-        {
-            ReadByte,
-            Read,
-            ReadAsync,
-            Copy,
-            CopyAsync
-        }
-
         private async Task TestConcatenatedGzipStreams(int streamCount, TestScenario scenario, int bufferSize, int bytesPerStream = 1)
         {
             bool isCopy = scenario == TestScenario.Copy || scenario == TestScenario.CopyAsync;
@@ -281,6 +274,52 @@ namespace System.IO.Compression
             }
         }
 
+        [Fact]
+        public void StreamCorruption_IsDetected()
+        {
+            byte[] source = Enumerable.Range(0, 64).Select(i => (byte)i).ToArray();
+            var buffer = new byte[64];
+            byte[] compressedData;
+            using (var compressed = new MemoryStream())
+            using (Stream compressor = CreateStream(compressed, CompressionMode.Compress))
+            {
+                foreach (byte b in source)
+                {
+                    compressor.WriteByte(b);
+                }
+
+                compressor.Dispose();
+                compressedData = compressed.ToArray();
+            }
+
+            // the last 7 bytes of the 10-byte gzip header can be changed with no decompression error
+            // this is by design, so we skip them for the test
+            int[] byteToSkip = { 3, 4, 5, 6, 7, 8, 9 };
+
+            for (int byteToCorrupt = 0; byteToCorrupt < compressedData.Length; byteToCorrupt++)
+            {
+                if (byteToSkip.Contains(byteToCorrupt))
+                    continue;
+
+                // corrupt the data
+                compressedData[byteToCorrupt]++;
+
+                using (var decompressedStream = new MemoryStream(compressedData))
+                {
+                    using (Stream decompressor = CreateStream(decompressedStream, CompressionMode.Decompress))
+                    {
+                        Assert.Throws<InvalidDataException>(() =>
+                        {
+                            while (ZipFileTestBase.ReadAllBytes(decompressor, buffer, 0, buffer.Length) != 0);
+                        });
+                    }
+                }
+
+                // restore the data
+                compressedData[byteToCorrupt]--;
+            }
+        }
+
         private sealed class DerivedGZipStream : GZipStream
         {
             public bool ReadArrayInvoked = false, WriteArrayInvoked = false;
index 5f43177..df3a315 100644 (file)
@@ -2,6 +2,8 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Buffers;
+using System.IO.Compression.Tests;
+using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
@@ -16,5 +18,44 @@ namespace System.IO.Compression
         public override Stream CreateStream(Stream stream, CompressionLevel level, bool leaveOpen) => new ZLibStream(stream, level, leaveOpen);
         public override Stream BaseStream(Stream stream) => ((ZLibStream)stream).BaseStream;
         protected override string CompressedTestFile(string uncompressedPath) => Path.Combine("ZLibTestData", Path.GetFileName(uncompressedPath) + ".z");
+
+        [Fact]
+        public void StreamCorruption_IsDetected()
+        {
+            byte[] source = Enumerable.Range(0, 64).Select(i => (byte)i).ToArray();
+            var buffer = new byte[64];
+            byte[] compressedData;
+            using (var compressed = new MemoryStream())
+            using (Stream compressor = CreateStream(compressed, CompressionMode.Compress))
+            {
+                foreach (byte b in source)
+                {
+                    compressor.WriteByte(b);
+                }
+
+                compressor.Dispose();
+                compressedData = compressed.ToArray();
+            }
+
+            for (int byteToCorrupt = 0; byteToCorrupt < compressedData.Length; byteToCorrupt++)
+            {
+                // corrupt the data
+                compressedData[byteToCorrupt]++;
+
+                using (var decompressedStream = new MemoryStream(compressedData))
+                {
+                    using (Stream decompressor = CreateStream(decompressedStream, CompressionMode.Decompress))
+                    {
+                        Assert.Throws<InvalidDataException>(() =>
+                        {
+                            while (ZipFileTestBase.ReadAllBytes(decompressor, buffer, 0, buffer.Length) != 0);
+                        });
+                    }
+                }
+
+                // restore the data
+                compressedData[byteToCorrupt]--;
+            }
+        }
     }
 }