Override more Stream members on System.IO.Compression streams (#54518)
authorStephen Toub <stoub@microsoft.com>
Fri, 25 Jun 2021 07:35:29 +0000 (03:35 -0400)
committerGitHub <noreply@github.com>
Fri, 25 Jun 2021 07:35:29 +0000 (09:35 +0200)
src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateManaged/DeflateManagedStream.cs
src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateManaged/InflaterManaged.cs
src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateManaged/OutputWindow.cs
src/libraries/System.IO.Compression/src/System/IO/Compression/ZipArchiveEntry.cs
src/libraries/System.IO.Compression/src/System/IO/Compression/ZipCustomStreams.cs

index f24e15d..c3a2765 100644 (file)
@@ -2,7 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Diagnostics;
-using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -97,19 +97,22 @@ namespace System.IO.Compression
         public override int Read(byte[] buffer, int offset, int count)
         {
             ValidateBufferArguments(buffer, offset, count);
+            return Read(new Span<byte>(buffer, offset, count));
+        }
+
+        public override int Read(Span<byte> buffer)
+        {
             EnsureNotDisposed();
 
-            int bytesRead;
-            int currentOffset = offset;
-            int remainingCount = count;
+            int initialLength = buffer.Length;
 
+            int bytesRead;
             while (true)
             {
-                bytesRead = _inflater.Inflate(buffer, currentOffset, remainingCount);
-                currentOffset += bytesRead;
-                remainingCount -= bytesRead;
+                bytesRead = _inflater.Inflate(buffer);
+                buffer = buffer.Slice(bytesRead);
 
-                if (remainingCount == 0)
+                if (buffer.Length == 0)
                 {
                     break;
                 }
@@ -136,7 +139,13 @@ namespace System.IO.Compression
                 _inflater.SetInput(_buffer, 0, bytes);
             }
 
-            return count - remainingCount;
+            return initialLength - buffer.Length;
+        }
+
+        public override int ReadByte()
+        {
+            byte b = default;
+            return Read(MemoryMarshal.CreateSpan(ref b, 1)) == 1 ? b : -1;
         }
 
         private void EnsureNotDisposed()
@@ -169,7 +178,7 @@ namespace System.IO.Compression
             try
             {
                 // Try to read decompressed data in output buffer
-                int bytesRead = _inflater.Inflate(buffer);
+                int bytesRead = _inflater.Inflate(buffer.Span);
                 if (bytesRead != 0)
                 {
                     // If decompression output buffer is not empty, return immediately.
@@ -224,7 +233,7 @@ namespace System.IO.Compression
 
                     // Feed the data from base stream into decompression engine
                     _inflater.SetInput(_buffer, 0, bytesRead);
-                    bytesRead = _inflater.Inflate(buffer);
+                    bytesRead = _inflater.Inflate(buffer.Span);
 
                     if (bytesRead == 0 && !_inflater.Finished())
                     {
index 8372368..78b93e3 100644 (file)
@@ -95,7 +95,7 @@ namespace System.IO.Compression
 
         public int AvailableOutput => _output.AvailableBytes;
 
-        public int Inflate(Memory<byte> bytes)
+        public int Inflate(Span<byte> bytes)
         {
             // copy bytes from output to outputbytes if we have available bytes
             // if buffer is not filled up. keep decoding until no input are available
@@ -139,7 +139,7 @@ namespace System.IO.Compression
             return count;
         }
 
-        public int Inflate(byte[] bytes, int offset, int length) => Inflate(bytes.AsMemory(offset, length));
+        public int Inflate(byte[] bytes, int offset, int length) => Inflate(bytes.AsSpan(offset, length));
 
         //Each block of compressed data begins with 3 header bits
         // containing the following data:
index 2cdcf7d..471e866 100644 (file)
@@ -118,7 +118,7 @@ namespace System.IO.Compression
         public int AvailableBytes => _bytesUsed;
 
         /// <summary>Copy the decompressed bytes to output buffer.</summary>
-        public int CopyTo(Memory<byte> output)
+        public int CopyTo(Span<byte> output)
         {
             int copy_end;
 
@@ -140,19 +140,13 @@ namespace System.IO.Compression
             {
                 // this means we need to copy two parts separately
                 // copy the taillen bytes from the end of the output window
-                _window.AsSpan(WindowSize - tailLen, tailLen).CopyTo(output.Span);
+                _window.AsSpan(WindowSize - tailLen, tailLen).CopyTo(output);
                 output = output.Slice(tailLen, copy_end);
             }
-            _window.AsSpan(copy_end - output.Length, output.Length).CopyTo(output.Span);
+            _window.AsSpan(copy_end - output.Length, output.Length).CopyTo(output);
             _bytesUsed -= copied;
             Debug.Assert(_bytesUsed >= 0, "check this function and find why we copied more bytes than we have");
             return copied;
         }
-
-        /// <summary>Copy the decompressed bytes to output array.</summary>
-        public int CopyTo(byte[] output, int offset, int length)
-        {
-            return CopyTo(output.AsMemory(offset, length));
-        }
     }
 }
index 4f0072b..0481ee7 100644 (file)
@@ -4,7 +4,10 @@
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
+using System.Runtime.InteropServices;
 using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.IO.Compression
 {
@@ -1222,6 +1225,38 @@ namespace System.IO.Compression
                 _position += source.Length;
             }
 
+            public override void WriteByte(byte value) =>
+                Write(MemoryMarshal.CreateReadOnlySpan(ref value, 1));
+
+            public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+            {
+                ValidateBufferArguments(buffer, offset, count);
+                return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
+            }
+
+            public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+            {
+                ThrowIfDisposed();
+                Debug.Assert(CanWrite);
+
+                return !buffer.IsEmpty ?
+                    Core(buffer, cancellationToken) :
+                    default;
+
+                async ValueTask Core(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
+                {
+                    if (!_everWritten)
+                    {
+                        _everWritten = true;
+                        // write local header, we are good to go
+                        _usedZip64inLH = _entry.WriteLocalFileHeader(isEmptyFile: false);
+                    }
+
+                    await _crcSizeStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
+                    _position += buffer.Length;
+                }
+            }
+
             public override void Flush()
             {
                 ThrowIfDisposed();
@@ -1230,6 +1265,14 @@ namespace System.IO.Compression
                 _crcSizeStream.Flush();
             }
 
+            public override Task FlushAsync(CancellationToken cancellationToken)
+            {
+                ThrowIfDisposed();
+                Debug.Assert(CanWrite);
+
+                return _crcSizeStream.FlushAsync(cancellationToken);
+            }
+
             protected override void Dispose(bool disposing)
             {
                 if (disposing && !_isDisposed)
index dbc4f56..7f43043 100644 (file)
@@ -2,6 +2,9 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Diagnostics;
+using System.Runtime.InteropServices;
+using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.IO.Compression
 {
@@ -95,6 +98,38 @@ namespace System.IO.Compression
             return _baseStream.Read(buffer, offset, count);
         }
 
+        public override int Read(Span<byte> buffer)
+        {
+            ThrowIfDisposed();
+            ThrowIfCantRead();
+
+            return _baseStream.Read(buffer);
+        }
+
+        public override int ReadByte()
+        {
+            ThrowIfDisposed();
+            ThrowIfCantRead();
+
+            return _baseStream.ReadByte();
+        }
+
+        public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            ThrowIfDisposed();
+            ThrowIfCantRead();
+
+            return _baseStream.ReadAsync(buffer, offset, count, cancellationToken);
+        }
+
+        public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
+        {
+            ThrowIfDisposed();
+            ThrowIfCantRead();
+
+            return _baseStream.ReadAsync(buffer, cancellationToken);
+        }
+
         public override long Seek(long offset, SeekOrigin origin)
         {
             ThrowIfDisposed();
@@ -128,6 +163,30 @@ namespace System.IO.Compression
             _baseStream.Write(source);
         }
 
+        public override void WriteByte(byte value)
+        {
+            ThrowIfDisposed();
+            ThrowIfCantWrite();
+
+            _baseStream.WriteByte(value);
+        }
+
+        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            ThrowIfDisposed();
+            ThrowIfCantWrite();
+
+            return _baseStream.WriteAsync(buffer, offset, count, cancellationToken);
+        }
+
+        public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+        {
+            ThrowIfDisposed();
+            ThrowIfCantWrite();
+
+            return _baseStream.WriteAsync(buffer, cancellationToken);
+        }
+
         public override void Flush()
         {
             ThrowIfDisposed();
@@ -136,6 +195,14 @@ namespace System.IO.Compression
             _baseStream.Flush();
         }
 
+        public override Task FlushAsync(CancellationToken cancellationToken)
+        {
+            ThrowIfDisposed();
+            ThrowIfCantWrite();
+
+            return _baseStream.FlushAsync(cancellationToken);
+        }
+
         protected override void Dispose(bool disposing)
         {
             if (disposing && !_isDisposed)
@@ -259,6 +326,43 @@ namespace System.IO.Compression
             return ret;
         }
 
+        public override int ReadByte()
+        {
+            byte b = default;
+            return Read(MemoryMarshal.CreateSpan(ref b, 1)) == 1 ? b : -1;
+        }
+
+        public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            ValidateBufferArguments(buffer, offset, count);
+            return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
+        }
+
+        public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
+        {
+            ThrowIfDisposed();
+            ThrowIfCantRead();
+            return Core(buffer, cancellationToken);
+
+            async ValueTask<int> Core(Memory<byte> buffer, CancellationToken cancellationToken)
+            {
+                if (_superStream.Position != _positionInSuperStream)
+                {
+                    _superStream.Seek(_positionInSuperStream, SeekOrigin.Begin);
+                }
+
+                if (_positionInSuperStream > _endInSuperStream - buffer.Length)
+                {
+                    buffer = buffer.Slice(0, (int)(_endInSuperStream - _positionInSuperStream));
+                }
+
+                int ret = await _superStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
+
+                _positionInSuperStream += ret;
+                return ret;
+            }
+        }
+
         public override long Seek(long offset, SeekOrigin origin)
         {
             ThrowIfDisposed();
@@ -437,6 +541,39 @@ namespace System.IO.Compression
             _position += source.Length;
         }
 
+        public override void WriteByte(byte value) =>
+            Write(MemoryMarshal.CreateReadOnlySpan(ref value, 1));
+
+        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+        {
+            ValidateBufferArguments(buffer, offset, count);
+            return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
+        }
+
+        public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+        {
+            ThrowIfDisposed();
+            Debug.Assert(CanWrite);
+
+            return !buffer.IsEmpty ?
+                Core(buffer, cancellationToken) :
+                default;
+
+            async ValueTask Core(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+            {
+                if (!_everWritten)
+                {
+                    _initialPosition = _baseBaseStream.Position;
+                    _everWritten = true;
+                }
+
+                _checksum = Crc32Helper.UpdateCrc32(_checksum, buffer.Span);
+
+                await _baseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
+                _position += buffer.Length;
+            }
+        }
+
         public override void Flush()
         {
             ThrowIfDisposed();
@@ -447,6 +584,12 @@ namespace System.IO.Compression
             _baseStream.Flush();
         }
 
+        public override Task FlushAsync(CancellationToken cancellationToken)
+        {
+            ThrowIfDisposed();
+            return _baseStream.FlushAsync(cancellationToken);
+        }
+
         protected override void Dispose(bool disposing)
         {
             if (disposing && !_isDisposed)