Add HashAlgorithm.ComputeHashAsync
authorJeremy Barton <jbarton@microsoft.com>
Tue, 12 Nov 2019 23:28:59 +0000 (15:28 -0800)
committerGitHub <noreply@github.com>
Tue, 12 Nov 2019 23:28:59 +0000 (15:28 -0800)
Commit migrated from https://github.com/dotnet/corefx/commit/ac02476cfd31e0aed74e5ea4c9dd0af3d7b1ef7f

src/libraries/System.Security.Cryptography.Primitives/ref/System.Security.Cryptography.Primitives.cs
src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/HashAlgorithm.cs
src/libraries/System.Security.Cryptography.Primitives/tests/HashAlgorithmTest.cs

index 6b1b4a4..c207fe0 100644 (file)
@@ -109,6 +109,7 @@ namespace System.Security.Cryptography
         public byte[] ComputeHash(byte[] buffer) { throw null; }
         public byte[] ComputeHash(byte[] buffer, int offset, int count) { throw null; }
         public byte[] ComputeHash(System.IO.Stream inputStream) { throw null; }
+        public System.Threading.Tasks.Task<byte[]> ComputeHashAsync(System.IO.Stream inputStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
         public static System.Security.Cryptography.HashAlgorithm Create() { throw null; }
         public static System.Security.Cryptography.HashAlgorithm? Create(string hashName) { throw null; }
         public void Dispose() { }
index 27bc168..d722009 100644 (file)
@@ -4,6 +4,8 @@
 
 using System.Buffers;
 using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Security.Cryptography
 {
@@ -101,12 +103,57 @@ namespace System.Security.Cryptography
             byte[] buffer = ArrayPool<byte>.Shared.Rent(4096);
 
             int bytesRead;
+            int clearLimit = 0;
+
             while ((bytesRead = inputStream.Read(buffer, 0, buffer.Length)) > 0)
             {
+                if (bytesRead > clearLimit)
+                {
+                    clearLimit = bytesRead;
+                }
+
                 HashCore(buffer, 0, bytesRead);
             }
 
-            ArrayPool<byte>.Shared.Return(buffer, clearArray: true);
+            CryptographicOperations.ZeroMemory(buffer.AsSpan(0, clearLimit));
+            ArrayPool<byte>.Shared.Return(buffer, clearArray: false);
+            return CaptureHashCodeAndReinitialize();
+        }
+
+        public Task<byte[]> ComputeHashAsync(
+            Stream inputStream,
+            CancellationToken cancellationToken = default)
+        {
+            if (inputStream == null)
+                throw new ArgumentNullException(nameof(inputStream));
+            if (_disposed)
+                throw new ObjectDisposedException(null);
+
+            return ComputeHashAsyncCore(inputStream, cancellationToken);
+        }
+
+        private async Task<byte[]> ComputeHashAsyncCore(
+            Stream inputStream,
+            CancellationToken cancellationToken)
+        {
+            // Use ArrayPool.Shared instead of CryptoPool because the array is passed out.
+            byte[] rented = ArrayPool<byte>.Shared.Rent(4096);
+            Memory<byte> buffer = rented;
+            int clearLimit = 0;
+            int bytesRead;
+
+            while ((bytesRead = await inputStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false)) > 0)
+            {
+                if (bytesRead > clearLimit)
+                {
+                    clearLimit = bytesRead;
+                }
+
+                HashCore(rented, 0, bytesRead);
+            }
+
+            CryptographicOperations.ZeroMemory(rented.AsSpan(0, clearLimit));
+            ArrayPool<byte>.Shared.Return(rented, clearArray: false);
             return CaptureHashCodeAndReinitialize();
         }
 
index 515e51b..a934703 100644 (file)
@@ -3,6 +3,9 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Test.IO.Streams;
 using Xunit;
 
 namespace System.Security.Cryptography.Hashing.Tests
@@ -29,6 +32,77 @@ namespace System.Security.Cryptography.Hashing.Tests
             Assert.Equal(input.Sum(b => (long)b), BitConverter.ToInt64(output, 0));
         }
 
+        [Theory]
+        [InlineData(0)]
+        [InlineData(10)]
+        [InlineData(4096)]
+        [InlineData(4097)]
+        [InlineData(10000)]
+        public async Task VerifyComputeHashAsync(int size)
+        {
+            int fullCycles = size / 256;
+            int partial = size % 256;
+            // SUM(0..255) is 32640
+            const long CycleSum = 32640L;
+
+            // The formula for additive sum IS n*(n+1)/2, but size is a count and the first n is 0,
+            // which happens to turn it into (n-1) * n / 2, aka n * (n - 1) / 2.
+            long expectedSum = CycleSum * fullCycles + (partial * (partial - 1) / 2);
+
+            using (PositionValueStream stream = new PositionValueStream(size))
+            using (HashAlgorithm hash = new SummingTestHashAlgorithm())
+            {
+                byte[] result = await hash.ComputeHashAsync(stream);
+                byte[] expected = BitConverter.GetBytes(expectedSum);
+
+                Assert.Equal(expected, result);
+            }
+        }
+
+        [Fact]
+        public async Task ComputeHashAsync_SupportsCancellation()
+        {
+            using (CancellationTokenSource cancellationSource = new CancellationTokenSource(100))
+            using (PositionValueStream stream = new SlowPositionValueStream(10000))
+            using (HashAlgorithm hash = new SummingTestHashAlgorithm())
+            {
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(
+                    () => hash.ComputeHashAsync(stream, cancellationSource.Token));
+            }
+        }
+
+        [Fact]
+        public void ComputeHashAsync_Disposed()
+        {
+            using (PositionValueStream stream = new SlowPositionValueStream(10000))
+            using (HashAlgorithm hash = new SummingTestHashAlgorithm())
+            {
+                hash.Dispose();
+
+                Assert.Throws<ObjectDisposedException>(
+                    () =>
+                    {
+                        // Not returning or awaiting the Task, it never got created.
+                        hash.ComputeHashAsync(stream);
+                    });
+            }
+        }
+
+        [Fact]
+        public void ComputeHashAsync_RequiresStream()
+        {
+            using (HashAlgorithm hash = new SummingTestHashAlgorithm())
+            {
+                AssertExtensions.Throws<ArgumentNullException>(
+                    "inputStream",
+                    () =>
+                    {
+                        // Not returning or awaiting the Task, it never got created.
+                        hash.ComputeHashAsync(null);
+                    });
+            }
+        }
+
         private sealed class SummingTestHashAlgorithm : HashAlgorithm
         {
             private long _sum;
@@ -48,5 +122,18 @@ namespace System.Security.Cryptography.Hashing.Tests
             // test verifies that calling the base implementations invokes the array
             // implementations by verifying the right value is produced.
         }
+
+        private class SlowPositionValueStream : PositionValueStream
+        {
+            public SlowPositionValueStream(int totalCount) : base(totalCount)
+            {
+            }
+
+            public override int Read(byte[] buffer, int offset, int count)
+            {
+                System.Threading.Thread.Sleep(1000);
+                return base.Read(buffer, offset, count);
+            }
+        }
     }
 }