From: Stephen Toub Date: Fri, 11 Jun 2021 21:31:15 +0000 (-0400) Subject: Fix Deflate/Brotli/CryptoStream handling of partial and zero-byte reads (#53644) X-Git-Tag: submit/tizen/20210909.063632~810 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=68dec6ac17053b1d2c7bfead7e785c18ccdc8dd0;p=platform%2Fupstream%2Fdotnet%2Fruntime.git Fix Deflate/Brotli/CryptoStream handling of partial and zero-byte reads (#53644) Stream.Read{Async} is supposed to return once at least a byte of data is available, and in particular, if there's any data already available, it shouldn't block. But Read{Async} on DeflateStream (and thus also GZipStream and ZLibStream), BrotliStream, and CryptoStream won't return until either it hits the end of the stream or the caller's buffer is filled. This makes it behave very unexpectedly when used in a context where the app is using a large read buffer but expects to be able to process data as it's available, e.g. in networked streaming scenarios where messages are being sent as part of bidirectional communication. This fixes that by stopping looping once any data is consumed. Just doing that, though, caused problems for zero-byte reads. Zero-byte reads are typically used by code that's trying to delay-allocate a buffer for the read data until data will be available to read. At present, however, zero-byte reads return immediately regardless of whether data is available to be consumed. I've changed the flow to make it so that zero-byte reads don't return until there's at least some data available as input to the inflater/transform (this, though, doesn't 100% guarantee the inflater/transform will be able to produce output data). Note that both of these changes have the potential to introduce breaks into an app that erroneously depended on these implementation details: - If an app passing in a buffer of size N to Read{Async} depended on that call always producing the requested number of bytes (rather than what the Stream contract defines), they might experience behavioral changes. - If an app passing in a zero-byte buffer expected it to return immediately, it might instead end up waiting until data was actually available. --- diff --git a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs index d41e072..43451d4 100644 --- a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs +++ b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs @@ -1562,11 +1562,6 @@ namespace System.IO.Tests /// Gets whether the stream guarantees that all data written to it will be flushed as part of Flush{Async}. /// protected virtual bool FlushGuaranteesAllDataWritten => true; - /// - /// Gets whether a stream implements an aggressive read that tries to fill the supplied buffer and only - /// stops when it does so or hits EOF. - /// - protected virtual bool ReadsMayBlockUntilBufferFullOrEOF => false; /// Gets whether reads for a count of 0 bytes block if no bytes are available to read. protected virtual bool BlocksOnZeroByteReads => false; /// @@ -1709,6 +1704,10 @@ namespace System.IO.Tests } } + public static IEnumerable ReadWrite_Modes => + from mode in Enum.GetValues() + select new object[] { mode }; + public static IEnumerable ReadWrite_Success_MemberData() => from mode in Enum.GetValues() from writeSize in new[] { 1, 42, 10 * 1024 } @@ -1786,6 +1785,54 @@ namespace System.IO.Tests } [Theory] + [MemberData(nameof(ReadWrite_Modes))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/51371", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)] + public virtual async Task ReadWrite_MessagesSmallerThanReadBuffer_Success(ReadWriteMode mode) + { + if (!FlushGuaranteesAllDataWritten) + { + return; + } + + foreach (CancellationToken nonCanceledToken in new[] { CancellationToken.None, new CancellationTokenSource().Token }) + { + using StreamPair streams = await CreateConnectedStreamsAsync(); + + foreach ((Stream writeable, Stream readable) in GetReadWritePairs(streams)) + { + byte[] writerBytes = RandomNumberGenerator.GetBytes(512); + var readerBytes = new byte[writerBytes.Length * 2]; + + // Repeatedly write then read a message smaller in size than the read buffer + for (int i = 0; i < 5; i++) + { + Task writes = Task.Run(async () => + { + await WriteAsync(mode, writeable, writerBytes, 0, writerBytes.Length, nonCanceledToken); + if (FlushRequiredToWriteData) + { + await writeable.FlushAsync(); + } + }); + + int n = 0; + while (n < writerBytes.Length) + { + int r = await ReadAsync(mode, readable, readerBytes, n, readerBytes.Length - n); + Assert.InRange(r, 1, writerBytes.Length - n); + n += r; + } + + Assert.Equal(writerBytes.Length, n); + AssertExtensions.SequenceEqual(writerBytes, readerBytes.AsSpan(0, writerBytes.Length)); + + await writes; + } + } + } + } + + [Theory] [MemberData(nameof(AllReadWriteModesAndValue), false)] [MemberData(nameof(AllReadWriteModesAndValue), true)] [ActiveIssue("https://github.com/dotnet/runtime/issues/51371", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)] @@ -2160,6 +2207,10 @@ namespace System.IO.Tests }); Assert.Equal(0, await zeroByteRead); + // Perform a second zero-byte read. + await Task.Run(() => ReadAsync(mode, readable, Array.Empty(), 0, 0)); + + // Now consume all the data. var readBytes = new byte[5]; int count = 0; while (count < readBytes.Length) @@ -2684,7 +2735,7 @@ namespace System.IO.Tests [InlineData(true, true)] public virtual async Task Dispose_Flushes(bool useAsync, bool leaveOpen) { - if (leaveOpen && (!SupportsLeaveOpen || ReadsMayBlockUntilBufferFullOrEOF)) + if (leaveOpen && !SupportsLeaveOpen) { return; } diff --git a/src/libraries/Common/tests/System/IO/Compression/CompressionStreamTestBase.cs b/src/libraries/Common/tests/System/IO/Compression/CompressionStreamTestBase.cs index 366c547..24f32fa 100644 --- a/src/libraries/Common/tests/System/IO/Compression/CompressionStreamTestBase.cs +++ b/src/libraries/Common/tests/System/IO/Compression/CompressionStreamTestBase.cs @@ -54,6 +54,6 @@ namespace System.IO.Compression protected override Type UnsupportedReadWriteExceptionType => typeof(InvalidOperationException); protected override bool WrappedUsableAfterClose => false; protected override bool FlushRequiredToWriteData => true; - protected override bool FlushGuaranteesAllDataWritten => false; + protected override bool BlocksOnZeroByteReads => true; } } diff --git a/src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs b/src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs index e8a0c08..6f31946 100644 --- a/src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs +++ b/src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs @@ -65,6 +65,17 @@ namespace System.IO.Compression.Tests } } + public static int ReadAllBytes(Stream stream, byte[] buffer, int offset, int count) + { + int bytesRead; + int totalRead = 0; + while ((bytesRead = stream.Read(buffer, offset + totalRead, count - totalRead)) != 0) + { + totalRead += bytesRead; + } + return totalRead; + } + public static bool ArraysEqual(T[] a, T[] b) where T : IComparable { if (a.Length != b.Length) return false; @@ -111,8 +122,8 @@ namespace System.IO.Compression.Tests if (blocksToRead != -1 && blocksRead >= blocksToRead) break; - ac = ast.Read(ad, 0, 4096); - bc = bst.Read(bd, 0, 4096); + ac = ReadAllBytes(ast, ad, 0, 4096); + bc = ReadAllBytes(bst, bd, 0, 4096); if (ac != bc) { @@ -170,7 +181,7 @@ namespace System.IO.Compression.Tests var buffer = new byte[entry.Length]; using (Stream entrystream = entry.Open()) { - entrystream.Read(buffer, 0, buffer.Length); + ReadAllBytes(entrystream, buffer, 0, buffer.Length); #if NETCOREAPP uint zipcrc = entry.Crc32; Assert.Equal(CRC.CalculateCRC(buffer), zipcrc); diff --git a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/BrotliStream.cs b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/BrotliStream.cs index 4401a47..73ecccf 100644 --- a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/BrotliStream.cs +++ b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/BrotliStream.cs @@ -173,7 +173,7 @@ namespace System.IO.Compression private void AsyncOperationStarting() { - if (Interlocked.CompareExchange(ref _activeAsyncOperation, 1, 0) != 0) + if (Interlocked.Exchange(ref _activeAsyncOperation, 1) != 0) { ThrowInvalidBeginCall(); } @@ -181,13 +181,11 @@ namespace System.IO.Compression private void AsyncOperationCompleting() { - int oldValue = Interlocked.CompareExchange(ref _activeAsyncOperation, 0, 1); - Debug.Assert(oldValue == 1, $"Expected {nameof(_activeAsyncOperation)} to be 1, got {oldValue}"); + Debug.Assert(_activeAsyncOperation == 1); + Volatile.Write(ref _activeAsyncOperation, 0); } - private static void ThrowInvalidBeginCall() - { + private static void ThrowInvalidBeginCall() => throw new InvalidOperationException(SR.InvalidBeginCall); - } } } diff --git a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs index f2f7d72..b9708e0 100644 --- a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs +++ b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; +using System.Diagnostics; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; @@ -42,8 +43,8 @@ namespace System.IO.Compression public override int ReadByte() { byte b = default; - int numRead = Read(MemoryMarshal.CreateSpan(ref b, 1)); - return numRead != 0 ? b : -1; + int bytesRead = Read(MemoryMarshal.CreateSpan(ref b, 1)); + return bytesRead != 0 ? b : -1; } /// Reads a sequence of bytes from the current Brotli stream to a byte span and advances the position within the Brotli stream by the number of bytes read. @@ -57,59 +58,25 @@ namespace System.IO.Compression if (_mode != CompressionMode.Decompress) throw new InvalidOperationException(SR.BrotliStream_Compress_UnsupportedOperation); EnsureNotDisposed(); - int totalWritten = 0; - OperationStatus lastResult = OperationStatus.DestinationTooSmall; - // We want to continue calling Decompress until we're either out of space for output or until Decompress indicates it is finished. - while (buffer.Length > 0 && lastResult != OperationStatus.Done) + int bytesWritten; + while (!TryDecompress(buffer, out bytesWritten)) { - if (lastResult == OperationStatus.NeedMoreData) + int bytesRead = _stream.Read(_buffer, _bufferCount, _buffer.Length - _bufferCount); + if (bytesRead <= 0) { - // Ensure any left over data is at the beginning of the array so we can fill the remainder. - if (_bufferCount > 0 && _bufferOffset != 0) - { - _buffer.AsSpan(_bufferOffset, _bufferCount).CopyTo(_buffer); - } - _bufferOffset = 0; - - int numRead = 0; - while (_bufferCount < _buffer.Length && ((numRead = _stream.Read(_buffer, _bufferCount, _buffer.Length - _bufferCount)) > 0)) - { - _bufferCount += numRead; - if (_bufferCount > _buffer.Length) - { - // The stream is either malicious or poorly implemented and returned a number of - // bytes larger than the buffer supplied to it. - throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream); - } - } - - if (_bufferCount <= 0) - { - break; - } - } - - lastResult = _decoder.Decompress(new ReadOnlySpan(_buffer, _bufferOffset, _bufferCount), buffer, out int bytesConsumed, out int bytesWritten); - if (lastResult == OperationStatus.InvalidData) - { - throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData); + break; } - if (bytesConsumed > 0) - { - _bufferOffset += bytesConsumed; - _bufferCount -= bytesConsumed; - } + _bufferCount += bytesRead; - if (bytesWritten > 0) + if (_bufferCount > _buffer.Length) { - totalWritten += bytesWritten; - buffer = buffer.Slice(bytesWritten); + ThrowInvalidStream(); } } - return totalWritten; + return bytesWritten; } /// Begins an asynchronous read operation. (Consider using the method instead.) @@ -169,73 +136,100 @@ namespace System.IO.Compression { return ValueTask.FromCanceled(cancellationToken); } - return FinishReadAsyncMemory(buffer, cancellationToken); - } - private async ValueTask FinishReadAsyncMemory(Memory buffer, CancellationToken cancellationToken) - { - AsyncOperationStarting(); - try + return Core(buffer, cancellationToken); + + async ValueTask Core(Memory buffer, CancellationToken cancellationToken) { - int totalWritten = 0; - OperationStatus lastResult = OperationStatus.DestinationTooSmall; - // We want to continue calling Decompress until we're either out of space for output or until Decompress indicates it is finished. - while (buffer.Length > 0 && lastResult != OperationStatus.Done) + AsyncOperationStarting(); + try { - if (lastResult == OperationStatus.NeedMoreData) + int bytesWritten; + while (!TryDecompress(buffer.Span, out bytesWritten)) { - // Ensure any left over data is at the beginning of the array so we can fill the remainder. - if (_bufferCount > 0 && _bufferOffset != 0) + int bytesRead = await _stream.ReadAsync(_buffer.AsMemory(_bufferCount), cancellationToken).ConfigureAwait(false); + if (bytesRead <= 0) { - _buffer.AsSpan(_bufferOffset, _bufferCount).CopyTo(_buffer); + break; } - _bufferOffset = 0; - int numRead = 0; - while (_bufferCount < _buffer.Length && - ((numRead = await _stream.ReadAsync(new Memory(_buffer, _bufferCount, _buffer.Length - _bufferCount), cancellationToken).ConfigureAwait(false)) > 0)) - { - _bufferCount += numRead; - if (_bufferCount > _buffer.Length) - { - // The stream is either malicious or poorly implemented and returned a number of - // bytes larger than the buffer supplied to it. - throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream); - } - } + _bufferCount += bytesRead; - if (_bufferCount <= 0) + if (_bufferCount > _buffer.Length) { - break; + ThrowInvalidStream(); } } - cancellationToken.ThrowIfCancellationRequested(); - lastResult = _decoder.Decompress(new ReadOnlySpan(_buffer, _bufferOffset, _bufferCount), buffer.Span, out int bytesConsumed, out int bytesWritten); - if (lastResult == OperationStatus.InvalidData) - { - throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData); - } + return bytesWritten; + } + finally + { + AsyncOperationCompleting(); + } + } + } - if (bytesConsumed > 0) - { - _bufferOffset += bytesConsumed; - _bufferCount -= bytesConsumed; - } + /// Tries to decode available data into the destination buffer. + /// The destination buffer for the decompressed data. + /// The number of bytes written to destination. + /// true if the caller should consider the read operation completed; otherwise, false. + private bool TryDecompress(Span destination, out int bytesWritten) + { + // Decompress any data we may have in our buffer. + OperationStatus lastResult = _decoder.Decompress(new ReadOnlySpan(_buffer, _bufferOffset, _bufferCount), destination, out int bytesConsumed, out bytesWritten); + if (lastResult == OperationStatus.InvalidData) + { + throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData); + } - if (bytesWritten > 0) - { - totalWritten += bytesWritten; - buffer = buffer.Slice(bytesWritten); - } - } + if (bytesConsumed != 0) + { + _bufferOffset += bytesConsumed; + _bufferCount -= bytesConsumed; + } + + // If we successfully decompressed any bytes, or if we've reached the end of the decompression, we're done. + if (bytesWritten != 0 || lastResult == OperationStatus.Done) + { + return true; + } - return totalWritten; + if (destination.IsEmpty) + { + // The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting + // a buffer until data is known to be available. We don't have perfect knowledge here, as _decoder.Decompress + // will return DestinationTooSmall whether or not more data is required. As such, we assume that if there's + // any data in our input buffer, it would have been decompressible into at least one byte of output, and + // otherwise we need to do a read on the underlying stream. This isn't perfect, because having input data + // doesn't necessarily mean it'll decompress into at least one byte of output, but it's a reasonable approximation + // for the 99% case. If it's wrong, it just means that a caller using zero-byte reads as a way to delay + // getting a buffer to use for a subsequent call may end up getting one earlier than otherwise preferred. + Debug.Assert(lastResult == OperationStatus.DestinationTooSmall); + if (_bufferCount != 0) + { + Debug.Assert(bytesWritten == 0); + return true; + } } - finally + + Debug.Assert( + lastResult == OperationStatus.NeedMoreData || + (lastResult == OperationStatus.DestinationTooSmall && destination.IsEmpty && _bufferCount == 0), $"{nameof(lastResult)} == {lastResult}, {nameof(destination.Length)} == {destination.Length}"); + + // Ensure any left over data is at the beginning of the array so we can fill the remainder. + if (_bufferCount != 0 && _bufferOffset != 0) { - AsyncOperationCompleting(); + new ReadOnlySpan(_buffer, _bufferOffset, _bufferCount).CopyTo(_buffer); } + _bufferOffset = 0; + + return false; } + + private static void ThrowInvalidStream() => + // The stream is either malicious or poorly implemented and returned a number of + // bytes larger than the buffer supplied to it. + throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream); } } diff --git a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs index 44a368f..efa17bc 100644 --- a/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs +++ b/src/libraries/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs @@ -68,8 +68,8 @@ namespace System.IO.Compression Span output = new Span(_buffer); while (lastResult == OperationStatus.DestinationTooSmall) { - int bytesConsumed = 0; - int bytesWritten = 0; + int bytesConsumed; + int bytesWritten; lastResult = _encoder.Compress(buffer, output, out bytesConsumed, out bytesWritten, isFinalBlock); if (lastResult == OperationStatus.InvalidData) throw new InvalidOperationException(SR.BrotliStream_Compress_InvalidData); @@ -176,7 +176,7 @@ namespace System.IO.Compression Span output = new Span(_buffer); while (lastResult == OperationStatus.DestinationTooSmall) { - int bytesWritten = 0; + int bytesWritten; lastResult = _encoder.Flush(output, out bytesWritten); if (lastResult == OperationStatus.InvalidData) throw new InvalidDataException(SR.BrotliStream_Compress_InvalidData); diff --git a/src/libraries/System.IO.Compression.Brotli/tests/CompressionStreamUnitTests.Brotli.cs b/src/libraries/System.IO.Compression.Brotli/tests/CompressionStreamUnitTests.Brotli.cs index 3960233..96eff30 100644 --- a/src/libraries/System.IO.Compression.Brotli/tests/CompressionStreamUnitTests.Brotli.cs +++ b/src/libraries/System.IO.Compression.Brotli/tests/CompressionStreamUnitTests.Brotli.cs @@ -14,7 +14,8 @@ namespace System.IO.Compression public override Stream CreateStream(Stream stream, CompressionLevel level) => new BrotliStream(stream, level); public override Stream CreateStream(Stream stream, CompressionLevel level, bool leaveOpen) => new BrotliStream(stream, level, leaveOpen); public override Stream BaseStream(Stream stream) => ((BrotliStream)stream).BaseStream; - protected override bool ReadsMayBlockUntilBufferFullOrEOF => true; + + protected override bool FlushGuaranteesAllDataWritten => false; // The tests are relying on an implementation detail of BrotliStream, using knowledge of its internal buffer size // in various test calculations. Currently the implementation is using the ArrayPool, which will round up to a diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs b/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs index d9739b0..84b7ad0 100644 --- a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs +++ b/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs @@ -104,12 +104,14 @@ namespace System.IO.Compression InitializeBuffer(); } + [MemberNotNull(nameof(_buffer))] private void InitializeBuffer() { Debug.Assert(_buffer == null); _buffer = ArrayPool.Shared.Rent(DefaultBufferSize); } + [MemberNotNull(nameof(_buffer))] private void EnsureBufferInitialized() { if (_buffer == null) @@ -259,83 +261,94 @@ namespace System.IO.Compression EnsureDecompressionMode(); EnsureNotDisposed(); EnsureBufferInitialized(); - - int totalRead = 0; - Debug.Assert(_inflater != null); + + int bytesRead; while (true) { - int bytesRead = _inflater.Inflate(buffer.Slice(totalRead)); - totalRead += bytesRead; - if (totalRead == buffer.Length) - { - break; - } - - // If the stream is finished then we have a few potential cases here: - // 1. DeflateStream => return - // 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input - // 3. GZipStream that is finished and appended with garbage => return - if (_inflater.Finished() && (!_inflater.IsGzipStream() || !_inflater.NeedsInput())) + // Try to decompress any data from the inflater into the caller's buffer. + // If we're able to decompress any bytes, or if decompression is completed, we're done. + bytesRead = _inflater.Inflate(buffer); + if (bytesRead != 0 || InflatorIsFinished) { break; } + // We were unable to decompress any data. If the inflater needs additional input + // data to proceed, read some to populate it. if (_inflater.NeedsInput()) { - Debug.Assert(_buffer != null); - int bytes = _stream.Read(_buffer, 0, _buffer.Length); - if (bytes <= 0) + int n = _stream.Read(_buffer, 0, _buffer.Length); + if (n <= 0) { break; } - else if (bytes > _buffer.Length) + else if (n > _buffer.Length) + { + ThrowGenericInvalidData(); + } + else { - // The stream is either malicious or poorly implemented and returned a number of - // bytes larger than the buffer supplied to it. - throw new InvalidDataException(SR.GenericInvalidData); + _inflater.SetInput(_buffer, 0, n); } + } - _inflater.SetInput(_buffer, 0, bytes); + if (buffer.IsEmpty) + { + // The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting + // a buffer until data is known to be available. We don't have perfect knowledge here, as _inflater.Inflate + // will return 0 whether or not more data is required, and having input data doesn't necessarily mean it'll + // decompress into at least one byte of output, but it's a reasonable approximation for the 99% case. If it's + // wrong, it just means that a caller using zero-byte reads as a way to delay getting a buffer to use for a + // subsequent call may end up getting one earlier than otherwise preferred. + Debug.Assert(bytesRead == 0); + break; } } - return totalRead; + return bytesRead; } + private bool InflatorIsFinished => + // If the stream is finished then we have a few potential cases here: + // 1. DeflateStream => return + // 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input + // 3. GZipStream that is finished and appended with garbage => return + _inflater!.Finished() && + (!_inflater.IsGzipStream() || !_inflater.NeedsInput()); + private void EnsureNotDisposed() { if (_stream == null) ThrowStreamClosedException(); - } - private static void ThrowStreamClosedException() - { - throw new ObjectDisposedException(nameof(DeflateStream), SR.ObjectDisposed_StreamClosed); + static void ThrowStreamClosedException() => + throw new ObjectDisposedException(nameof(DeflateStream), SR.ObjectDisposed_StreamClosed); } private void EnsureDecompressionMode() { if (_mode != CompressionMode.Decompress) ThrowCannotReadFromDeflateStreamException(); - } - private static void ThrowCannotReadFromDeflateStreamException() - { - throw new InvalidOperationException(SR.CannotReadFromDeflateStream); + static void ThrowCannotReadFromDeflateStreamException() => + throw new InvalidOperationException(SR.CannotReadFromDeflateStream); } private void EnsureCompressionMode() { if (_mode != CompressionMode.Compress) ThrowCannotWriteToDeflateStreamException(); - } - private static void ThrowCannotWriteToDeflateStreamException() - { - throw new InvalidOperationException(SR.CannotWriteToDeflateStream); + static void ThrowCannotWriteToDeflateStreamException() => + throw new InvalidOperationException(SR.CannotWriteToDeflateStream); } + private static void ThrowGenericInvalidData() => + // The stream is either malicious or poorly implemented and returned a number of + // bytes < 0 || > than the buffer supplied to it. + throw new InvalidDataException(SR.GenericInvalidData); + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) => TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState); @@ -378,6 +391,7 @@ namespace System.IO.Compression } EnsureBufferInitialized(); + Debug.Assert(_inflater != null); return Core(buffer, cancellationToken); @@ -386,48 +400,49 @@ namespace System.IO.Compression AsyncOperationStarting(); try { - int totalRead = 0; - - Debug.Assert(_inflater != null); + int bytesRead; while (true) { - int bytesRead = _inflater.Inflate(buffer.Span.Slice(totalRead)); - totalRead += bytesRead; - if (totalRead == buffer.Length) - { - break; - } - - // If the stream is finished then we have a few potential cases here: - // 1. DeflateStream => return - // 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input - // 3. GZipStream that is finished and appended with garbage => return - if (_inflater.Finished() && (!_inflater.IsGzipStream() || !_inflater.NeedsInput())) + // Try to decompress any data from the inflater into the caller's buffer. + // If we're able to decompress any bytes, or if decompression is completed, we're done. + bytesRead = _inflater.Inflate(buffer.Span); + if (bytesRead != 0 || InflatorIsFinished) { break; } + // We were unable to decompress any data. If the inflater needs additional input + // data to proceed, read some to populate it. if (_inflater.NeedsInput()) { - Debug.Assert(_buffer != null); - int bytes = await _stream.ReadAsync(_buffer, cancellationToken).ConfigureAwait(false); - EnsureNotDisposed(); - if (bytes <= 0) + int n = await _stream.ReadAsync(new Memory(_buffer, 0, _buffer.Length), cancellationToken).ConfigureAwait(false); + if (n <= 0) { break; } - else if (bytes > _buffer.Length) + else if (n > _buffer.Length) { - // The stream is either malicious or poorly implemented and returned a number of - // bytes larger than the buffer supplied to it. - throw new InvalidDataException(SR.GenericInvalidData); + ThrowGenericInvalidData(); } + else + { + _inflater.SetInput(_buffer, 0, n); + } + } - _inflater.SetInput(_buffer, 0, bytes); + if (buffer.IsEmpty) + { + // The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting + // a buffer until data is known to be available. We don't have perfect knowledge here, as _inflater.Inflate + // will return 0 whether or not more data is required, and having input data doesn't necessarily mean it'll + // decompress into at least one byte of output, but it's a reasonable approximation for the 99% case. If it's + // wrong, it just means that a caller using zero-byte reads as a way to delay getting a buffer to use for a + // subsequent call may end up getting one earlier than otherwise preferred. + break; } } - return totalRead; + return bytesRead; } finally { @@ -1014,21 +1029,16 @@ namespace System.IO.Compression private void AsyncOperationStarting() { - if (Interlocked.CompareExchange(ref _activeAsyncOperation, 1, 0) != 0) + if (Interlocked.Exchange(ref _activeAsyncOperation, 1) != 0) { ThrowInvalidBeginCall(); } } - private void AsyncOperationCompleting() - { - int oldValue = Interlocked.CompareExchange(ref _activeAsyncOperation, 0, 1); - Debug.Assert(oldValue == 1, $"Expected {nameof(_activeAsyncOperation)} to be 1, got {oldValue}"); - } + private void AsyncOperationCompleting() => + Volatile.Write(ref _activeAsyncOperation, 0); - private static void ThrowInvalidBeginCall() - { + private static void ThrowInvalidBeginCall() => throw new InvalidOperationException(SR.InvalidBeginCall); - } } } diff --git a/src/libraries/System.Security.Cryptography.Encoding/tests/Base64TransformsTests.cs b/src/libraries/System.Security.Cryptography.Encoding/tests/Base64TransformsTests.cs index d2e8136..aab283c 100644 --- a/src/libraries/System.Security.Cryptography.Encoding/tests/Base64TransformsTests.cs +++ b/src/libraries/System.Security.Cryptography.Encoding/tests/Base64TransformsTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; using System.IO; using Xunit; @@ -123,7 +124,7 @@ namespace System.Security.Cryptography.Encoding.Tests using (var ms = new MemoryStream(inputBytes)) using (var cs = new CryptoStream(ms, transform, CryptoStreamMode.Read)) { - int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length); + int bytesRead = ReadAll(cs, outputBytes); string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead); Assert.Equal(expected, outputString); } @@ -195,7 +196,7 @@ namespace System.Security.Cryptography.Encoding.Tests using (var ms = new MemoryStream(inputBytes)) using (var cs = new CryptoStream(ms, transform, CryptoStreamMode.Read)) { - int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length); + int bytesRead = ReadAll(cs, outputBytes); // Missing padding bytes not supported (no exception, however) Assert.NotEqual(inputBytes.Length, bytesRead); @@ -230,7 +231,7 @@ namespace System.Security.Cryptography.Encoding.Tests using (var ms = new MemoryStream(inputBytes)) using (var cs = new CryptoStream(ms, base64Transform, CryptoStreamMode.Read)) { - int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length); + int bytesRead = ReadAll(cs, outputBytes); string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead); Assert.Equal(expected, outputString); } @@ -240,7 +241,7 @@ namespace System.Security.Cryptography.Encoding.Tests using (var ms = new MemoryStream(inputBytes)) using (var cs = new CryptoStream(ms, base64Transform, CryptoStreamMode.Read)) { - int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length); + int bytesRead = ReadAll(cs, outputBytes); string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead); Assert.Equal(expected, outputString); } @@ -293,5 +294,22 @@ namespace System.Security.Cryptography.Encoding.Tests Assert.True(transform.CanReuseTransform); } } + + private static int ReadAll(Stream stream, Span buffer) + { + int totalRead = 0; + while (totalRead < buffer.Length) + { + int bytesRead = stream.Read(buffer.Slice(totalRead)); + if (bytesRead == 0) + { + break; + } + + totalRead += bytesRead; + } + + return totalRead; + } } } diff --git a/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs b/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs index c023c90..afe137d 100644 --- a/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs +++ b/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs @@ -16,10 +16,10 @@ namespace System.Security.Cryptography // Member variables private readonly Stream _stream; private readonly ICryptoTransform _transform; - private byte[]? _inputBuffer; // read from _stream before _Transform + private byte[] _inputBuffer; // read from _stream before _Transform private int _inputBufferIndex; private int _inputBlockSize; - private byte[]? _outputBuffer; // buffered output of _Transform + private byte[] _outputBuffer; // buffered output of _Transform private int _outputBufferIndex; private int _outputBlockSize; private bool _canRead; @@ -37,24 +37,41 @@ namespace System.Security.Cryptography public CryptoStream(Stream stream, ICryptoTransform transform, CryptoStreamMode mode, bool leaveOpen) { + if (transform is null) + { + throw new ArgumentNullException(nameof(transform)); + } _stream = stream; _transform = transform; _leaveOpen = leaveOpen; + switch (mode) { case CryptoStreamMode.Read: - if (!(_stream.CanRead)) throw new ArgumentException(SR.Format(SR.Argument_StreamNotReadable, nameof(stream))); + if (!_stream.CanRead) + { + throw new ArgumentException(SR.Format(SR.Argument_StreamNotReadable, nameof(stream))); + } _canRead = true; break; + case CryptoStreamMode.Write: - if (!(_stream.CanWrite)) throw new ArgumentException(SR.Format(SR.Argument_StreamNotWritable, nameof(stream))); + if (!_stream.CanWrite) + { + throw new ArgumentException(SR.Format(SR.Argument_StreamNotWritable, nameof(stream))); + } _canWrite = true; break; + default: - throw new ArgumentException(SR.Argument_InvalidValue); + throw new ArgumentException(SR.Argument_InvalidValue, nameof(mode)); } - InitializeBuffer(); + + _inputBlockSize = _transform.InputBlockSize; + _inputBuffer = new byte[_inputBlockSize]; + _outputBlockSize = _transform.OutputBlockSize; + _outputBuffer = new byte[_outputBlockSize]; } public override bool CanRead @@ -293,198 +310,149 @@ namespace System.Security.Cryptography private async ValueTask ReadAsyncCore(Memory buffer, CancellationToken cancellationToken, bool useAsync) { - // read <= count bytes from the input stream, transforming as we go. - // Basic idea: first we deliver any bytes we already have in the - // _OutputBuffer, because we know they're good. Then, if asked to deliver - // more bytes, we read & transform a block at a time until either there are - // no bytes ready or we've delivered enough. - int bytesToDeliver = buffer.Length; - int currentOutputIndex = 0; - Debug.Assert(_outputBuffer != null); - if (_outputBufferIndex != 0) + while (true) { - // we have some already-transformed bytes in the output buffer - if (_outputBufferIndex <= buffer.Length) + // If there are currently any bytes stored in the output buffer, hand back as many as we can. + if (_outputBufferIndex != 0) { - _outputBuffer.AsSpan(0, _outputBufferIndex).CopyTo(buffer.Span); - bytesToDeliver -= _outputBufferIndex; - currentOutputIndex += _outputBufferIndex; - int toClear = _outputBuffer.Length - _outputBufferIndex; - CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear)); - _outputBufferIndex = 0; + int bytesToCopy = Math.Min(_outputBufferIndex, buffer.Length); + if (bytesToCopy != 0) + { + // Copy as many bytes as we can, then shift down the remaining bytes. + new ReadOnlySpan(_outputBuffer, 0, bytesToCopy).CopyTo(buffer.Span); + _outputBufferIndex -= bytesToCopy; + _outputBuffer.AsSpan(bytesToCopy).CopyTo(_outputBuffer); + CryptographicOperations.ZeroMemory(_outputBuffer.AsSpan(_outputBufferIndex, bytesToCopy)); + } + return bytesToCopy; } - else - { - _outputBuffer.AsSpan(0, buffer.Length).CopyTo(buffer.Span); - Buffer.BlockCopy(_outputBuffer, buffer.Length, _outputBuffer, 0, _outputBufferIndex - buffer.Length); - _outputBufferIndex -= buffer.Length; - int toClear = _outputBuffer.Length - _outputBufferIndex; - CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear)); - - return buffer.Length; + // If we've already hit the end of the stream, there's nothing more to do. + Debug.Assert(_outputBufferIndex == 0); + if (_finalBlockTransformed) + { + Debug.Assert(_inputBufferIndex == 0); + return 0; } - } - // _finalBlockTransformed == true implies we're at the end of the input stream - // if we got through the previous if block then _OutputBufferIndex = 0, meaning - // we have no more transformed bytes to give - // so return count-bytesToDeliver, the amount we were able to hand back - // eventually, we'll just always return 0 here because there's no more to read - if (_finalBlockTransformed) - { - return buffer.Length - bytesToDeliver; - } - // ok, now loop until we've delivered enough or there's nothing available - int amountRead = 0; - int numOutputBytes; - // OK, see first if it's a multi-block transform and we can speed up things - int blocksToProcess = bytesToDeliver / _outputBlockSize; + int bytesRead = 0; + bool eof = false; - Debug.Assert(_inputBuffer != null); - if (blocksToProcess > 1 && _transform.CanTransformMultipleBlocks) - { - int numWholeBlocksInBytes = blocksToProcess * _inputBlockSize; - - // Use ArrayPool.Shared instead of CryptoPool because the array is passed out. - byte[]? tempInputBuffer = ArrayPool.Shared.Rent(numWholeBlocksInBytes); - byte[]? tempOutputBuffer = null; - - try + // If the transform supports transforming multiple blocks, try to read as large a chunk as would yield + // data to fill the output buffer and do the appropriate transform directly into the output buffer. + int blocksToProcess = buffer.Length / _outputBlockSize; + if (blocksToProcess > 1 && _transform.CanTransformMultipleBlocks) { - amountRead = useAsync ? - await _stream.ReadAsync(new Memory(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex), cancellationToken).ConfigureAwait(false) : - _stream.Read(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex); - - int totalInput = _inputBufferIndex + amountRead; - - // If there's still less than a block, copy the new data into the hold buffer and move to the slow read. - if (totalInput < _inputBlockSize) - { - Buffer.BlockCopy(tempInputBuffer, _inputBufferIndex, _inputBuffer, _inputBufferIndex, amountRead); - _inputBufferIndex = totalInput; - } - else + // Use ArrayPool.Shared instead of CryptoPool because the array is passed out. + int numWholeBlocksInBytes = blocksToProcess * _inputBlockSize; + byte[] tempInputBuffer = ArrayPool.Shared.Rent(numWholeBlocksInBytes); + try { - // Copy any held data into tempInputBuffer now that we know we're proceeding - Buffer.BlockCopy(_inputBuffer, 0, tempInputBuffer, 0, _inputBufferIndex); - CryptographicOperations.ZeroMemory(new Span(_inputBuffer, 0, _inputBufferIndex)); - amountRead += _inputBufferIndex; - _inputBufferIndex = 0; - - // Make amountRead an integral multiple of _InputBlockSize - int numWholeReadBlocks = amountRead / _inputBlockSize; - int numWholeReadBlocksInBytes = numWholeReadBlocks * _inputBlockSize; - int numIgnoredBytes = amountRead - numWholeReadBlocksInBytes; - - if (numIgnoredBytes != 0) + // Read into our temporary input buffer, leaving enough room at the beginning for any existing data + // we have in _inputBuffer. + bytesRead = useAsync ? + await _stream.ReadAsync(new Memory(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex), cancellationToken).ConfigureAwait(false) : + _stream.Read(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex); + eof = bytesRead == 0; + + // If we got enough data to form at least one block, transform as much as we can. + int totalInput = _inputBufferIndex + bytesRead; + if (totalInput >= _inputBlockSize) { - _inputBufferIndex = numIgnoredBytes; - Buffer.BlockCopy(tempInputBuffer, numWholeReadBlocksInBytes, _inputBuffer, 0, numIgnoredBytes); - } - - // Use ArrayPool.Shared instead of CryptoPool because the array is passed out. - tempOutputBuffer = ArrayPool.Shared.Rent(numWholeReadBlocks * _outputBlockSize); - numOutputBytes = _transform.TransformBlock(tempInputBuffer, 0, numWholeReadBlocksInBytes, tempOutputBuffer, 0); - tempOutputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span.Slice(currentOutputIndex)); - - // Clear what was written while we know how much that was - CryptographicOperations.ZeroMemory(new Span(tempOutputBuffer, 0, numOutputBytes)); - ArrayPool.Shared.Return(tempOutputBuffer); - tempOutputBuffer = null; + // Copy any held data into tempInputBuffer now that we know we're proceeding to handle + // decrypting all the received data. + Buffer.BlockCopy(_inputBuffer, 0, tempInputBuffer, 0, _inputBufferIndex); + CryptographicOperations.ZeroMemory(new Span(_inputBuffer, 0, _inputBufferIndex)); + bytesRead += _inputBufferIndex; + + // Determine how many entire blocks worth of data we read. + int numWholeReadBlocks = bytesRead / _inputBlockSize; + int numWholeReadBlocksInBytes = numWholeReadBlocks * _inputBlockSize; + + // If there's anything left over, copy that back into _inputBuffer for a later read. + _inputBufferIndex = bytesRead - numWholeReadBlocksInBytes; + if (_inputBufferIndex != 0) + { + Buffer.BlockCopy(tempInputBuffer, numWholeReadBlocksInBytes, _inputBuffer, 0, _inputBufferIndex); + } - bytesToDeliver -= numOutputBytes; - currentOutputIndex += numOutputBytes; - } + // Transform the read data into the caller's buffer. + int numOutputBytes; + if (MemoryMarshal.TryGetArray(buffer, out ArraySegment bufferArray)) + { + // Because TransformBlock is based on arrays, we can only write directly into the output + // buffer if it's backed by an array; otherwise, we need to rent from the pool. + numOutputBytes = _transform.TransformBlock(tempInputBuffer, 0, numWholeReadBlocksInBytes, bufferArray.Array!, bufferArray.Offset); + } + else + { + // Otherwise, we need to rent a temporary from the pool. + byte[] tempOutputBuffer = ArrayPool.Shared.Rent(numWholeReadBlocks * _outputBlockSize); + numOutputBytes = numWholeReadBlocks * _outputBlockSize; + try + { + numOutputBytes = _transform.TransformBlock(tempInputBuffer, 0, numWholeReadBlocksInBytes, tempOutputBuffer, 0); + tempOutputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span); + } + finally + { + CryptographicOperations.ZeroMemory(new Span(tempOutputBuffer, 0, numOutputBytes)); + ArrayPool.Shared.Return(tempOutputBuffer); + } + } - CryptographicOperations.ZeroMemory(new Span(tempInputBuffer, 0, numWholeBlocksInBytes)); - ArrayPool.Shared.Return(tempInputBuffer); - tempInputBuffer = null; - } - catch - { - // If we rented and then an exception happened we don't know how much was written to, - // clear the whole thing and let it get reclaimed by the GC. - if (tempOutputBuffer != null) - { - CryptographicOperations.ZeroMemory(tempOutputBuffer); - tempOutputBuffer = null; + // Return anything we've got at this point. + if (numOutputBytes != 0) + { + return numOutputBytes; + } + } + else + { + // We have less than a block's worth of data. Copy the new data back into the _inputBuffer + // and fall back to using the single block code path. + Buffer.BlockCopy(tempInputBuffer, _inputBufferIndex, _inputBuffer, _inputBufferIndex, bytesRead); + _inputBufferIndex = totalInput; + } } - - // For the input buffer we know how much was written, so clear that. - // But still let it get reclaimed by the GC. - if (tempInputBuffer != null) + finally { CryptographicOperations.ZeroMemory(new Span(tempInputBuffer, 0, numWholeBlocksInBytes)); - tempInputBuffer = null; + ArrayPool.Shared.Return(tempInputBuffer); } - - throw; } - } - // try to fill _InputBuffer so we have something to transform - while (bytesToDeliver > 0) - { - while (_inputBufferIndex < _inputBlockSize) + // Read enough to fill one input block, as anything less won't be able to be transformed to produce output. + if (!eof) { - amountRead = useAsync ? - await _stream.ReadAsync(new Memory(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex), cancellationToken).ConfigureAwait(false) : - _stream.Read(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex); + while (_inputBufferIndex < _inputBlockSize) + { + bytesRead = useAsync ? + await _stream.ReadAsync(new Memory(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex), cancellationToken).ConfigureAwait(false) : + _stream.Read(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex); + if (bytesRead <= 0) + { + break; + } - // first, check to see if we're at the end of the input stream - if (amountRead == 0) goto ProcessFinalBlock; - _inputBufferIndex += amountRead; + _inputBufferIndex += bytesRead; + } } - numOutputBytes = _transform.TransformBlock(_inputBuffer, 0, _inputBlockSize, _outputBuffer, 0); - _inputBufferIndex = 0; - - if (bytesToDeliver >= numOutputBytes) + // Transform the received data. + if (bytesRead <= 0) { - _outputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span.Slice(currentOutputIndex)); - CryptographicOperations.ZeroMemory(new Span(_outputBuffer, 0, numOutputBytes)); - currentOutputIndex += numOutputBytes; - bytesToDeliver -= numOutputBytes; + _outputBuffer = _transform.TransformFinalBlock(_inputBuffer, 0, _inputBufferIndex); + _outputBufferIndex = _outputBuffer.Length; + _finalBlockTransformed = true; } else { - _outputBuffer.AsSpan(0, bytesToDeliver).CopyTo(buffer.Span.Slice(currentOutputIndex)); - _outputBufferIndex = numOutputBytes - bytesToDeliver; - Buffer.BlockCopy(_outputBuffer, bytesToDeliver, _outputBuffer, 0, _outputBufferIndex); - int toClear = _outputBuffer.Length - _outputBufferIndex; - CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear)); - return buffer.Length; + _outputBufferIndex = _transform.TransformBlock(_inputBuffer, 0, _inputBufferIndex, _outputBuffer, 0); } - } - return buffer.Length; - - ProcessFinalBlock: - // if so, then call TransformFinalBlock to get whatever is left - byte[] finalBytes = _transform.TransformFinalBlock(_inputBuffer, 0, _inputBufferIndex); - // now, since _OutputBufferIndex must be 0 if we're in the while loop at this point, - // reset it to be what we just got back - _outputBuffer = finalBytes; - _outputBufferIndex = finalBytes.Length; - // set the fact that we've transformed the final block - _finalBlockTransformed = true; - // now, return either everything we just got or just what's asked for, whichever is smaller - if (bytesToDeliver < _outputBufferIndex) - { - _outputBuffer.AsSpan(0, bytesToDeliver).CopyTo(buffer.Span.Slice(currentOutputIndex)); - _outputBufferIndex -= bytesToDeliver; - Buffer.BlockCopy(_outputBuffer, bytesToDeliver, _outputBuffer, 0, _outputBufferIndex); - int toClear = _outputBuffer.Length - _outputBufferIndex; - CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear)); - return buffer.Length; - } - else - { - _outputBuffer.AsSpan(0, _outputBufferIndex).CopyTo(buffer.Span.Slice(currentOutputIndex)); - bytesToDeliver -= _outputBufferIndex; - _outputBufferIndex = 0; - CryptographicOperations.ZeroMemory(_outputBuffer); - return buffer.Length - bytesToDeliver; + + // All input data has been processed. + _inputBufferIndex = 0; } } @@ -807,8 +775,8 @@ namespace System.Security.Cryptography if (_outputBuffer != null) Array.Clear(_outputBuffer); - _inputBuffer = null; - _outputBuffer = null; + _inputBuffer = null!; + _outputBuffer = null!; _canRead = false; _canWrite = false; } @@ -858,30 +826,13 @@ namespace System.Security.Cryptography Array.Clear(_outputBuffer); } - _inputBuffer = null; - _outputBuffer = null; + _inputBuffer = null!; + _outputBuffer = null!; _canRead = false; _canWrite = false; } } - // Private methods - - private void InitializeBuffer() - { - if (_transform != null) - { - _inputBlockSize = _transform.InputBlockSize; - _inputBuffer = new byte[_inputBlockSize]; - _outputBlockSize = _transform.OutputBlockSize; - _outputBuffer = new byte[_outputBlockSize]; - } - else - { - throw new ArgumentNullException(nameof(_transform)); - } - } - [MemberNotNull(nameof(_lazyAsyncActiveSemaphore))] private SemaphoreSlim AsyncActiveSemaphore { diff --git a/src/libraries/System.Security.Cryptography.Primitives/tests/CryptoStream.cs b/src/libraries/System.Security.Cryptography.Primitives/tests/CryptoStream.cs index ea43dc8..86fa123 100644 --- a/src/libraries/System.Security.Cryptography.Primitives/tests/CryptoStream.cs +++ b/src/libraries/System.Security.Cryptography.Primitives/tests/CryptoStream.cs @@ -27,6 +27,7 @@ namespace System.Security.Cryptography.Encryption.Tests.Asymmetric } protected override Type UnsupportedConcurrentExceptionType => null; + protected override bool BlocksOnZeroByteReads => true; [ActiveIssue("https://github.com/dotnet/runtime/issues/45080")] [Theory] @@ -37,7 +38,7 @@ namespace System.Security.Cryptography.Encryption.Tests.Asymmetric public static void Ctor() { var transform = new IdentityTransform(1, 1, true); - AssertExtensions.Throws(null, () => new CryptoStream(new MemoryStream(), transform, (CryptoStreamMode)12345)); + AssertExtensions.Throws("mode", () => new CryptoStream(new MemoryStream(), transform, (CryptoStreamMode)12345)); AssertExtensions.Throws(null, "stream", () => new CryptoStream(new MemoryStream(new byte[0], writable: false), transform, CryptoStreamMode.Write)); AssertExtensions.Throws(null, "stream", () => new CryptoStream(new CryptoStream(new MemoryStream(new byte[0]), transform, CryptoStreamMode.Write), transform, CryptoStreamMode.Read)); }