Override Write(Span) for CryptoStream to clear arrays before returning to pool (...
authorKevin Jones <kevin@vcsjones.com>
Tue, 14 Mar 2023 14:33:56 +0000 (10:33 -0400)
committerGitHub <noreply@github.com>
Tue, 14 Mar 2023 14:33:56 +0000 (10:33 -0400)
src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/CryptoStream.cs
src/libraries/System.Security.Cryptography/tests/CryptoStream.cs

index f85b396..3afce1c 100644 (file)
@@ -499,6 +499,29 @@ namespace System.Security.Cryptography
             WriteAsyncCore(buffer.AsMemory(offset, count), default, useAsync: false).AsTask().GetAwaiter().GetResult();
         }
 
+        /// <inheritdoc/>
+        public override void Write(ReadOnlySpan<byte> buffer)
+        {
+            // Logically this is doing the same thing as the base Stream, however CryptoStream clears arrays before
+            // returning them to the pool, whereas the base Stream does not.
+            // Use ArrayPool.Shared instead of CryptoPool because the array is passed out.
+            byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
+
+            try
+            {
+                buffer.CopyTo(sharedBuffer);
+
+                // We want to keep calling the virtual Write(byte[]...) so that derived CryptoStream types continue
+                // to get the array overload called from the span one.
+                Write(sharedBuffer, 0, buffer.Length);
+            }
+            finally
+            {
+                CryptographicOperations.ZeroMemory(sharedBuffer.AsSpan(0, buffer.Length));
+                ArrayPool<byte>.Shared.Return(sharedBuffer);
+            }
+        }
+
         private void CheckWriteArguments(byte[] buffer, int offset, int count)
         {
             ValidateBufferArguments(buffer, offset, count);
index fb5fca6..9b6a368 100644 (file)
@@ -55,6 +55,7 @@ namespace System.Security.Cryptography.Tests
         [InlineData(64, 64, false)]
         public static void Roundtrip(int inputBlockSize, int outputBlockSize, bool canTransformMultipleBlocks)
         {
+            const string ExpectedString = LoremText + LoremText + LoremText + LoremText + LoremText;
             ICryptoTransform encryptor = new IdentityTransform(inputBlockSize, outputBlockSize, canTransformMultipleBlocks);
             ICryptoTransform decryptor = new IdentityTransform(inputBlockSize, outputBlockSize, canTransformMultipleBlocks);
 
@@ -88,6 +89,10 @@ namespace System.Security.Cryptography.Tests
                 encryptStream.WriteAsync(toWrite, 0, toWrite.Length).GetAwaiter().GetResult();
                 Assert.False(encryptStream.HasFlushedFinalBlock);
 
+                // Write span
+                encryptStream.Write(toWrite.AsSpan());
+                Assert.False(encryptStream.HasFlushedFinalBlock);
+
                 // Flush (nops)
                 encryptStream.Flush();
                 encryptStream.FlushAsync().GetAwaiter().GetResult();
@@ -110,9 +115,7 @@ namespace System.Security.Cryptography.Tests
 
                 using (StreamReader reader = new StreamReader(decryptStream))
                 {
-                    Assert.Equal(
-                        LoremText + LoremText + LoremText + LoremText,
-                        reader.ReadToEnd());
+                    Assert.Equal(ExpectedString, reader.ReadToEnd());
                 }
             }
 
@@ -121,9 +124,7 @@ namespace System.Security.Cryptography.Tests
             using (CryptoStream decryptStream = new CryptoStream(stream, decryptor, CryptoStreamMode.Read))
             using (StreamReader reader = new StreamReader(decryptStream))
             {
-                Assert.Equal(
-                    LoremText + LoremText + LoremText + LoremText,
-                    reader.ReadToEndAsync().GetAwaiter().GetResult());
+                Assert.Equal(ExpectedString, reader.ReadToEndAsync().GetAwaiter().GetResult());
             }
 
             // Read/decrypt using a small buffer to force multiple calls to Read
@@ -131,17 +132,14 @@ namespace System.Security.Cryptography.Tests
             using (CryptoStream decryptStream = new CryptoStream(stream, decryptor, CryptoStreamMode.Read))
             using (StreamReader reader = new StreamReader(decryptStream, Encoding.UTF8, true, bufferSize: 10))
             {
-                Assert.Equal(
-                    LoremText + LoremText + LoremText + LoremText,
-                    reader.ReadToEndAsync().GetAwaiter().GetResult());
+                Assert.Equal(ExpectedString, reader.ReadToEndAsync().GetAwaiter().GetResult());
             }
 
             // Read/decrypt one byte at a time with ReadByte
             stream = new MemoryStream(stream.ToArray()); // CryptoStream.Dispose disposes the stream
             using (CryptoStream decryptStream = new CryptoStream(stream, decryptor, CryptoStreamMode.Read))
             {
-                string expectedStr = LoremText + LoremText + LoremText + LoremText;
-                foreach (char c in expectedStr)
+                foreach (char c in ExpectedString)
                 {
                     Assert.Equal(c, decryptStream.ReadByte()); // relies on LoremText being ASCII
                 }