Add Base64.IsValid and allow Base64.DecodeXx methods to skip whitespace (#85938)
authorStephen Toub <stoub@microsoft.com>
Tue, 9 May 2023 12:32:08 +0000 (08:32 -0400)
committerGitHub <noreply@github.com>
Tue, 9 May 2023 12:32:08 +0000 (08:32 -0400)
* Allow Base64Decoder to ignore space chars, add IsValid methods and tests

* Some cleanup of Base64.IsValid changes

This includes making FromBase64Transform significantly faster via SearchValues.

* Address PR feedback and some more cleanup

---------

Co-authored-by: Heath Baron-Morgan <heathbm@outlook.com>
src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs
src/libraries/System.Memory/tests/Base64/Base64TestBase.cs [new file with mode: 0644]
src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs
src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs [new file with mode: 0644]
src/libraries/System.Memory/tests/System.Memory.Tests.csproj
src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs
src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs
src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs [new file with mode: 0644]
src/libraries/System.Runtime/ref/System.Runtime.cs
src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs

index 0d792b2..bbd28fb 100644 (file)
@@ -1,12 +1,15 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
 using System.Text;
 using Xunit;
 
 namespace System.Buffers.Text.Tests
 {
-    public class Base64DecoderUnitTests
+    public class Base64DecoderUnitTests : Base64TestBase
     {
         [Fact]
         public void BasicDecoding()
@@ -157,7 +160,7 @@ namespace System.Buffers.Text.Tests
 
                 Span<byte> decodedBytes = new byte[3];
                 int consumed, written;
-                if (numBytes % 4 == 0)
+                if (numBytes >= 8)
                 {
                     Assert.True(OperationStatus.DestinationTooSmall ==
                         Base64.DecodeFromUtf8(source, decodedBytes, out consumed, out written), "Number of Input Bytes: " + numBytes);
@@ -373,8 +376,12 @@ namespace System.Buffers.Text.Tests
                 for (int i = 0; i < invalidBytes.Length; i++)
                 {
                     // Don't test padding (byte 61 i.e. '='), which is tested in DecodingInvalidBytesPadding
-                    if (invalidBytes[i] == Base64TestHelper.EncodingPad)
+                    // Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ')
+                    if (invalidBytes[i] == Base64TestHelper.EncodingPad ||
+                        Base64TestHelper.IsByteToBeIgnored(invalidBytes[i]))
+                    {
                         continue;
+                    }
 
                     // replace one byte with an invalid input
                     source[j] = invalidBytes[i];
@@ -568,8 +575,12 @@ namespace System.Buffers.Text.Tests
                     Span<byte> buffer = "2222PPPP"u8.ToArray(); // valid input
 
                     // Don't test padding (byte 61 i.e. '='), which is tested in DecodeInPlaceInvalidBytesPadding
-                    if (invalidBytes[i] == Base64TestHelper.EncodingPad)
+                    // Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ')
+                    if (invalidBytes[i] == Base64TestHelper.EncodingPad ||
+                        Base64TestHelper.IsByteToBeIgnored(invalidBytes[i]))
+                    {
                         continue;
+                    }
 
                     // replace one byte with an invalid input
                     buffer[j] = invalidBytes[i];
@@ -594,7 +605,7 @@ namespace System.Buffers.Text.Tests
             {
                 Span<byte> buffer = "2222PPP"u8.ToArray(); // incomplete input
                 Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8InPlace(buffer, out int bytesWritten));
-                Assert.Equal(0, bytesWritten);
+                Assert.Equal(3, bytesWritten);
             }
         }
 
@@ -667,5 +678,90 @@ namespace System.Buffers.Text.Tests
             }
         }
 
+        [Theory]
+        [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))]
+        public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored);
+            byte[] resultBytes = new byte[5];
+            OperationStatus result = Base64.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten);
+
+            // Control value from Convert.FromBase64String
+            byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored);
+
+            Assert.Equal(OperationStatus.Done, result);
+            Assert.Equal(utf8WithCharsToBeIgnored.Length, bytesConsumed);
+            Assert.Equal(expectedBytes.Length, bytesWritten);
+            Assert.True(expectedBytes.SequenceEqual(resultBytes));
+            Assert.True(stringBytes.SequenceEqual(resultBytes));
+        }
+
+        [Theory]
+        [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))]
+        public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes)
+        {
+            Span<byte> utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored);
+            OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten);
+            Span<byte> bytesOverwritten = utf8BytesWithByteToBeIgnored.Slice(0, bytesWritten);
+            byte[] resultBytesArray = bytesOverwritten.ToArray();
+
+            // Control value from Convert.FromBase64String
+            byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored);
+
+            Assert.Equal(OperationStatus.Done, result);
+            Assert.Equal(expectedBytes.Length, bytesWritten);
+            Assert.True(expectedBytes.SequenceEqual(resultBytesArray));
+            Assert.True(stringBytes.SequenceEqual(resultBytesArray));
+        }
+
+        [Theory]
+        [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))]
+        public void BasicDecodingWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored);
+            byte[] resultBytes = new byte[5];
+            OperationStatus result = Base64.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten);
+
+            Assert.Equal(OperationStatus.Done, result);
+            Assert.Equal(0, bytesWritten);
+        }
+
+        [Theory]
+        [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))]
+        public void DecodingInPlaceWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored)
+        {
+            Span<byte> utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored);
+            OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten);
+
+            Assert.Equal(OperationStatus.Done, result);
+            Assert.Equal(0, bytesWritten);
+        }
+
+        [Theory]
+        [MemberData(nameof(BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData))]
+        public void BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes(string inputString, int expectedConsumed, int expectedWritten)
+        {
+            Span<byte> source = Encoding.ASCII.GetBytes(inputString);
+            Span<byte> decodedBytes = new byte[Base64.GetMaxDecodedFromUtf8Length(source.Length)];
+
+            Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount));
+            Assert.Equal(expectedConsumed, consumed);
+            Assert.Equal(expectedWritten, decodedByteCount);
+            Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes));
+        }
+
+        public static IEnumerable<object[]> BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData()
+        {
+            var r = new Random(42);
+            for (int i = 0; i < 5; i++)
+            {
+                yield return new object[] { "AQ==" + new string(r.GetItems<char>(" \n\t\r", i)), 4 + i, 1 };
+            }
+
+            foreach (string s in new[] { "MTIz", "M TIz", "MT Iz", "MTI z", "MTIz ", "M    TI   z", "M T I Z " })
+            {
+                yield return new object[] { s + s + s + s, s.Length * 4, 12 };
+            }
+        }
     }
 }
diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs
new file mode 100644 (file)
index 0000000..882db30
--- /dev/null
@@ -0,0 +1,111 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.utf8Bytes, utf8Bytes.Length
+
+using System.Collections.Generic;
+using System.Text;
+
+namespace System.Buffers.Text.Tests
+{
+    public class Base64TestBase
+    {
+        public static IEnumerable<object[]> ValidBase64Strings_WithCharsThatMustBeIgnored()
+        {
+            // Create a Base64 string
+            string text = "a b c";
+            byte[] utf8Bytes = Encoding.UTF8.GetBytes(text);
+            string base64Utf8String = Convert.ToBase64String(utf8Bytes);
+
+            // Split the base64 string in half
+            int stringLength = base64Utf8String.Length / 2;
+            string firstSegment = base64Utf8String.Substring(0, stringLength);
+            string secondSegment = base64Utf8String.Substring(stringLength, stringLength);
+
+            // Insert ignored chars between the base 64 string
+            // One will have 1 char, another will have 3
+
+            // Line feed
+            yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 3), utf8Bytes };
+
+            // Horizontal tab
+            yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 3), utf8Bytes };
+
+            // Carriage return
+            yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 3), utf8Bytes };
+
+            // Space
+            yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 3), utf8Bytes };
+
+            string GetBase64StringWithPassedCharInsertedInTheMiddle(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{new string(charToInsert, numberOfTimesToInsert)}{secondSegment}";
+
+            // Insert ignored chars at the start of the base 64 string
+            // One will have 1 char, another will have 3
+
+            // Line feed
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 3), utf8Bytes };
+
+            // Horizontal tab
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 3), utf8Bytes };
+
+            // Carriage return
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 3), utf8Bytes };
+
+            // Space
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 3), utf8Bytes };
+
+            string GetBase64StringWithPassedCharInsertedAtTheStart(char charToInsert, int numberOfTimesToInsert) => $"{new string(charToInsert, numberOfTimesToInsert)}{firstSegment}{secondSegment}";
+
+            // Insert ignored chars at the end of the base 64 string
+            // One will have 1 char, another will have 3
+            // Whitespace after end/padding is not included in consumed bytes
+
+            // Line feed
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 3), utf8Bytes };
+
+            // Horizontal tab
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 3), utf8Bytes };
+
+            // Carriage return
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 3), utf8Bytes };
+
+            // Space
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 1), utf8Bytes };
+            yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 3), utf8Bytes };
+
+            string GetBase64StringWithPassedCharInsertedAtTheEnd(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{secondSegment}{new string(charToInsert, numberOfTimesToInsert)}";
+        }
+
+        public static IEnumerable<object[]> StringsOnlyWithCharsToBeIgnored()
+        {
+            // One will have 1 char, another will have 3
+
+            // Line feed
+            yield return new object[] { GetRepeatedChar(Convert.ToChar(9), 1) };
+            yield return new object[] { GetRepeatedChar(Convert.ToChar(9), 3) };
+
+            // Horizontal tab
+            yield return new object[] { GetRepeatedChar(Convert.ToChar(10), 1) };
+            yield return new object[] { GetRepeatedChar(Convert.ToChar(10), 3) };
+
+            // Carriage return
+            yield return new object[] { GetRepeatedChar(Convert.ToChar(13), 1) };
+            yield return new object[] { GetRepeatedChar(Convert.ToChar(13), 3) };
+
+            // Space
+            yield return new object[] { GetRepeatedChar(Convert.ToChar(32), 1) };
+            yield return new object[] { GetRepeatedChar(Convert.ToChar(32), 3) };
+
+            string GetRepeatedChar(char charToInsert, int numberOfTimesToInsert) => new string(charToInsert, numberOfTimesToInsert);
+        }
+    }
+}
index 7715f6b..1ccc8e0 100644 (file)
@@ -44,6 +44,8 @@ namespace System.Buffers.Text.Tests
             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         };
 
+        public static bool IsByteToBeIgnored(byte charByte) => charByte is (byte)' ' or (byte)'\t' or (byte)'\r' or (byte)'\n';
+
         public const byte EncodingPad = (byte)'=';      // '=', for padding
         public const sbyte InvalidByte = -1;            // Designating -1 for invalid bytes in the decoding map
 
diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs
new file mode 100644 (file)
index 0000000..c7f164a
--- /dev/null
@@ -0,0 +1,339 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Linq;
+using System.Text;
+using Xunit;
+
+namespace System.Buffers.Text.Tests
+{
+    public class Base64ValidationUnitTests : Base64TestBase
+    {
+        [Fact]
+        public void BasicValidationBytes()
+        {
+            var rnd = new Random(42);
+            for (int i = 0; i < 10; i++)
+            {
+                int numBytes;
+                do
+                {
+                    numBytes = rnd.Next(100, 1000 * 1000);
+                } while (numBytes % 4 != 0);    // ensure we have a valid length
+
+                Span<byte> source = new byte[numBytes];
+                Base64TestHelper.InitializeDecodableBytes(source, numBytes);
+
+                Assert.True(Base64.IsValid(source));
+                Assert.True(Base64.IsValid(source, out int decodedLength));
+                Assert.True(decodedLength > 0);
+            }
+        }
+
+        [Fact]
+        public void BasicValidationChars()
+        {
+            var rnd = new Random(42);
+            for (int i = 0; i < 10; i++)
+            {
+                int numBytes;
+                do
+                {
+                    numBytes = rnd.Next(100, 1000 * 1000);
+                } while (numBytes % 4 != 0);    // ensure we have a valid length
+
+                Span<byte> source = new byte[numBytes];
+                Base64TestHelper.InitializeDecodableBytes(source, numBytes);
+                Span<char> chars = source
+                    .ToArray()
+                    .Select(Convert.ToChar)
+                    .ToArray()
+                    .AsSpan();
+
+                Assert.True(Base64.IsValid(chars));
+                Assert.True(Base64.IsValid(chars, out int decodedLength));
+                Assert.True(decodedLength > 0);
+            }
+        }
+
+        [Fact]
+        public void BasicValidationInvalidInputLengthBytes()
+        {
+            var rnd = new Random(42);
+            for (int i = 0; i < 10; i++)
+            {
+                int numBytes;
+                do
+                {
+                    numBytes = rnd.Next(100, 1000 * 1000);
+                } while (numBytes % 4 == 0);    // ensure we have a invalid length
+
+                Span<byte> source = new byte[numBytes];
+
+                Assert.False(Base64.IsValid(source));
+                Assert.False(Base64.IsValid(source, out int decodedLength));
+                Assert.Equal(0, decodedLength);
+            }
+        }
+
+        [Fact]
+        public void BasicValidationInvalidInputLengthChars()
+        {
+            var rnd = new Random(42);
+            for (int i = 0; i < 10; i++)
+            {
+                int numBytes;
+                do
+                {
+                    numBytes = rnd.Next(100, 1000 * 1000);
+                } while (numBytes % 4 == 0);    // ensure we have a invalid length
+
+                Span<char> source = new char[numBytes];
+
+                Assert.False(Base64.IsValid(source));
+                Assert.False(Base64.IsValid(source, out int decodedLength));
+                Assert.Equal(0, decodedLength);
+            }
+        }
+
+        [Fact]
+        public void ValidateEmptySpanBytes()
+        {
+            Span<byte> source = Span<byte>.Empty;
+
+            Assert.True(Base64.IsValid(source));
+            Assert.True(Base64.IsValid(source, out int decodedLength));
+            Assert.Equal(0, decodedLength);
+        }
+
+        [Fact]
+        public void ValidateEmptySpanChars()
+        {
+            Span<char> source = Span<char>.Empty;
+
+            Assert.True(Base64.IsValid(source));
+            Assert.True(Base64.IsValid(source, out int decodedLength));
+            Assert.Equal(0, decodedLength);
+        }
+
+        [Fact]
+        public void ValidateGuidBytes()
+        {
+            Span<byte> source = new byte[24];
+            Span<byte> decodedBytes = Guid.NewGuid().ToByteArray();
+            Base64.EncodeToUtf8(decodedBytes, source, out int _, out int _);
+
+            Assert.True(Base64.IsValid(source));
+            Assert.True(Base64.IsValid(source, out int decodedLength));
+            Assert.True(decodedLength > 0);
+        }
+
+        [Fact]
+        public void ValidateGuidChars()
+        {
+            Span<byte> source = new byte[24];
+            Span<byte> decodedBytes = Guid.NewGuid().ToByteArray();
+            Base64.EncodeToUtf8(decodedBytes, source, out int _, out int _);
+            Span<char> chars = source
+                .ToArray()
+                .Select(Convert.ToChar)
+                .ToArray()
+                .AsSpan();
+
+            Assert.True(Base64.IsValid(chars));
+            Assert.True(Base64.IsValid(chars, out int decodedLength));
+            Assert.True(decodedLength > 0);
+        }
+
+        [Theory]
+        [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))]
+        public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored);
+
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(expectedBytes.Length, decodedLength);
+        }
+
+        [Theory]
+        [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))]
+        public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes)
+        {
+            ReadOnlySpan<char> utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray();
+
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(expectedBytes.Length, decodedLength);
+        }
+
+        [Theory]
+        [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))]
+        public void ValidateWithOnlyCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored);
+
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(0, decodedLength);
+        }
+
+        [Theory]
+        [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))]
+        public void ValidateWithOnlyCharsToBeIgnoredChars(string utf8WithByteToBeIgnored)
+        {
+            ReadOnlySpan<char> utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray();
+
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(0, decodedLength);
+        }
+
+        [Theory]
+        [InlineData("YQ==", 1)]
+        [InlineData("YWI=", 2)]
+        [InlineData("YWJj", 3)]
+        [InlineData(" YWI=", 2)]
+        [InlineData("Y WI=", 2)]
+        [InlineData("YW I=", 2)]
+        [InlineData("YWI =", 2)]
+        [InlineData("YWI= ", 2)]
+        [InlineData(" YQ==", 1)]
+        [InlineData("Y Q==", 1)]
+        [InlineData("YQ ==", 1)]
+        [InlineData("YQ= =", 1)]
+        [InlineData("YQ== ", 1)]
+        public void ValidateWithPaddingReturnsCorrectCountBytes(string utf8WithByteToBeIgnored, int expectedLength)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored);
+
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(expectedLength, decodedLength);
+        }
+
+        [Theory]
+        [InlineData("YQ==", 1)]
+        [InlineData("YWI=", 2)]
+        [InlineData("YWJj", 3)]
+        [InlineData(" YWI=", 2)]
+        [InlineData("Y WI=", 2)]
+        [InlineData("YW I=", 2)]
+        [InlineData("YWI =", 2)]
+        [InlineData("YWI= ", 2)]
+        [InlineData(" YQ==", 1)]
+        [InlineData("Y Q==", 1)]
+        [InlineData("YQ ==", 1)]
+        [InlineData("YQ= =", 1)]
+        [InlineData("YQ== ", 1)]
+        public void ValidateWithPaddingReturnsCorrectCountChars(string utf8WithByteToBeIgnored, int expectedLength)
+        {
+            ReadOnlySpan<char> utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray();
+
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(expectedLength, decodedLength);
+        }
+
+        [Theory]
+        [InlineData("YQ==", 1)]
+        [InlineData("YWI=", 2)]
+        [InlineData("YWJj", 3)]
+        public void DecodeEmptySpan(string utf8WithByteToBeIgnored, int expectedLength)
+        {
+            ReadOnlySpan<char> utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray();
+
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(expectedLength, decodedLength);
+        }
+
+        [Theory]
+        [InlineData("YWJ")]
+        [InlineData("YW")]
+        [InlineData("Y")]
+        public void InvalidSizeBytes(string utf8WithByteToBeIgnored)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored);
+
+            Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(0, decodedLength);
+        }
+
+        [Theory]
+        [InlineData("YWJ")]
+        [InlineData("YW")]
+        [InlineData("Y")]
+        public void InvalidSizeChars(string utf8WithByteToBeIgnored)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored);
+
+            Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(0, decodedLength);
+        }
+
+        [Theory]
+        [InlineData("YQ===")]
+        [InlineData("YQ=a=")]
+        [InlineData("YWI=a")]
+        [InlineData(" aYWI=a")]
+        [InlineData("a YWI=a")]
+        [InlineData("aY WI=a")]
+        [InlineData("aYW I=a")]
+        [InlineData("aYWI =a")]
+        [InlineData("aYWI= a")]
+        [InlineData("a YQ==a")]
+        [InlineData("aY Q==a")]
+        [InlineData("aYQ ==a")]
+        [InlineData("aYQ= =a")]
+        [InlineData("aYQ== a")]
+        [InlineData("aYQ==a ")]
+        public void InvalidBase64Bytes(string utf8WithByteToBeIgnored)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored);
+
+            Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(0, decodedLength);
+        }
+
+        [Theory]
+        [InlineData("YQ===")]
+        [InlineData("YQ=a=")]
+        [InlineData("YWI=a")]
+        [InlineData("a YWI=a")]
+        [InlineData("aY WI=a")]
+        [InlineData("aYW I=a")]
+        [InlineData("aYWI =a")]
+        [InlineData("aYWI= a")]
+        [InlineData("a YQ==a")]
+        [InlineData("aY Q==a")]
+        [InlineData("aYQ ==a")]
+        [InlineData("aYQ= =a")]
+        [InlineData("aYQ== a")]
+        [InlineData("aYQ==a ")]
+        [InlineData("a")]
+        [InlineData(" a")]
+        [InlineData("  a")]
+        [InlineData("   a")]
+        [InlineData("    a")]
+        [InlineData("a ")]
+        [InlineData("a  ")]
+        [InlineData("a   ")]
+        [InlineData("a    ")]
+        [InlineData(" a ")]
+        [InlineData("  a  ")]
+        [InlineData("   a   ")]
+        [InlineData("    a    ")]
+        public void InvalidBase64Chars(string utf8WithByteToBeIgnored)
+        {
+            byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored);
+
+            Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored));
+            Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
+            Assert.Equal(0, decodedLength);
+        }
+    }
+}
index ff1fe29..ca10d25 100644 (file)
     <Compile Include="Base64\Base64DecoderUnitTests.cs" />
     <Compile Include="Base64\Base64EncoderUnitTests.cs" />
     <Compile Include="Base64\Base64TestHelper.cs" />
+    <Compile Include="Base64\Base64TestBase.cs" />
+    <Compile Include="Base64\Base64ValidationUnitTests.cs" />
   </ItemGroup>
   <ItemGroup>
     <Compile Include="$(CommonTestPath)System\Buffers\NativeMemoryManager.cs"
index f4cec1a..4a1ed34 100644 (file)
     <Compile Include="$(MSBuildThisFileDirectory)System\Buffers\Text\Base64.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Buffers\Text\Base64Encoder.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Buffers\Text\Base64Decoder.cs" />
+    <Compile Include="$(MSBuildThisFileDirectory)System\Buffers\Text\Base64Validator.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Buffers\Text\FormattingHelpers.CountDigits.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Buffers\Text\Utf8Constants.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Buffers\Text\Utf8Formatter\FormattingHelpers.cs" />
index 2899c7d..5f49e54 100644 (file)
@@ -35,15 +35,66 @@ namespace System.Buffers.Text
         /// - InvalidData - if the input contains bytes outside of the expected base64 range, or if it contains invalid/more than two padding characters,
         ///   or if the input is incomplete (i.e. not a multiple of 4) and <paramref name="isFinalBlock"/> is <see langword="true"/>.
         /// </returns>
-        public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan<byte> utf8, Span<byte> bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true)
+        public static OperationStatus DecodeFromUtf8(ReadOnlySpan<byte> utf8, Span<byte> bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true)
         {
-            if (utf8.IsEmpty)
+            OperationStatus status = OperationStatus.Done;
+            bytesConsumed = 0;
+            bytesWritten = 0;
+
+            while (!utf8.IsEmpty)
             {
-                bytesConsumed = 0;
-                bytesWritten = 0;
-                return OperationStatus.Done;
+                status = DecodeFromUtf8Core(utf8, bytes, out int localConsumed, out int localWritten, isFinalBlock);
+                bytesConsumed += localConsumed;
+                bytesWritten += localWritten;
+
+                if (status is not OperationStatus.InvalidData)
+                {
+                    break;
+                }
+
+                utf8 = utf8.Slice(localConsumed);
+                bytes = bytes.Slice(localWritten);
+
+                if (utf8.IsEmpty)
+                {
+                    break;
+                }
+
+                localConsumed = IndexOfAnyExceptWhiteSpace(utf8);
+                if (localConsumed < 0)
+                {
+                    // The remainder of the input is all whitespace. Mark it all as having been consumed,
+                    // and mark the operation as being done.
+                    bytesConsumed += utf8.Length;
+                    status = OperationStatus.Done;
+                    break;
+                }
+
+                if (localConsumed == 0)
+                {
+                    // Non-whitespace was found at the beginning of the input. Since it wasn't consumed
+                    // by the previous call to DecodeFromUtf8Core, it must be part of a Base64 sequence
+                    // that was interrupted by whitespace or something else considered invalid.
+                    // Fall back to block-wise decoding. This is very slow, but it's also very non-standard
+                    // formatting of the input; whitespace is typically only found between blocks, such as
+                    // when Convert.ToBase64String inserts a line break every 76 output characters.
+                    return DecodeWithWhiteSpaceBlockwise(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock);
+                }
+
+                // Skip over the starting whitespace and continue.
+                bytesConsumed += localConsumed;
+                utf8 = utf8.Slice(localConsumed);
             }
 
+            return status;
+        }
+
+        /// <summary>
+        /// Core logic for decoding UTF-8 encoded text in base 64 into binary data.
+        /// </summary>
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static unsafe OperationStatus DecodeFromUtf8Core(ReadOnlySpan<byte> utf8, Span<byte> bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true)
+        {
             fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8))
             fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes))
             {
@@ -72,7 +123,9 @@ namespace System.Buffers.Text
                         Avx2Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes);
 
                         if (src == srcEnd)
+                        {
                             goto DoneExit;
+                        }
                     }
 
                     end = srcMax - 24;
@@ -81,7 +134,9 @@ namespace System.Buffers.Text
                         Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes);
 
                         if (src == srcEnd)
+                        {
                             goto DoneExit;
+                        }
                     }
                 }
 
@@ -109,7 +164,9 @@ namespace System.Buffers.Text
                     int result = Decode(src, ref decodingMap);
 
                     if (result < 0)
+                    {
                         goto InvalidDataExit;
+                    }
 
                     WriteThreeLowOrderBytes(dest, result);
                     src += 4;
@@ -117,17 +174,23 @@ namespace System.Buffers.Text
                 }
 
                 if (maxSrcLength != srcLength - skipLastChunk)
+                {
                     goto DestinationTooSmallExit;
+                }
 
                 // If input is less than 4 bytes, srcLength == sourceIndex == 0
                 // If input is not a multiple of 4, sourceIndex == srcLength != 0
                 if (src == srcEnd)
                 {
                     if (isFinalBlock)
+                    {
                         goto InvalidDataExit;
+                    }
 
                     if (src == srcBytes + utf8.Length)
+                    {
                         goto DoneExit;
+                    }
 
                     goto NeedMoreDataExit;
                 }
@@ -161,9 +224,13 @@ namespace System.Buffers.Text
                     i0 |= i2;
 
                     if (i0 < 0)
+                    {
                         goto InvalidDataExit;
+                    }
                     if (dest + 3 > destMax)
+                    {
                         goto DestinationTooSmallExit;
+                    }
 
                     WriteThreeLowOrderBytes(dest, i0);
                     dest += 3;
@@ -177,9 +244,13 @@ namespace System.Buffers.Text
                     i0 |= i2;
 
                     if (i0 < 0)
+                    {
                         goto InvalidDataExit;
+                    }
                     if (dest + 2 > destMax)
+                    {
                         goto DestinationTooSmallExit;
+                    }
 
                     dest[0] = (byte)(i0 >> 16);
                     dest[1] = (byte)(i0 >> 8);
@@ -188,9 +259,13 @@ namespace System.Buffers.Text
                 else
                 {
                     if (i0 < 0)
+                    {
                         goto InvalidDataExit;
+                    }
                     if (dest + 1 > destMax)
+                    {
                         goto DestinationTooSmallExit;
+                    }
 
                     dest[0] = (byte)(i0 >> 16);
                     dest += 1;
@@ -199,7 +274,9 @@ namespace System.Buffers.Text
                 src += 4;
 
                 if (srcLength != utf8.Length)
+                {
                     goto InvalidDataExit;
+                }
 
             DoneExit:
                 bytesConsumed = (int)(src - srcBytes);
@@ -208,7 +285,9 @@ namespace System.Buffers.Text
 
             DestinationTooSmallExit:
                 if (srcLength != utf8.Length && isFinalBlock)
+                {
                     goto InvalidDataExit; // if input is not a multiple of 4, and there is no more data, return invalid data instead
+                }
 
                 bytesConsumed = (int)(src - srcBytes);
                 bytesWritten = (int)(dest - destBytes);
@@ -227,7 +306,7 @@ namespace System.Buffers.Text
         }
 
         /// <summary>
-        /// Returns the maximum length (in bytes) of the result if you were to deocde base 64 encoded text within a byte span of size "length".
+        /// Returns the maximum length (in bytes) of the result if you were to decode base 64 encoded text within a byte span of size "length".
         /// </summary>
         /// <exception cref="System.ArgumentOutOfRangeException">
         /// Thrown when the specified <paramref name="length"/> is less than 0.
@@ -236,7 +315,9 @@ namespace System.Buffers.Text
         public static int GetMaxDecodedFromUtf8Length(int length)
         {
             if (length < 0)
+            {
                 ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.length);
+            }
 
             return (length >> 2) * 3;
         }
@@ -256,25 +337,138 @@ namespace System.Buffers.Text
         /// It does not return NeedMoreData since this method tramples the data in the buffer and
         /// hence can only be called once with all the data in the buffer.
         /// </returns>
-        public static unsafe OperationStatus DecodeFromUtf8InPlace(Span<byte> buffer, out int bytesWritten)
+        public static OperationStatus DecodeFromUtf8InPlace(Span<byte> buffer, out int bytesWritten)
         {
-            if (buffer.IsEmpty)
+            OperationStatus status = DecodeFromUtf8InPlaceCore(buffer, out bytesWritten, out uint sourceIndex);
+            Debug.Assert(status is OperationStatus.Done or OperationStatus.InvalidData, "These are the only statuses the method is coded to return.");
+
+            if (status != OperationStatus.Done)
             {
-                bytesWritten = 0;
-                return OperationStatus.Done;
+                // The input may have whitespace, attempt to decode while ignoring whitespace.
+                status = DecodeWithWhiteSpaceFromUtf8InPlace(buffer, ref bytesWritten, (int)sourceIndex);
+            }
+
+            return status;
+        }
+
+        private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan<byte> utf8, Span<byte> bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true)
+        {
+            const int BlockSize = 4;
+            Span<byte> buffer = stackalloc byte[BlockSize];
+            OperationStatus status = OperationStatus.Done;
+
+            while (!utf8.IsEmpty)
+            {
+                int encodedIdx = 0;
+                int bufferIdx = 0;
+                int skipped = 0;
+
+                for (; encodedIdx < utf8.Length && (uint)bufferIdx < (uint)buffer.Length; ++encodedIdx)
+                {
+                    if (IsWhiteSpace(utf8[encodedIdx]))
+                    {
+                        skipped++;
+                    }
+                    else
+                    {
+                        buffer[bufferIdx] = utf8[encodedIdx];
+                        bufferIdx++;
+                    }
+                }
+
+                utf8 = utf8.Slice(encodedIdx);
+                bytesConsumed += skipped;
+
+                if (bufferIdx == 0)
+                {
+                    continue;
+                }
+
+                bool hasAnotherBlock = utf8.Length >= BlockSize && bufferIdx == BlockSize;
+                bool localIsFinalBlock = !hasAnotherBlock;
+
+                // If this block contains padding and there's another block, then only whitespace may follow for being valid.
+                if (hasAnotherBlock)
+                {
+                    int paddingCount = GetPaddingCount(ref buffer[^1]);
+                    if (paddingCount > 0)
+                    {
+                        hasAnotherBlock = false;
+                        localIsFinalBlock = true;
+                    }
+                }
+
+                if (localIsFinalBlock && !isFinalBlock)
+                {
+                    localIsFinalBlock = false;
+                }
+
+                status = DecodeFromUtf8Core(buffer.Slice(0, bufferIdx), bytes, out int localConsumed, out int localWritten, localIsFinalBlock);
+                bytesConsumed += localConsumed;
+                bytesWritten += localWritten;
+
+                if (status != OperationStatus.Done)
+                {
+                    return status;
+                }
+
+                // The remaining data must all be whitespace in order to be valid.
+                if (!hasAnotherBlock)
+                {
+                    for (int i = 0; i < utf8.Length; ++i)
+                    {
+                        if (!IsWhiteSpace(utf8[i]))
+                        {
+                            // Revert previous dest increment, since an invalid state followed.
+                            bytesConsumed -= localConsumed;
+                            bytesWritten -= localWritten;
+
+                            return OperationStatus.InvalidData;
+                        }
+
+                        bytesConsumed++;
+                    }
+
+                    break;
+                }
+
+                bytes = bytes.Slice(localWritten);
             }
 
+            return status;
+        }
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static int GetPaddingCount(ref byte ptrToLastElement)
+        {
+            int padding = 0;
+
+            if (ptrToLastElement == EncodingPad) padding++;
+            if (Unsafe.Subtract(ref ptrToLastElement, 1) == EncodingPad) padding++;
+
+            return padding;
+        }
+
+        /// <summary>
+        /// Core logic for decoding UTF-8 encoded text in base 64 into binary data in place.
+        /// </summary>
+        private static unsafe OperationStatus DecodeFromUtf8InPlaceCore(Span<byte> buffer, out int bytesWritten, out uint sourceIndex)
+        {
             fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer))
             {
                 int bufferLength = buffer.Length;
-                uint sourceIndex = 0;
+                sourceIndex = 0;
                 uint destIndex = 0;
 
                 // only decode input if it is a multiple of 4
                 if (bufferLength != ((bufferLength >> 2) * 4))
+                {
                     goto InvalidExit;
+                }
                 if (bufferLength == 0)
+                {
                     goto DoneExit;
+                }
 
                 ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap);
 
@@ -282,7 +476,10 @@ namespace System.Buffers.Text
                 {
                     int result = Decode(bufferBytes + sourceIndex, ref decodingMap);
                     if (result < 0)
+                    {
                         goto InvalidExit;
+                    }
+
                     WriteThreeLowOrderBytes(bufferBytes + destIndex, result);
                     destIndex += 3;
                     sourceIndex += 4;
@@ -312,7 +509,9 @@ namespace System.Buffers.Text
                     i0 |= i2;
 
                     if (i0 < 0)
+                    {
                         goto InvalidExit;
+                    }
 
                     WriteThreeLowOrderBytes(bufferBytes + destIndex, i0);
                     destIndex += 3;
@@ -326,7 +525,9 @@ namespace System.Buffers.Text
                     i0 |= i2;
 
                     if (i0 < 0)
+                    {
                         goto InvalidExit;
+                    }
 
                     bufferBytes[destIndex] = (byte)(i0 >> 16);
                     bufferBytes[destIndex + 1] = (byte)(i0 >> 8);
@@ -335,7 +536,9 @@ namespace System.Buffers.Text
                 else
                 {
                     if (i0 < 0)
+                    {
                         goto InvalidExit;
+                    }
 
                     bufferBytes[destIndex] = (byte)(i0 >> 16);
                     destIndex += 1;
@@ -351,6 +554,76 @@ namespace System.Buffers.Text
             }
         }
 
+        private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span<byte> utf8, ref int destIndex, int sourceIndex)
+        {
+            const int BlockSize = 4;
+            Span<byte> buffer = stackalloc byte[BlockSize];
+
+            OperationStatus status = OperationStatus.Done;
+            int localDestIndex = destIndex;
+            bool hasPaddingBeenProcessed = false;
+            int localBytesWritten = 0;
+
+            while ((uint)sourceIndex < (uint)utf8.Length)
+            {
+                int bufferIdx = 0;
+
+                while (bufferIdx < BlockSize)
+                {
+                    if ((uint)sourceIndex >= (uint)utf8.Length) // TODO https://github.com/dotnet/runtime/issues/83349: move into the while condition once fixed
+                    {
+                        break;
+                    }
+
+                    if (!IsWhiteSpace(utf8[sourceIndex]))
+                    {
+                        buffer[bufferIdx] = utf8[sourceIndex];
+                        bufferIdx++;
+                    }
+
+                    sourceIndex++;
+                }
+
+                if (bufferIdx == 0)
+                {
+                    continue;
+                }
+
+                if (bufferIdx != 4)
+                {
+                    status = OperationStatus.InvalidData;
+                    break;
+                }
+
+                if (hasPaddingBeenProcessed)
+                {
+                    // Padding has already been processed, a new valid block cannot be processed.
+                    // Revert previous dest increment, since an invalid state followed.
+                    localDestIndex -= localBytesWritten;
+                    status = OperationStatus.InvalidData;
+                    break;
+                }
+
+                status = DecodeFromUtf8InPlaceCore(buffer, out localBytesWritten, out _);
+                localDestIndex += localBytesWritten;
+                hasPaddingBeenProcessed = localBytesWritten < 3;
+
+                if (status != OperationStatus.Done)
+                {
+                    break;
+                }
+
+                // Write result to source span in place.
+                for (int i = 0; i < localBytesWritten; i++)
+                {
+                    utf8[localDestIndex - localBytesWritten + i] = buffer[i];
+                }
+            }
+
+            destIndex = localDestIndex;
+            return status;
+        }
+
         [BypassReadyToRun]
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart)
@@ -433,7 +706,9 @@ namespace System.Buffers.Text
                 Vector256<sbyte> lo = Avx2.Shuffle(lutLo, loNibbles);
 
                 if (!Avx.TestZ(lo, hi))
+                {
                     break;
+                }
 
                 Vector256<sbyte> eq2F = Avx2.CompareEqual(str, mask2F);
                 Vector256<sbyte> shift = Avx2.Shuffle(lutShift, Avx2.Add(eq2F, hiNibbles));
@@ -598,7 +873,9 @@ namespace System.Buffers.Text
                 // Check for invalid input: if any "and" values from lo and hi are not zero,
                 // fall back on bytewise code to do error checking and reporting:
                 if ((lo & hi) != Vector128<byte>.Zero)
+                {
                     break;
+                }
 
                 Vector128<byte> eq2F = Vector128.Equals(str, mask2F);
                 Vector128<byte> shift = SimdShuffle(lutShift.AsByte(), (eq2F + hiNibbles), mask8F);
@@ -692,16 +969,54 @@ namespace System.Buffers.Text
         {
             destination[0] = (byte)(value >> 16);
             destination[1] = (byte)(value >> 8);
-            destination[2] = (byte)(value);
+            destination[2] = (byte)value;
+        }
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private static int IndexOfAnyExceptWhiteSpace(ReadOnlySpan<byte> span)
+        {
+            for (int i = 0; i < span.Length; i++)
+            {
+                if (!IsWhiteSpace(span[i]))
+                {
+                    return i;
+                }
+            }
+
+            return -1;
+        }
+
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        internal static bool IsWhiteSpace(int value)
+        {
+            if (Environment.Is64BitProcess)
+            {
+                // For description see https://github.com/dotnet/runtime/blob/48e74187cb15386c29eedaa046a5ee2c7ddef161/src/libraries/Common/src/System/HexConverter.cs#L314-L330
+                // Lookup bit mask for "\t\n\r ".
+                const ulong MagicConstant = 0xC800010000000000UL;
+                ulong i = (uint)value - '\t';
+                ulong shift = MagicConstant << (int)i;
+                ulong mask = i - 64;
+                return (long)(shift & mask) < 0;
+            }
+
+            if (value < 32)
+            {
+                const int BitMask = (1 << (int)'\t') | (1 << (int)'\n') | (1 << (int)'\r');
+                return ((1 << value) & BitMask) != 0;
+            }
+
+            return value == 32;
         }
 
         // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests)
-        private static ReadOnlySpan<sbyte> DecodingMap => new sbyte[] {
+        private static ReadOnlySpan<sbyte> DecodingMap => new sbyte[]
+        {
             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,         //62 is placed at index 43 (for +), 63 at index 47 (for /)
             52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,         //52-61 are placed at index 48-57 (for 0-9), 64 at index 61 (for =)
-            -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+            -1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
             15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,         //0-25 are placed at index 65-90 (for A-Z)
             -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
             41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,         //26-51 are placed at index 97-122 (for a-z)
index 83fe75a..e051bbe 100644 (file)
@@ -585,10 +585,10 @@ namespace System.Buffers.Text
             }
         }
 
-        private const uint EncodingPad = '='; // '=', for padding
+        internal const uint EncodingPad = '='; // '=', for padding
 
         private const int MaximumEncodeLength = (int.MaxValue / 4) * 3; // 1610612733
 
-        private static ReadOnlySpan<byte> EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8;
+        internal static ReadOnlySpan<byte> EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8;
     }
 }
diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs
new file mode 100644 (file)
index 0000000..403377f
--- /dev/null
@@ -0,0 +1,159 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Buffers.Text
+{
+    public static partial class Base64
+    {
+        /// <summary>Validates that the specified span of text is comprised of valid base-64 encoded data.</summary>
+        /// <param name="base64Text">A span of text to validate.</param>
+        /// <returns><see langword="true"/> if <paramref name="base64Text"/> contains a valid, decodable sequence of base-64 encoded data; otherwise, <see langword="false"/>.</returns>
+        /// <remarks>
+        /// If the method returns <see langword="true"/>, the same text passed to <see cref="Convert.FromBase64String(string)"/> and
+        /// <see cref="Convert.TryFromBase64Chars"/> would successfully decode (in the case
+        /// of <see cref="Convert.TryFromBase64Chars"/> assuming sufficient output space). Any amount of whitespace is allowed anywhere in the input,
+        /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n'.
+        /// </remarks>
+        public static bool IsValid(ReadOnlySpan<char> base64Text) =>
+            IsValid<char, Base64CharValidatable>(base64Text, out _);
+
+        /// <summary>Validates that the specified span of text is comprised of valid base-64 encoded data.</summary>
+        /// <param name="base64Text">A span of text to validate.</param>
+        /// <param name="decodedLength">If the method returns true, the number of decoded bytes that will result from decoding the input text.</param>
+        /// <returns><see langword="true"/> if <paramref name="base64Text"/> contains a valid, decodable sequence of base-64 encoded data; otherwise, <see langword="false"/>.</returns>
+        /// <remarks>
+        /// If the method returns <see langword="true"/>, the same text passed to <see cref="Convert.FromBase64String(string)"/> and
+        /// <see cref="Convert.TryFromBase64Chars"/> would successfully decode (in the case
+        /// of <see cref="Convert.TryFromBase64Chars"/> assuming sufficient output space). Any amount of whitespace is allowed anywhere in the input,
+        /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n'.
+        /// </remarks>
+        public static bool IsValid(ReadOnlySpan<char> base64Text, out int decodedLength) =>
+            IsValid<char, Base64CharValidatable>(base64Text, out decodedLength);
+
+        /// <summary>Validates that the specified span of UTF8 text is comprised of valid base-64 encoded data.</summary>
+        /// <param name="base64TextUtf8">A span of UTF8 text to validate.</param>
+        /// <returns><see langword="true"/> if <paramref name="base64TextUtf8"/> contains a valid, decodable sequence of base-64 encoded data; otherwise, <see langword="false"/>.</returns>
+        /// <remarks>
+        /// If the method returns <see langword="true"/>, the same text passed to <see cref="DecodeFromUtf8"/> and
+        /// <see cref="DecodeFromUtf8InPlace"/> would successfully decode. Any amount of whitespace is allowed anywhere in the input,
+        /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n' (as bytes).
+        /// </remarks>
+        public static bool IsValid(ReadOnlySpan<byte> base64TextUtf8) =>
+            IsValid<byte, Base64ByteValidatable>(base64TextUtf8, out _);
+
+        /// <summary>Validates that the specified span of UTF8 text is comprised of valid base-64 encoded data.</summary>
+        /// <param name="base64TextUtf8">A span of UTF8 text to validate.</param>
+        /// <param name="decodedLength">If the method returns true, the number of decoded bytes that will result from decoding the input UTF8 text.</param>
+        /// <returns><see langword="true"/> if <paramref name="base64TextUtf8"/> contains a valid, decodable sequence of base-64 encoded data; otherwise, <see langword="false"/>.</returns>
+        /// <remarks>
+        /// If the method returns <see langword="true"/>, the same text passed to <see cref="DecodeFromUtf8"/> and
+        /// <see cref="DecodeFromUtf8InPlace"/> would successfully decode. Any amount of whitespace is allowed anywhere in the input,
+        /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n' (as bytes).
+        /// </remarks>
+        public static bool IsValid(ReadOnlySpan<byte> base64TextUtf8, out int decodedLength) =>
+            IsValid<byte, Base64ByteValidatable>(base64TextUtf8, out decodedLength);
+
+        private static bool IsValid<T, TBase64Validatable>(ReadOnlySpan<T> base64Text, out int decodedLength)
+            where TBase64Validatable : IBase64Validatable<T>
+        {
+            int length = 0, paddingCount = 0;
+
+            if (!base64Text.IsEmpty)
+            {
+                while (true)
+                {
+                    int index = TBase64Validatable.IndexOfAnyExcept(base64Text);
+                    if ((uint)index >= (uint)base64Text.Length)
+                    {
+                        length += base64Text.Length;
+                        break;
+                    }
+
+                    length += index;
+
+                    T charToValidate = base64Text[index];
+                    base64Text = base64Text.Slice(index + 1);
+
+                    if (TBase64Validatable.IsWhiteSpace(charToValidate))
+                    {
+                        // It's common if there's whitespace for there to be multiple whitespace characters in a row,
+                        // e.g. \r\n.  Optimize for that case by looping here.
+                        while (!base64Text.IsEmpty && TBase64Validatable.IsWhiteSpace(base64Text[0]))
+                        {
+                            base64Text = base64Text.Slice(1);
+                        }
+                        continue;
+                    }
+
+                    if (!TBase64Validatable.IsEncodingPad(charToValidate))
+                    {
+                        // Invalid char was found.
+                        goto Fail;
+                    }
+
+                    // Encoding pad found. Determine if padding is valid, then stop processing.
+                    paddingCount = 1;
+                    foreach (T charToValidateInPadding in base64Text)
+                    {
+                        if (TBase64Validatable.IsEncodingPad(charToValidateInPadding))
+                        {
+                            // There can be at most 2 padding chars.
+                            if (paddingCount >= 2)
+                            {
+                                goto Fail;
+                            }
+
+                            paddingCount++;
+                        }
+                        else if (!TBase64Validatable.IsWhiteSpace(charToValidateInPadding))
+                        {
+                            // Invalid char was found.
+                            goto Fail;
+                        }
+                    }
+
+                    length += paddingCount;
+                    break;
+                }
+
+                if (length % 4 != 0)
+                {
+                    goto Fail;
+                }
+            }
+
+            // Remove padding to get exact length.
+            decodedLength = (int)((uint)length / 4 * 3) - paddingCount;
+            return true;
+
+            Fail:
+            decodedLength = 0;
+            return false;
+        }
+
+        private interface IBase64Validatable<T>
+        {
+            static abstract int IndexOfAnyExcept(ReadOnlySpan<T> span);
+            static abstract bool IsWhiteSpace(T value);
+            static abstract bool IsEncodingPad(T value);
+        }
+
+        private readonly struct Base64CharValidatable : IBase64Validatable<char>
+        {
+            private static readonly SearchValues<char> s_validBase64Chars = SearchValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");
+
+            public static int IndexOfAnyExcept(ReadOnlySpan<char> span) => span.IndexOfAnyExcept(s_validBase64Chars);
+            public static bool IsWhiteSpace(char value) => Base64.IsWhiteSpace(value);
+            public static bool IsEncodingPad(char value) => value == EncodingPad;
+        }
+
+        private readonly struct Base64ByteValidatable : IBase64Validatable<byte>
+        {
+            private static readonly SearchValues<byte> s_validBase64Chars = SearchValues.Create(EncodingMap);
+
+            public static int IndexOfAnyExcept(ReadOnlySpan<byte> span) => span.IndexOfAnyExcept(s_validBase64Chars);
+            public static bool IsWhiteSpace(byte value) => Base64.IsWhiteSpace(value);
+            public static bool IsEncodingPad(byte value) => value == EncodingPad;
+        }
+    }
+}
index 5e3e26c..f5a8ae2 100644 (file)
@@ -7342,6 +7342,10 @@ namespace System.Buffers.Text
         public static System.Buffers.OperationStatus EncodeToUtf8InPlace(System.Span<byte> buffer, int dataLength, out int bytesWritten) { throw null; }
         public static int GetMaxDecodedFromUtf8Length(int length) { throw null; }
         public static int GetMaxEncodedToUtf8Length(int length) { throw null; }
+        public static bool IsValid(System.ReadOnlySpan<char> base64Text) { throw null; }
+        public static bool IsValid(System.ReadOnlySpan<char> base64Text, out int decodedLength) { throw null; }
+        public static bool IsValid(System.ReadOnlySpan<byte> base64TextUtf8) { throw null; }
+        public static bool IsValid(System.ReadOnlySpan<byte> base64TextUtf8, out int decodedLength) { throw null; }
     }
 }
 namespace System.CodeDom.Compiler
index 544eb36..275a708 100644 (file)
@@ -108,9 +108,21 @@ namespace System.Security.Cryptography
 
     public class FromBase64Transform : ICryptoTransform
     {
+        /// <summary>Characters considered whitespace.</summary>
+        /// <remarks>
+        /// We assume ASCII encoded data. If there is any non-ASCII char, it is invalid
+        /// Base64 and will be caught during decoding.
+        /// SPACE        32
+        /// TAB           9
+        /// LF           10
+        /// VTAB         11
+        /// FORM FEED    12
+        /// CR           13
+        /// </remarks>
+        private static readonly SearchValues<byte> s_whiteSpace = SearchValues.Create(" \t\n\v\f\r"u8);
+        private readonly FromBase64TransformMode _whitespaces;
         private byte[] _inputBuffer = new byte[4];
         private int _inputIndex;
-        private readonly FromBase64TransformMode _whitespaces;
 
         public FromBase64Transform() : this(FromBase64TransformMode.IgnoreWhiteSpaces) { }
         public FromBase64Transform(FromBase64TransformMode whitespaces)
@@ -223,41 +235,35 @@ namespace System.Security.Cryptography
 
         private Span<byte> AppendInputBuffers(ReadOnlySpan<byte> inputBuffer, Span<byte> transformBuffer)
         {
-            _inputBuffer.AsSpan(0, _inputIndex).CopyTo(transformBuffer);
+            int index = _inputIndex;
+            _inputBuffer.AsSpan(0, index).CopyTo(transformBuffer);
 
             if (_whitespaces == FromBase64TransformMode.DoNotIgnoreWhiteSpaces)
             {
-                inputBuffer.CopyTo(transformBuffer.Slice(_inputIndex));
-                return transformBuffer.Slice(0, _inputIndex + inputBuffer.Length);
+                if (inputBuffer.IndexOfAny(s_whiteSpace) >= 0)
+                {
+                    ThrowHelper.ThrowBase64FormatException();
+                }
             }
             else
             {
-                int count = _inputIndex;
-                for (int i = 0; i < inputBuffer.Length; i++)
+                int whitespaceIndex;
+                while ((whitespaceIndex = inputBuffer.IndexOfAny(s_whiteSpace)) >= 0)
                 {
-                    if (!IsWhitespace(inputBuffer[i]))
+                    inputBuffer.Slice(0, whitespaceIndex).CopyTo(transformBuffer.Slice(index));
+                    index += whitespaceIndex;
+                    inputBuffer = inputBuffer.Slice(whitespaceIndex);
+
+                    do
                     {
-                        transformBuffer[count++] = inputBuffer[i];
+                        inputBuffer = inputBuffer.Slice(1);
                     }
+                    while (!inputBuffer.IsEmpty && s_whiteSpace.Contains(inputBuffer[0]));
                 }
-
-                return transformBuffer.Slice(0, count);
             }
-        }
-
-        private static bool IsWhitespace(byte value)
-        {
-            // We assume ASCII encoded data. If there is any non-ASCII char, it is invalid
-            // Base64 and will be caught during decoding.
-
-            // SPACE        32
-            // TAB           9
-            // LF           10
-            // VTAB         11
-            // FORM FEED    12
-            // CR           13
 
-            return value == 32 || ((uint)value - 9 <= (13 - 9));
+            inputBuffer.CopyTo(transformBuffer.Slice(index));
+            return transformBuffer.Slice(0, index + inputBuffer.Length);
         }
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]