HKDF implementation (dotnet/corefx#42567)
authorKrzysztof Wicher <mordotymoja@gmail.com>
Thu, 14 Nov 2019 00:23:19 +0000 (16:23 -0800)
committerGitHub <noreply@github.com>
Thu, 14 Nov 2019 00:23:19 +0000 (16:23 -0800)
* HKDF implementation

* Fix CreateMacProvider on OSX

* apply review feedback

* improve error message in case of test failure

Commit migrated from https://github.com/dotnet/corefx/commit/c14fc5636bdb9141f69eaeaf0e5812b80af525b1

13 files changed:
src/libraries/Common/src/Internal/Cryptography/HashProviderCng.cs
src/libraries/Common/src/Interop/Windows/BCrypt/Interop.BCryptCreateHash.cs
src/libraries/System.Security.Cryptography.Algorithms/ref/System.Security.Cryptography.Algorithms.cs
src/libraries/System.Security.Cryptography.Algorithms/src/Internal/Cryptography/HMACCommon.cs
src/libraries/System.Security.Cryptography.Algorithms/src/Internal/Cryptography/HashProviderDispenser.OSX.cs
src/libraries/System.Security.Cryptography.Algorithms/src/Internal/Cryptography/HashProviderDispenser.Unix.cs
src/libraries/System.Security.Cryptography.Algorithms/src/Internal/Cryptography/HashProviderDispenser.Windows.cs
src/libraries/System.Security.Cryptography.Algorithms/src/Resources/Strings.resx
src/libraries/System.Security.Cryptography.Algorithms/src/System.Security.Cryptography.Algorithms.csproj
src/libraries/System.Security.Cryptography.Algorithms/src/System/Security/Cryptography/HKDF.cs [new file with mode: 0644]
src/libraries/System.Security.Cryptography.Algorithms/src/System/Security/Cryptography/IncrementalHash.cs
src/libraries/System.Security.Cryptography.Algorithms/tests/HKDFTests.cs [new file with mode: 0644]
src/libraries/System.Security.Cryptography.Algorithms/tests/System.Security.Cryptography.Algorithms.Tests.csproj

index d94bc09..ce4e1f5 100644 (file)
@@ -21,12 +21,16 @@ namespace Internal.Cryptography
         //
         //   - "key" activates MAC hashing if present. If null, this HashProvider performs a regular old hash.
         //
-        public HashProviderCng(string hashAlgId, byte[] key)
+        public HashProviderCng(string hashAlgId, byte[] key) : this(hashAlgId, key, isHmac: key != null)
+        {
+        }
+
+        internal HashProviderCng(string hashAlgId, ReadOnlySpan<byte> key, bool isHmac)
         {
             BCryptOpenAlgorithmProviderFlags dwFlags = BCryptOpenAlgorithmProviderFlags.None;
-            if (key != null)
+            if (isHmac)
             {
-                _key = key.CloneByteArray();
+                _key = key.ToArray();
                 dwFlags |= BCryptOpenAlgorithmProviderFlags.BCRYPT_ALG_HANDLE_HMAC_FLAG;
             }
 
@@ -63,7 +67,6 @@ namespace Internal.Cryptography
                     throw Interop.BCrypt.CreateCryptographicException(ntStatus);
                 _hashSize = hashSize;
             }
-            return;
         }
 
         public sealed override unsafe void AppendHashData(ReadOnlySpan<byte> source)
index 09d5987..6a2fc0d 100644 (file)
@@ -12,8 +12,13 @@ internal partial class Interop
 {
     internal partial class BCrypt
     {
+        internal static NTSTATUS BCryptCreateHash(SafeBCryptAlgorithmHandle hAlgorithm, out SafeBCryptHashHandle phHash, IntPtr pbHashObject, int cbHashObject, ReadOnlySpan<byte> secret, int cbSecret, BCryptCreateHashFlags dwFlags)
+        {
+            return BCryptCreateHash(hAlgorithm, out phHash, pbHashObject, cbHashObject, ref MemoryMarshal.GetReference(secret), cbSecret, dwFlags);
+        }
+
         [DllImport(Libraries.BCrypt, CharSet = CharSet.Unicode)]
-        internal static extern NTSTATUS BCryptCreateHash(SafeBCryptAlgorithmHandle hAlgorithm, out SafeBCryptHashHandle phHash, IntPtr pbHashObject, int cbHashObject, [In, Out] byte[] pbSecret, int cbSecret, BCryptCreateHashFlags dwFlags);
+        private static extern NTSTATUS BCryptCreateHash(SafeBCryptAlgorithmHandle hAlgorithm, out SafeBCryptHashHandle phHash, IntPtr pbHashObject, int cbHashObject, ref byte pbSecret, int cbSecret, BCryptCreateHashFlags dwFlags);
 
         [Flags]
         internal enum BCryptCreateHashFlags : int
index 06fec22..f60d0ed 100644 (file)
@@ -332,6 +332,15 @@ namespace System.Security.Cryptography
         public byte[] X;
         public byte[] Y;
     }
+    public static class HKDF
+    {
+        public static byte[] Extract(HashAlgorithmName hashAlgorithmName, byte[] ikm, byte[] salt = null) { throw null; }
+        public static int Extract(HashAlgorithmName hashAlgorithmName, ReadOnlySpan<byte> ikm, ReadOnlySpan<byte> salt, Span<byte> prk) { throw null; }
+        public static byte[] Expand(HashAlgorithmName hashAlgorithmName, byte[] prk, int outputLength, byte[] info = null) { throw null; }
+        public static void Expand(HashAlgorithmName hashAlgorithmName, ReadOnlySpan<byte> prk, Span<byte> output, ReadOnlySpan<byte> info) { throw null; }
+        public static byte[] DeriveKey(HashAlgorithmName hashAlgorithmName, byte[] ikm, int outputLength, byte[] salt = null, byte[] info = null) { throw null; }
+        public static void DeriveKey(HashAlgorithmName hashAlgorithmName, ReadOnlySpan<byte> ikm, Span<byte> output, ReadOnlySpan<byte> salt, ReadOnlySpan<byte> info) { throw null; }
+    }
     public partial class HMACMD5 : System.Security.Cryptography.HMAC
     {
         public HMACMD5() { }
index 8a6d400..63f3b8b 100644 (file)
@@ -17,20 +17,45 @@ namespace Internal.Cryptography
     //
     internal sealed class HMACCommon
     {
-        public HMACCommon(string hashAlgorithmId, byte[] key, int blockSize)
+        public HMACCommon(string hashAlgorithmId, byte[] key, int blockSize) : this(hashAlgorithmId, blockSize)
+        {
+            ChangeKey(key);
+        }
+
+        internal HMACCommon(string hashAlgorithmId, ReadOnlySpan<byte> key, int blockSize) : this(hashAlgorithmId, blockSize)
+        {
+            // note: will not set ActualKey if key size is smaller or equal than blockSize
+            //       this is to avoid extra allocation. ActualKey can still be used if key is generated.
+            //       Otherwise the ReadOnlySpan overload would actually be slower than byte array overload.
+            ChangeKey(key);
+        }
+
+        private HMACCommon(string hashAlgorithmId, int blockSize)
         {
             Debug.Assert(!string.IsNullOrEmpty(hashAlgorithmId));
             Debug.Assert(blockSize > 0 || blockSize == -1);
 
             _hashAlgorithmId = hashAlgorithmId;
             _blockSize = blockSize;
-            ChangeKey(key);
         }
 
         public int HashSizeInBits => _hMacProvider.HashSizeInBytes * 8;
 
         public void ChangeKey(byte[] key)
         {
+            ActualKey = ChangeKeyImpl(key) ?? key;
+        }
+
+        internal void ChangeKey(ReadOnlySpan<byte> key)
+        {
+            // note: does not set key when it's smaller than blockSize
+            ActualKey = ChangeKeyImpl(key);
+        }
+
+        private byte[] ChangeKeyImpl(ReadOnlySpan<byte> key)
+        {
+            byte[] modifiedKey = null;
+
             // If _blockSize is -1 the key isn't going to be extractable by the object holder,
             // so there's no point in recalculating it in managed code.
             if (key.Length > _blockSize && _blockSize > 0)
@@ -40,8 +65,8 @@ namespace Internal.Cryptography
                 {
                     _lazyHashProvider = HashProviderDispenser.CreateHashProvider(_hashAlgorithmId);
                 }
-                _lazyHashProvider.AppendHashData(key, 0, key.Length);
-                key = _lazyHashProvider.FinalizeHashAndReset();
+                _lazyHashProvider.AppendHashData(key);
+                modifiedKey = _lazyHashProvider.FinalizeHashAndReset();
             }
 
             HashProvider oldHashProvider = _hMacProvider;
@@ -49,7 +74,7 @@ namespace Internal.Cryptography
             oldHashProvider?.Dispose(true);
             _hMacProvider = HashProviderDispenser.CreateMacProvider(_hashAlgorithmId, key);
 
-            ActualKey = key;
+            return modifiedKey;
         }
 
         // The actual key used for hashing. This will not be the same as the original key passed to ChangeKey() if the original key exceeded the
index 4768666..fb6025f 100644 (file)
@@ -30,7 +30,7 @@ namespace Internal.Cryptography
             throw new CryptographicException(SR.Format(SR.Cryptography_UnknownHashAlgorithm, hashAlgorithmId));
         }
 
-        public static HashProvider CreateMacProvider(string hashAlgorithmId, byte[] key)
+        public static HashProvider CreateMacProvider(string hashAlgorithmId, ReadOnlySpan<byte> key)
         {
             switch (hashAlgorithmId)
             {
@@ -62,9 +62,9 @@ namespace Internal.Cryptography
 
             public override int HashSizeInBytes { get; }
 
-            internal AppleHmacProvider(Interop.AppleCrypto.PAL_HashAlgorithm algorithm, byte[] key)
+            internal AppleHmacProvider(Interop.AppleCrypto.PAL_HashAlgorithm algorithm, ReadOnlySpan<byte> key)
             {
-                _key = key.CloneByteArray();
+                _key = key.ToArray();
                 int hashSizeInBytes = 0;
                 _ctx = Interop.AppleCrypto.HmacCreate(algorithm, ref hashSizeInBytes);
 
index ff8e91e..1c85f79 100644 (file)
@@ -30,7 +30,7 @@ namespace Internal.Cryptography
             throw new CryptographicException(SR.Format(SR.Cryptography_UnknownHashAlgorithm, hashAlgorithmId));
         }
 
-        public static unsafe HashProvider CreateMacProvider(string hashAlgorithmId, byte[] key)
+        public static unsafe HashProvider CreateMacProvider(string hashAlgorithmId, ReadOnlySpan<byte> key)
         {
             switch (hashAlgorithmId)
             {
@@ -121,10 +121,9 @@ namespace Internal.Cryptography
             private readonly int _hashSize;
             private SafeHmacCtxHandle _hmacCtx;
 
-            public HmacHashProvider(IntPtr algorithmEvp, byte[] key)
+            public HmacHashProvider(IntPtr algorithmEvp, ReadOnlySpan<byte> key)
             {
                 Debug.Assert(algorithmEvp != IntPtr.Zero);
-                Debug.Assert(key != null);
 
                 _hashSize = Interop.Crypto.EvpMdSize(algorithmEvp);
                 if (_hashSize <= 0 || _hashSize > Interop.Crypto.EVP_MAX_MD_SIZE)
@@ -132,7 +131,7 @@ namespace Internal.Cryptography
                     throw new CryptographicException();
                 }
 
-                _hmacCtx = Interop.Crypto.HmacCreate(ref MemoryMarshal.GetReference(new Span<byte>(key)), key.Length, algorithmEvp);
+                _hmacCtx = Interop.Crypto.HmacCreate(ref MemoryMarshal.GetReference(key), key.Length, algorithmEvp);
                 Interop.Crypto.CheckValidOpenSslHandle(_hmacCtx);
             }
 
index 0121303..e5ddba7 100644 (file)
@@ -18,9 +18,9 @@ namespace Internal.Cryptography
             return new HashProviderCng(hashAlgorithmId, null);
         }
 
-        public static HashProvider CreateMacProvider(string hashAlgorithmId, byte[] key)
+        public static HashProvider CreateMacProvider(string hashAlgorithmId, ReadOnlySpan<byte> key)
         {
-            return new HashProviderCng(hashAlgorithmId, key);
+            return new HashProviderCng(hashAlgorithmId, key, isHmac: true);
         }
     }
 }
index 5b97a59..593d64c 100644 (file)
   <data name="Cryptography_WriteEncodedValue_OneValueAtATime" xml:space="preserve">
     <value>The input to WriteEncodedValue must represent a single encoded value with no trailing data.</value>
   </data>
+  <data name="Cryptography_Prk_TooSmall" xml:space="preserve">
+    <value>The pseudo-random key length must be {0} bytes.</value>
+  </data>
+  <data name="Cryptography_Okm_TooLarge" xml:space="preserve">
+    <value>Output keying material length can be at most {0} bytes (255 * hash length).</value>
+  </data>
 </root>
index d46e378..820621a 100644 (file)
@@ -1,4 +1,4 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
     <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
     <DefineConstants>$(DefineConstants);INTERNAL_ASYMMETRIC_IMPLEMENTATIONS</DefineConstants>
@@ -45,6 +45,7 @@
     <Compile Include="System\Security\Cryptography\ECDsa.Xml.cs" />
     <Compile Include="System\Security\Cryptography\ECParameters.cs" />
     <Compile Include="System\Security\Cryptography\ECPoint.cs" />
+    <Compile Include="System\Security\Cryptography\HKDF.cs" />
     <Compile Include="System\Security\Cryptography\MaskGenerationMethod.cs" />
     <Compile Include="System\Security\Cryptography\MD5.cs" />
     <Compile Include="System\Security\Cryptography\SHA1.cs" />
     <Reference Include="System.Runtime.Numerics" />
   </ItemGroup>
   <ItemGroup>
-    <None Include="@(AsnXml)" /> 
+    <None Include="@(AsnXml)" />
   </ItemGroup>
 </Project>
diff --git a/src/libraries/System.Security.Cryptography.Algorithms/src/System/Security/Cryptography/HKDF.cs b/src/libraries/System.Security.Cryptography.Algorithms/src/System/Security/Cryptography/HKDF.cs
new file mode 100644 (file)
index 0000000..23edddc
--- /dev/null
@@ -0,0 +1,262 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Diagnostics;
+
+namespace System.Security.Cryptography
+{
+    /// <summary>
+    /// RFC5869  HMAC-based Extract-and-Expand Key Derivation (HKDF)
+    /// </summary>
+    /// <remarks>
+    /// In situations where the input key material is already a uniformly random bitstring, the HKDF standard allows the Extract
+    /// phase to be skipped, and the master key to be used directly as the pseudorandom key.
+    /// See <a href="https://tools.ietf.org/html/rfc5869">RFC5869</a> for more information.
+    /// </remarks>
+    public static class HKDF
+    {
+        /// <summary>
+        /// Performs the HKDF-Extract function.
+        /// See section 2.2 of <a href="https://tools.ietf.org/html/rfc5869#section-2.2">RFC5869</a>
+        /// </summary>
+        /// <param name="hashAlgorithmName">The hash algorithm used for HMAC operations.</param>
+        /// <param name="ikm">The input keying material.</param>
+        /// <param name="salt">The optional salt value (a non-secret random value). If not provided it defaults to a byte array of <see cref="HashLength"/> zeros.</param>
+        /// <returns>The pseudo random key (prk).</returns>
+        public static byte[] Extract(HashAlgorithmName hashAlgorithmName, byte[] ikm, byte[] salt = null)
+        {
+            if (ikm == null)
+                throw new ArgumentNullException(nameof(ikm));
+
+            int hashLength = HashLength(hashAlgorithmName);
+            byte[] prk = new byte[hashLength];
+
+            Extract(hashAlgorithmName, hashLength, ikm, salt, prk);
+            return prk;
+        }
+
+        /// <summary>
+        /// Performs the HKDF-Extract function.
+        /// See section 2.2 of <a href="https://tools.ietf.org/html/rfc5869#section-2.2">RFC5869</a>
+        /// </summary>
+        /// <param name="hashAlgorithmName">The hash algorithm used for HMAC operations.</param>
+        /// <param name="ikm">The input keying material.</param>
+        /// <param name="salt">The salt value (a non-secret random value).</param>
+        /// <param name="prk">The destination buffer to receive the pseudo-random key (prk).</param>
+        /// <returns>The number of bytes written to the <paramref name="prk"/> buffer.</returns>
+        public static int Extract(HashAlgorithmName hashAlgorithmName, ReadOnlySpan<byte> ikm, ReadOnlySpan<byte> salt, Span<byte> prk)
+        {
+            int hashLength = HashLength(hashAlgorithmName);
+
+            if (prk.Length < hashLength)
+            {
+                throw new ArgumentException(SR.Format(SR.Cryptography_Prk_TooSmall, hashLength), nameof(prk));
+            }
+
+            if (prk.Length > hashLength)
+            {
+                prk = prk.Slice(0, hashLength);
+            }
+
+            Extract(hashAlgorithmName, hashLength, ikm, salt, prk);
+            return hashLength;
+        }
+
+        private static void Extract(HashAlgorithmName hashAlgorithmName, int hashLength, ReadOnlySpan<byte> ikm, ReadOnlySpan<byte> salt, Span<byte> prk)
+        {
+            Debug.Assert(HashLength(hashAlgorithmName) == hashLength);
+
+            using (IncrementalHash hmac = IncrementalHash.CreateHMAC(hashAlgorithmName, salt))
+            {
+                hmac.AppendData(ikm);
+                GetHashAndReset(hmac, prk);
+            }
+        }
+
+        /// <summary>
+        /// Performs the HKDF-Expand function
+        /// See section 2.3 of <a href="https://tools.ietf.org/html/rfc5869#section-2.3">RFC5869</a>
+        /// </summary>
+        /// <param name="hashAlgorithmName">The hash algorithm used for HMAC operations.</param>
+        /// <param name="prk">The pseudorandom key of at least <see cref="HashLength"/> bytes (usually the output from Expand step).</param>
+        /// <param name="outputLength">The length of the output keying material.</param>
+        /// <param name="info">The optional context and application specific information.</param>
+        /// <returns>The output keying material.</returns>
+        public static byte[] Expand(HashAlgorithmName hashAlgorithmName, byte[] prk, int outputLength, byte[] info = null)
+        {
+            if (prk == null)
+                throw new ArgumentNullException(nameof(prk));
+
+            int hashLength = HashLength(hashAlgorithmName);
+
+            // Constant comes from section 2.3 (the constraint on L in the Inputs section)
+            int maxOkmLength = 255 * hashLength;
+            if (outputLength <= 0 || outputLength > maxOkmLength)
+                throw new ArgumentOutOfRangeException(nameof(outputLength), SR.Format(SR.Cryptography_Okm_TooLarge, maxOkmLength));
+
+            byte[] result = new byte[outputLength];
+            Expand(hashAlgorithmName, hashLength, prk, result, info);
+
+            return result;
+        }
+
+        /// <summary>
+        /// Performs the HKDF-Expand function
+        /// See section 2.3 of <a href="https://tools.ietf.org/html/rfc5869#section-2.3">RFC5869</a>
+        /// </summary>
+        /// <param name="hashAlgorithmName">The hash algorithm used for HMAC operations.</param>
+        /// <param name="prk">The pseudorandom key of at least <see cref="HashLength"/> bytes (usually the output from Expand step).</param>
+        /// <param name="output">The destination buffer to receive the output keying material.</param>
+        /// <param name="info">The context and application specific information (can be an empty span).</param>
+        public static void Expand(HashAlgorithmName hashAlgorithmName, ReadOnlySpan<byte> prk, Span<byte> output, ReadOnlySpan<byte> info)
+        {
+            int hashLength = HashLength(hashAlgorithmName);
+
+            // Constant comes from section 2.3 (the constraint on L in the Inputs section)
+            int maxOkmLength = 255 * hashLength;
+            if (output.Length > maxOkmLength)
+                throw new ArgumentException(SR.Format(SR.Cryptography_Okm_TooLarge, maxOkmLength), nameof(output));
+
+            Expand(hashAlgorithmName, hashLength, prk, output, info);
+        }
+
+        private static void Expand(HashAlgorithmName hashAlgorithmName, int hashLength, ReadOnlySpan<byte> prk, Span<byte> output, ReadOnlySpan<byte> info)
+        {
+            Debug.Assert(HashLength(hashAlgorithmName) == hashLength);
+
+            if (prk.Length < hashLength)
+                throw new ArgumentException(SR.Format(SR.Cryptography_Prk_TooSmall, hashLength), nameof(prk));
+
+            Span<byte> counterSpan = stackalloc byte[1];
+            ref byte counter = ref counterSpan[0];
+            Span<byte> t = Span<byte>.Empty;
+            Span<byte> remainingOutput = output;
+
+            using (IncrementalHash hmac = IncrementalHash.CreateHMAC(hashAlgorithmName, prk))
+            {
+                for (int i = 1; ; i++)
+                {
+                    hmac.AppendData(t);
+                    hmac.AppendData(info);
+                    counter = (byte)i;
+                    hmac.AppendData(counterSpan);
+
+                    if (remainingOutput.Length >= hashLength)
+                    {
+                        t = remainingOutput.Slice(0, hashLength);
+                        remainingOutput = remainingOutput.Slice(hashLength);
+                        GetHashAndReset(hmac, t);
+                    }
+                    else
+                    {
+                        if (remainingOutput.Length > 0)
+                        {
+                            Debug.Assert(hashLength <= 512 / 8, "hashLength is larger than expected, consider increasing this value or using regular allocation");
+                            Span<byte> lastChunk = stackalloc byte[hashLength];
+                            GetHashAndReset(hmac, lastChunk);
+                            lastChunk.Slice(0, remainingOutput.Length).CopyTo(remainingOutput);
+                        }
+
+                        break;
+                    }
+                }
+            }
+        }
+
+        /// <summary>
+        /// Performs the key derivation HKDF Expand and Extract functions
+        /// </summary>
+        /// <param name="hashAlgorithmName">The hash algorithm used for HMAC operations.</param>
+        /// <param name="ikm">The input keying material.</param>
+        /// <param name="outputLength">The length of the output keying material.</param>
+        /// <param name="salt">The optional salt value (a non-secret random value). If not provided it defaults to a byte array of <see cref="HashLength"/> zeros.</param>
+        /// <param name="info">The optional context and application specific information.</param>
+        /// <returns>The output keying material.</returns>
+        public static byte[] DeriveKey(HashAlgorithmName hashAlgorithmName, byte[] ikm, int outputLength, byte[] salt = null, byte[] info = null)
+        {
+            if (ikm == null)
+                throw new ArgumentNullException(nameof(ikm));
+
+            int hashLength = HashLength(hashAlgorithmName);
+            Debug.Assert(hashLength <= 512 / 8, "hashLength is larger than expected, consider increasing this value or using regular allocation");
+
+            // Constant comes from section 2.3 (the constraint on L in the Inputs section)
+            int maxOkmLength = 255 * hashLength;
+            if (outputLength > maxOkmLength)
+                throw new ArgumentOutOfRangeException(nameof(outputLength), SR.Format(SR.Cryptography_Okm_TooLarge, maxOkmLength));
+
+            Span<byte> prk = stackalloc byte[hashLength];
+
+            Extract(hashAlgorithmName, hashLength, ikm, salt, prk);
+
+            byte[] result = new byte[outputLength];
+            Expand(hashAlgorithmName, hashLength, prk, result, info);
+
+            return result;
+        }
+
+        /// <summary>
+        /// Performs the key derivation HKDF Expand and Extract functions
+        /// </summary>
+        /// <param name="hashAlgorithmName">The hash algorithm used for HMAC operations.</param>
+        /// <param name="ikm">The input keying material.</param>
+        /// <param name="output">The output buffer representing output keying material.</param>
+        /// <param name="salt">The salt value (a non-secret random value).</param>
+        /// <param name="info">The context and application specific information (can be an empty span).</param>
+        public static void DeriveKey(HashAlgorithmName hashAlgorithmName, ReadOnlySpan<byte> ikm, Span<byte> output, ReadOnlySpan<byte> salt, ReadOnlySpan<byte> info)
+        {
+            int hashLength = HashLength(hashAlgorithmName);
+
+            // Constant comes from section 2.3 (the constraint on L in the Inputs section)
+            int maxOkmLength = 255 * hashLength;
+            if (output.Length > maxOkmLength)
+                throw new ArgumentException(SR.Format(SR.Cryptography_Okm_TooLarge, maxOkmLength), nameof(output));
+
+            Debug.Assert(hashLength <= 512 / 8, "hashLength is larger than expected, consider increasing this value or using regular allocation");
+            Span<byte> prk = stackalloc byte[hashLength];
+
+            Extract(hashAlgorithmName, hashLength, ikm, salt, prk);
+            Expand(hashAlgorithmName, hashLength, prk, output, info);
+        }
+
+        private static void GetHashAndReset(IncrementalHash hmac, Span<byte> output)
+        {
+            if (!hmac.TryGetHashAndReset(output, out int bytesWritten))
+            {
+                Debug.Assert(false, "HMAC operation failed unexpectedly");
+                throw new CryptographicException(SR.Arg_CryptographyException);
+            }
+
+            Debug.Assert(bytesWritten == output.Length, $"Bytes written is {bytesWritten} bytes which does not match output length ({output.Length} bytes)");
+        }
+
+        private static int HashLength(HashAlgorithmName hashAlgorithmName)
+        {
+            if (hashAlgorithmName == HashAlgorithmName.SHA1)
+            {
+                return 160 / 8;
+            }
+            else if (hashAlgorithmName == HashAlgorithmName.SHA256)
+            {
+                return 256 / 8;
+            }
+            else if (hashAlgorithmName == HashAlgorithmName.SHA384)
+            {
+                return 384 / 8;
+            }
+            else if (hashAlgorithmName == HashAlgorithmName.SHA512)
+            {
+                return 512 / 8;
+            }
+            else if (hashAlgorithmName == HashAlgorithmName.MD5)
+            {
+                return 128 / 8;
+            }
+            else
+            {
+                throw new ArgumentOutOfRangeException(nameof(hashAlgorithmName));
+            }
+        }
+    }
+}
index a3c1553..709c560 100644 (file)
@@ -209,6 +209,12 @@ namespace System.Security.Cryptography
         {
             if (key == null)
                 throw new ArgumentNullException(nameof(key));
+
+            return CreateHMAC(hashAlgorithm, (ReadOnlySpan<byte>)key);
+        }
+
+        internal static IncrementalHash CreateHMAC(HashAlgorithmName hashAlgorithm, ReadOnlySpan<byte> key)
+        {
             if (string.IsNullOrEmpty(hashAlgorithm.Name))
                 throw new ArgumentException(SR.Cryptography_HashAlgorithmNameNullOrEmpty, nameof(hashAlgorithm));
 
diff --git a/src/libraries/System.Security.Cryptography.Algorithms/tests/HKDFTests.cs b/src/libraries/System.Security.Cryptography.Algorithms/tests/HKDFTests.cs
new file mode 100644 (file)
index 0000000..afd7b42
--- /dev/null
@@ -0,0 +1,554 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.DotNet.XUnitExtensions;
+using Test.Cryptography;
+using Xunit;
+
+namespace System.Security.Cryptography.Algorithms.Tests
+{
+    public abstract class HKDFTests
+    {
+        protected abstract byte[] Extract(HashAlgorithmName hash, int prkLength, byte[] ikm, byte[] salt);
+        protected abstract byte[] Expand(HashAlgorithmName hash, byte[] prk, int outputLength, byte[] info);
+        protected abstract byte[] DeriveKey(HashAlgorithmName hash, byte[] ikm, int outputLength, byte[] salt, byte[] info);
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCases))]
+        public void Rfc5869ExtractTests(Rfc5869TestCase test)
+        {
+            byte[] prk = Extract(test.Hash, test.Prk.Length, test.Ikm, test.Salt);
+            Assert.Equal(test.Prk, prk);
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCases))]
+        public void Rfc5869ExtractTamperHashTests(Rfc5869TestCase test)
+        {
+            byte[] prk = Extract(HashAlgorithmName.MD5, 128 / 8, test.Ikm, test.Salt);
+            Assert.NotEqual(test.Prk, prk);
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCases))]
+        public void Rfc5869ExtractTamperIkmTests(Rfc5869TestCase test)
+        {
+            byte[] ikm = test.Ikm.ToArray();
+            ikm[0] ^= 1;
+            byte[] prk = Extract(test.Hash, test.Prk.Length, ikm, test.Salt);
+            Assert.NotEqual(test.Prk, prk);
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCasesWithNonEmptySalt))]
+        public void Rfc5869ExtractTamperSaltTests(Rfc5869TestCase test)
+        {
+            byte[] salt = test.Salt.ToArray();
+            salt[0] ^= 1;
+            byte[] prk = Extract(test.Hash, test.Prk.Length, test.Ikm, salt);
+            Assert.NotEqual(test.Prk, prk);
+        }
+
+        [Fact]
+        public void Rfc5869ExtractDefaultHash()
+        {
+            byte[] ikm = new byte[20];
+            byte[] salt = new byte[20];
+            AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                "hashAlgorithmName",
+                () => Extract(default(HashAlgorithmName), 20, ikm, salt));
+        }
+
+        [Fact]
+        public void Rfc5869ExtractNonsensicalHash()
+        {
+            byte[] ikm = new byte[20];
+            byte[] salt = new byte[20];
+            AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                "hashAlgorithmName",
+                () => Extract(new HashAlgorithmName("foo"), 20, ikm, salt));
+        }
+
+        [Fact]
+        public void Rfc5869ExtractEmptyIkm()
+        {
+            byte[] salt = new byte[20];
+            byte[] ikm = Array.Empty<byte>();
+
+            // Ensure does not throw
+            byte[] prk = Extract(HashAlgorithmName.SHA1, 20, ikm, salt);
+            Assert.Equal("FBDB1D1B18AA6C08324B7D64B71FB76370690E1D", prk.ByteArrayToHex());
+        }
+
+        [Fact]
+        public void Rfc5869ExtractEmptySalt()
+        {
+            byte[] ikm = new byte[20];
+            byte[] salt = Array.Empty<byte>();
+            byte[] prk = Extract(HashAlgorithmName.SHA1, 20, ikm, salt);
+            Assert.Equal("A3CBF4A40F51A53E046F07397E52DF9286AE93A2", prk.ByteArrayToHex());
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCases))]
+        public void Rfc5869ExpandTests(Rfc5869TestCase test)
+        {
+            byte[] okm = Expand(test.Hash, test.Prk, test.Okm.Length, test.Info);
+            Assert.Equal(test.Okm, okm);
+        }
+
+        [Fact]
+        public void Rfc5869ExpandDefaultHash()
+        {
+            byte[] prk = new byte[20];
+            AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                "hashAlgorithmName",
+                () => Expand(default(HashAlgorithmName), prk, 20, null));
+        }
+
+        [Fact]
+        public void Rfc5869ExpandNonsensicalHash()
+        {
+            byte[] prk = new byte[20];
+            AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                "hashAlgorithmName",
+                () => Expand(new HashAlgorithmName("foo"), prk, 20, null));
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCases))]
+        public void Rfc5869ExpandTamperPrkTests(Rfc5869TestCase test)
+        {
+            byte[] prk = test.Prk.ToArray();
+            prk[0] ^= 1;
+            byte[] okm = Expand(test.Hash, prk, test.Okm.Length, test.Info);
+            Assert.NotEqual(test.Okm, okm);
+        }
+
+        [Theory]
+        [MemberData(nameof(GetPrkTooShortTestCases))]
+        public void Rfc5869ExpandPrkTooShort(HashAlgorithmName hash, int prkSize)
+        {
+            byte[] prk = new byte[prkSize];
+            AssertExtensions.Throws<ArgumentException>(
+                "prk",
+                () => Expand(hash, prk, 17, Array.Empty<byte>()));
+        }
+
+        [Fact]
+        public void Rfc5869ExpandOkmMaxSize()
+        {
+            byte[] prk = new byte[20];
+
+            // Does not throw
+            byte[] okm = Expand(HashAlgorithmName.SHA1, prk, 20 * 255, Array.Empty<byte>());
+            Assert.Equal(20 * 255, okm.Length);
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCases))]
+        public void Rfc5869DeriveKeyTests(Rfc5869TestCase test)
+        {
+            byte[] okm = DeriveKey(test.Hash, test.Ikm, test.Okm.Length, test.Salt, test.Info);
+            Assert.Equal(test.Okm, okm);
+        }
+
+        [Fact]
+        public void Rfc5869DeriveKeyDefaultHash()
+        {
+            byte[] ikm = new byte[20];
+            AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                "hashAlgorithmName",
+                () => DeriveKey(default(HashAlgorithmName), ikm, 20, Array.Empty<byte>(), Array.Empty<byte>()));
+        }
+
+        [Fact]
+        public void Rfc5869DeriveKeyNonSensicalHash()
+        {
+            byte[] ikm = new byte[20];
+            AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                "hashAlgorithmName",
+                () => DeriveKey(new HashAlgorithmName("foo"), ikm, 20, Array.Empty<byte>(), Array.Empty<byte>()));
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCases))]
+        public void Rfc5869DeriveKeyTamperIkmTests(Rfc5869TestCase test)
+        {
+            byte[] ikm = test.Ikm.ToArray();
+            ikm[0] ^= 1;
+            byte[] okm = DeriveKey(test.Hash, ikm, test.Okm.Length, test.Salt, test.Info);
+            Assert.NotEqual(test.Okm, okm);
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCasesWithNonEmptySalt))]
+        public void Rfc5869DeriveKeyTamperSaltTests(Rfc5869TestCase test)
+        {
+            byte[] salt = test.Salt.ToArray();
+            salt[0] ^= 1;
+            byte[] okm = DeriveKey(test.Hash, test.Ikm, test.Okm.Length, salt, test.Info);
+            Assert.NotEqual(test.Okm, okm);
+        }
+
+        [Theory]
+        [MemberData(nameof(GetRfc5869TestCasesWithNonEmptyInfo))]
+        public void Rfc5869DeriveKeyTamperInfoTests(Rfc5869TestCase test)
+        {
+            byte[] info = test.Info.ToArray();
+            info[0] ^= 1;
+            byte[] okm = DeriveKey(test.Hash, test.Ikm, test.Okm.Length, test.Salt, info);
+            Assert.NotEqual(test.Okm, okm);
+        }
+
+        public static IEnumerable<object[]> GetRfc5869TestCases()
+        {
+            foreach (Rfc5869TestCase test in Rfc5869TestCases)
+            {
+                yield return new object[] { test };
+            }
+        }
+
+        public static IEnumerable<object[]> GetRfc5869TestCasesWithNonEmptySalt()
+        {
+            foreach (Rfc5869TestCase test in Rfc5869TestCases)
+            {
+                if (test.Salt != null && test.Salt.Length != 0)
+                {
+                    yield return new object[] { test };
+                }
+            }
+        }
+
+        public static IEnumerable<object[]> GetRfc5869TestCasesWithNonEmptyInfo()
+        {
+            foreach (Rfc5869TestCase test in Rfc5869TestCases)
+            {
+                if (test.Info != null && test.Info.Length != 0)
+                {
+                    yield return new object[] { test };
+                }
+            }
+        }
+
+        public static IEnumerable<object[]> GetPrkTooShortTestCases()
+        {
+            yield return new object[] { HashAlgorithmName.SHA1, 0 };
+            yield return new object[] { HashAlgorithmName.SHA1, 1 };
+            yield return new object[] { HashAlgorithmName.SHA1, 160 / 8 - 1 };
+            yield return new object[] { HashAlgorithmName.SHA256, 256 / 8 - 1 };
+            yield return new object[] { HashAlgorithmName.SHA512, 512 / 8 - 1 };
+            yield return new object[] { HashAlgorithmName.MD5, 128 / 8 - 1 };
+        }
+
+        private static Rfc5869TestCase[] Rfc5869TestCases { get; } = new Rfc5869TestCase[7]
+        {
+            new Rfc5869TestCase()
+            {
+                Name = "Basic test case with SHA-256",
+                Hash = HashAlgorithmName.SHA256,
+                Ikm = "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b".HexToByteArray(),
+                Salt = "000102030405060708090a0b0c".HexToByteArray(),
+                Info = "f0f1f2f3f4f5f6f7f8f9".HexToByteArray(),
+                Prk = (
+                    "077709362c2e32df0ddc3f0dc47bba63" +
+                    "90b6c73bb50f9c3122ec844ad7c2b3e5").HexToByteArray(),
+                Okm = (
+                    "3cb25f25faacd57a90434f64d0362f2a" +
+                    "2d2d0a90cf1a5a4c5db02d56ecc4c5bf" +
+                    "34007208d5b887185865").HexToByteArray(),
+            },
+            new Rfc5869TestCase()
+            {
+                Name = "Test with SHA-256 and longer inputs/outputs",
+                Hash = HashAlgorithmName.SHA256,
+                Ikm = (
+                    "000102030405060708090a0b0c0d0e0f" +
+                    "101112131415161718191a1b1c1d1e1f" +
+                    "202122232425262728292a2b2c2d2e2f" +
+                    "303132333435363738393a3b3c3d3e3f" +
+                    "404142434445464748494a4b4c4d4e4f").HexToByteArray(),
+                Salt = (
+                    "606162636465666768696a6b6c6d6e6f" +
+                    "707172737475767778797a7b7c7d7e7f" +
+                    "808182838485868788898a8b8c8d8e8f" +
+                    "909192939495969798999a9b9c9d9e9f" +
+                    "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf").HexToByteArray(),
+                Info = (
+                    "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf" +
+                    "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" +
+                    "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" +
+                    "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" +
+                    "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff").HexToByteArray(),
+                Prk = (
+                    "06a6b88c5853361a06104c9ceb35b45c" +
+                    "ef760014904671014a193f40c15fc244").HexToByteArray(),
+                Okm = (
+                    "b11e398dc80327a1c8e7f78c596a4934" +
+                    "4f012eda2d4efad8a050cc4c19afa97c" +
+                    "59045a99cac7827271cb41c65e590e09" +
+                    "da3275600c2f09b8367793a9aca3db71" +
+                    "cc30c58179ec3e87c14c01d5c1f3434f" +
+                    "1d87").HexToByteArray(),
+            },
+            new Rfc5869TestCase()
+            {
+                Name = "Test with SHA-256 and zero-length salt/info",
+                Hash = HashAlgorithmName.SHA256,
+                Ikm = "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b".HexToByteArray(),
+                Salt = Array.Empty<byte>(),
+                Info = Array.Empty<byte>(),
+                Prk = (
+                    "19ef24a32c717b167f33a91d6f648bdf" +
+                    "96596776afdb6377ac434c1c293ccb04").HexToByteArray(),
+                Okm = (
+                    "8da4e775a563c18f715f802a063c5a31" +
+                    "b8a11f5c5ee1879ec3454e5f3c738d2d" +
+                    "9d201395faa4b61a96c8").HexToByteArray(),
+            },
+            new Rfc5869TestCase()
+            {
+                Name = "Basic test case with SHA-1",
+                Hash = HashAlgorithmName.SHA1,
+                Ikm = "0b0b0b0b0b0b0b0b0b0b0b".HexToByteArray(),
+                Salt = "000102030405060708090a0b0c".HexToByteArray(),
+                Info = "f0f1f2f3f4f5f6f7f8f9".HexToByteArray(),
+                Prk = "9b6c18c432a7bf8f0e71c8eb88f4b30baa2ba243".HexToByteArray(),
+                Okm = (
+                    "085a01ea1b10f36933068b56efa5ad81" +
+                    "a4f14b822f5b091568a9cdd4f155fda2" +
+                    "c22e422478d305f3f896").HexToByteArray(),
+            },
+            new Rfc5869TestCase()
+            {
+                Name = "Test with SHA-1 and longer inputs/outputs",
+                Hash = HashAlgorithmName.SHA1,
+                Ikm = (
+                    "000102030405060708090a0b0c0d0e0f" +
+                    "101112131415161718191a1b1c1d1e1f" +
+                    "202122232425262728292a2b2c2d2e2f" +
+                    "303132333435363738393a3b3c3d3e3f" +
+                    "404142434445464748494a4b4c4d4e4f").HexToByteArray(),
+                Salt = (
+                    "606162636465666768696a6b6c6d6e6f" +
+                    "707172737475767778797a7b7c7d7e7f" +
+                    "808182838485868788898a8b8c8d8e8f" +
+                    "909192939495969798999a9b9c9d9e9f" +
+                    "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf").HexToByteArray(),
+                Info = (
+                    "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf" +
+                    "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" +
+                    "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" +
+                    "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" +
+                    "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff").HexToByteArray(),
+                Prk = "8adae09a2a307059478d309b26c4115a224cfaf6".HexToByteArray(),
+                Okm = (
+                    "0bd770a74d1160f7c9f12cd5912a06eb" +
+                    "ff6adcae899d92191fe4305673ba2ffe" +
+                    "8fa3f1a4e5ad79f3f334b3b202b2173c" +
+                    "486ea37ce3d397ed034c7f9dfeb15c5e" +
+                    "927336d0441f4c4300e2cff0d0900b52" +
+                    "d3b4").HexToByteArray(),
+            },
+            new Rfc5869TestCase()
+            {
+                Name = "Test with SHA-1 and zero-length salt/info",
+                Hash = HashAlgorithmName.SHA1,
+                Ikm = "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b".HexToByteArray(),
+                Salt = Array.Empty<byte>(),
+                Info = Array.Empty<byte>(),
+                Prk = "da8c8a73c7fa77288ec6f5e7c297786aa0d32d01".HexToByteArray(),
+                Okm = (
+                    "0ac1af7002b3d761d1e55298da9d0506" +
+                    "b9ae52057220a306e07b6b87e8df21d0" +
+                    "ea00033de03984d34918").HexToByteArray(),
+            },
+            new Rfc5869TestCase()
+            {
+                Name = "Test with SHA-1, salt not provided (defaults to HashLen zero octets), zero-length info",
+                Hash = HashAlgorithmName.SHA1,
+                Ikm = "0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c".HexToByteArray(),
+                Salt = null,
+                Info = Array.Empty<byte>(),
+                Prk = "2adccada18779e7c2077ad2eb19d3f3e731385dd".HexToByteArray(),
+                Okm = (
+                    "2c91117204d745f3500d636a62f64f0a" +
+                    "b3bae548aa53d423b0d1f27ebba6f5e5" +
+                    "673a081d70cce7acfc48").HexToByteArray(),
+            },
+        };
+
+        public struct Rfc5869TestCase
+        {
+            public string Name { get; set; }
+            public HashAlgorithmName Hash { get; set; }
+            public byte[] Ikm { get; set; }
+            public byte[] Salt { get; set; }
+            public byte[] Info { get; set; }
+            public byte[] Prk { get; set; }
+            public byte[] Okm { get; set; }
+
+            public override string ToString() => Name;
+        }
+
+        public class HkdfByteArrayTests : HKDFTests
+        {
+            protected override byte[] Extract(HashAlgorithmName hash, int prkLength, byte[] ikm, byte[] salt)
+            {
+                return HKDF.Extract(hash, ikm, salt);
+            }
+
+            protected override byte[] Expand(HashAlgorithmName hash, byte[] prk, int outputLength, byte[] info)
+            {
+                return HKDF.Expand(hash, prk, outputLength, info);
+            }
+
+            protected override byte[] DeriveKey(HashAlgorithmName hash, byte[] ikm, int outputLength, byte[] salt, byte[] info)
+            {
+                return HKDF.DeriveKey(hash, ikm, outputLength, salt, info);
+            }
+
+            [Fact]
+            public void Rfc5869ExtractNullIkm()
+            {
+                byte[] salt = new byte[20];
+                AssertExtensions.Throws<ArgumentNullException>(
+                    "ikm",
+                    () => HKDF.Extract(HashAlgorithmName.SHA1, null, salt));
+            }
+
+            [Fact]
+            public void Rfc5869ExpandOkmMaxSizePlusOne()
+            {
+                byte[] prk = new byte[20];
+                AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                    "outputLength",
+                    () => HKDF.Expand(HashAlgorithmName.SHA1, prk, 20 * 255 + 1, Array.Empty<byte>()));
+            }
+
+            [Fact]
+            public void Rfc5869ExpandOkmPotentiallyOverflowingValue()
+            {
+                byte[] prk = new byte[20];
+                AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                    "outputLength",
+                    () => HKDF.Expand(HashAlgorithmName.SHA1, prk, 8421505, Array.Empty<byte>()));
+            }
+
+            [Fact]
+            public void Rfc5869DeriveKeyNullIkm()
+            {
+                AssertExtensions.Throws<ArgumentNullException>(
+                    "ikm",
+                    () => HKDF.DeriveKey(HashAlgorithmName.SHA1, null, 20, Array.Empty<byte>(), Array.Empty<byte>()));
+            }
+
+            [Fact]
+            public void Rfc5869DeriveKeyOkmMaxSizePlusOne()
+            {
+                byte[] ikm = new byte[20];
+                AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                    "outputLength",
+                    () => HKDF.DeriveKey(HashAlgorithmName.SHA1, ikm, 20 * 255 + 1, Array.Empty<byte>(), Array.Empty<byte>()));
+            }
+
+            [Fact]
+            public void Rfc5869DeriveKeyOkmPotentiallyOverflowingValue()
+            {
+                byte[] ikm = new byte[20];
+                AssertExtensions.Throws<ArgumentOutOfRangeException>(
+                    "outputLength",
+                    () => HKDF.DeriveKey(HashAlgorithmName.SHA1, ikm, 8421505, Array.Empty<byte>(), Array.Empty<byte>()));
+            }
+        }
+
+        public class HkdfSpanTests : HKDFTests
+        {
+            protected override byte[] Extract(HashAlgorithmName hash, int prkLength, byte[] ikm, byte[] salt)
+            {
+                byte[] prk = new byte[prkLength];
+                Assert.Equal(prkLength, HKDF.Extract(hash, ikm, salt, prk));
+                return prk;
+            }
+
+            protected override byte[] Expand(HashAlgorithmName hash, byte[] prk, int outputLength, byte[] info)
+            {
+                byte[] output = new byte[outputLength];
+                HKDF.Expand(hash, prk, output, info);
+                return output;
+            }
+
+            protected override byte[] DeriveKey(HashAlgorithmName hash, byte[] ikm, int outputLength, byte[] salt, byte[] info)
+            {
+                byte[] output = new byte[outputLength];
+                HKDF.DeriveKey(hash, ikm, output, salt, info);
+                return output;
+            }
+
+            [Fact]
+            public void Rfc5869ExtractPrkTooLong()
+            {
+                byte[] prk = new byte[24];
+
+                for (int i = 0; i < 4; i++)
+                {
+                    prk[20 + i] = (byte)(i + 5);
+                }
+
+                byte[] ikm = new byte[20];
+                byte[] salt = new byte[20];
+                Assert.Equal(20, HKDF.Extract(HashAlgorithmName.SHA1, ikm, salt, prk));
+                Assert.Equal("A3CBF4A40F51A53E046F07397E52DF9286AE93A2", prk.AsSpan(0, 20).ByteArrayToHex());
+
+                for (int i = 0; i < 4; i++)
+                {
+                    // ensure we didn't modify anything further
+                    Assert.Equal((byte)(i + 5), prk[20 + i]);
+                }
+            }
+
+            [Fact]
+            public void Rfc5869OkmMaxSizePlusOne()
+            {
+                byte[] prk = new byte[20];
+                byte[] okm = new byte[20 * 255 + 1];
+                AssertExtensions.Throws<ArgumentException>(
+                    "output",
+                    () => HKDF.Expand(HashAlgorithmName.SHA1, prk, okm, Array.Empty<byte>()));
+            }
+
+            [Fact]
+            public void Rfc5869OkmMaxSizePotentiallyOverflowingValue()
+            {
+                byte[] prk = new byte[20];
+                byte[] okm = new byte[8421505];
+                AssertExtensions.Throws<ArgumentException>(
+                    "output",
+                    () => HKDF.Expand(HashAlgorithmName.SHA1, prk, okm, Array.Empty<byte>()));
+            }
+
+            [Fact]
+            public void Rfc5869DeriveKeySpanOkmMaxSizePlusOne()
+            {
+                byte[] ikm = new byte[20];
+                byte[] okm = new byte[20 * 255 + 1];
+                AssertExtensions.Throws<ArgumentException>(
+                    "output",
+                    () => HKDF.DeriveKey(HashAlgorithmName.SHA1, ikm, okm, Array.Empty<byte>(), Array.Empty<byte>()));
+            }
+
+            [Fact]
+            public void Rfc5869DeriveKeySpanOkmPotentiallyOverflowingValue()
+            {
+                byte[] ikm = new byte[20];
+                byte[] okm = new byte[8421505];
+                AssertExtensions.Throws<ArgumentException>(
+                    "output",
+                    () => HKDF.DeriveKey(HashAlgorithmName.SHA1, ikm, okm, Array.Empty<byte>(), Array.Empty<byte>()));
+            }
+        }
+    }
+}