BigInteger parsing optimizations (#47842)
authorJoseph Da Silva <39675835+jfd16@users.noreply.github.com>
Sun, 9 May 2021 06:33:05 +0000 (23:33 -0700)
committerGitHub <noreply@github.com>
Sun, 9 May 2021 06:33:05 +0000 (23:33 -0700)
* Optimize BigInteger parsing

* Use ArrayPool<int> instead of ArrayPool<uint>.

* Additional changes

* Additional changes

src/libraries/System.Runtime.Numerics/src/System/Numerics/BigInteger.cs
src/libraries/System.Runtime.Numerics/src/System/Numerics/BigNumber.cs

index 5edffac..aa90a8b 100644 (file)
@@ -464,7 +464,7 @@ namespace System.Numerics
             AssertValid();
         }
 
-        private BigInteger(int n, uint[]? rgu)
+        internal BigInteger(int n, uint[]? rgu)
         {
             _sign = n;
             _bits = rgu;
@@ -703,7 +703,7 @@ namespace System.Numerics
 
         public static bool TryParse(ReadOnlySpan<char> value, out BigInteger result)
         {
-            return BigNumber.TryParseBigInteger(value, NumberStyles.Integer, NumberFormatInfo.CurrentInfo, out result);
+            return TryParse(value, NumberStyles.Integer, NumberFormatInfo.CurrentInfo, out result);
         }
 
         public static bool TryParse(ReadOnlySpan<char> value, NumberStyles style, IFormatProvider? provider, out BigInteger result)
index 5f4fb2a..957bd73 100644 (file)
@@ -274,6 +274,7 @@ using System.Buffers;
 using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
 using System.Globalization;
+using System.Runtime.InteropServices;
 using System.Text;
 
 namespace System.Numerics
@@ -286,6 +287,8 @@ namespace System.Numerics
                                                            | NumberStyles.AllowThousands | NumberStyles.AllowExponent
                                                            | NumberStyles.AllowCurrencySymbol | NumberStyles.AllowHexSpecifier);
 
+        private static readonly uint[] s_uint32PowersOfTen = { 1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000 };
+
         private struct BigNumberBuffer
         {
             public StringBuilder digits;
@@ -326,7 +329,7 @@ namespace System.Numerics
         {
             if (value == null)
             {
-                result = default(BigInteger);
+                result = default;
                 return false;
             }
 
@@ -335,32 +338,25 @@ namespace System.Numerics
 
         internal static bool TryParseBigInteger(ReadOnlySpan<char> value, NumberStyles style, NumberFormatInfo info, out BigInteger result)
         {
-            unsafe
+            if (!TryValidateParseStyleInteger(style, out ArgumentException? e))
             {
-                result = BigInteger.Zero;
-                ArgumentException? e;
-                if (!TryValidateParseStyleInteger(style, out e))
-                    throw e; // TryParse still throws ArgumentException on invalid NumberStyles
+                throw e; // TryParse still throws ArgumentException on invalid NumberStyles
+            }
 
-                BigNumberBuffer bignumber = BigNumberBuffer.Create();
-                if (!FormatProvider.TryStringToBigInteger(value, style, info, bignumber.digits, out bignumber.precision, out bignumber.scale, out bignumber.sign))
-                    return false;
+            BigNumberBuffer bigNumber = BigNumberBuffer.Create();
+            if (!FormatProvider.TryStringToBigInteger(value, style, info, bigNumber.digits, out bigNumber.precision, out bigNumber.scale, out bigNumber.sign))
+            {
+                result = default;
+                return false;
+            }
 
-                if ((style & NumberStyles.AllowHexSpecifier) != 0)
-                {
-                    if (!HexNumberToBigInteger(ref bignumber, ref result))
-                    {
-                        return false;
-                    }
-                }
-                else
-                {
-                    if (!NumberToBigInteger(ref bignumber, ref result))
-                    {
-                        return false;
-                    }
-                }
-                return true;
+            if ((style & NumberStyles.AllowHexSpecifier) != 0)
+            {
+                return HexNumberToBigInteger(ref bigNumber, out result);
+            }
+            else
+            {
+                return NumberToBigInteger(ref bigNumber, out result);
             }
         }
 
@@ -376,83 +372,296 @@ namespace System.Numerics
 
         internal static BigInteger ParseBigInteger(ReadOnlySpan<char> value, NumberStyles style, NumberFormatInfo info)
         {
-            ArgumentException? e;
-            if (!TryValidateParseStyleInteger(style, out e))
+            if (!TryValidateParseStyleInteger(style, out ArgumentException? e))
+            {
                 throw e;
-
-            BigInteger result = BigInteger.Zero;
-            if (!TryParseBigInteger(value, style, info, out result))
+            }
+            if (!TryParseBigInteger(value, style, info, out BigInteger result))
             {
                 throw new FormatException(SR.Overflow_ParseBigInteger);
             }
             return result;
         }
 
-        private static unsafe bool HexNumberToBigInteger(ref BigNumberBuffer number, ref BigInteger value)
+        private static bool HexNumberToBigInteger(ref BigNumberBuffer number, out BigInteger result)
         {
             if (number.digits == null || number.digits.Length == 0)
+            {
+                result = default;
                 return false;
+            }
+
+            const int DigitsPerBlock = 8;
+
+            int totalDigitCount = number.digits.Length - 1;   // Ignore trailing '\0'
+            int blockCount, partialDigitCount;
+
+            blockCount = Math.DivRem(totalDigitCount, DigitsPerBlock, out int remainder);
+            if (remainder == 0)
+            {
+                partialDigitCount = 0;
+            }
+            else
+            {
+                blockCount += 1;
+                partialDigitCount = DigitsPerBlock - remainder;
+            }
+
+            bool isNegative = HexConverter.FromChar(number.digits[0]) >= 8;
+            uint partialValue = (isNegative && partialDigitCount > 0) ? 0xFFFFFFFFu : 0;
+
+            int[]? arrayFromPool = null;
 
-            int len = number.digits.Length - 1; // Ignore trailing '\0'
-            byte[] bits = new byte[(len / 2) + (len % 2)];
+            Span<uint> bitsBuffer = (blockCount <= BigInteger.StackallocUInt32Limit)
+                ? stackalloc uint[blockCount]
+                : MemoryMarshal.Cast<int, uint>((arrayFromPool = ArrayPool<int>.Shared.Rent(blockCount)).AsSpan(0, blockCount));
 
-            bool shift = false;
-            bool isNegative = false;
-            int bitIndex = 0;
+            int bitsBufferPos = blockCount - 1;
 
-            // Parse the string into a little-endian two's complement byte array
-            // string value     : O F E B 7 \0
-            // string index (i) : 0 1 2 3 4 5 <--
-            // byte[] (bitIndex): 2 1 1 0 0 <--
-            //
-            for (int i = len - 1; i > -1; i--)
+            try
             {
-                char c = number.digits[i];
-                int b = HexConverter.FromChar(c);
-                Debug.Assert(b != 0xFF);
-                if (i == 0 && (b & 0x08) == 0x08)
-                    isNegative = true;
+                foreach (ReadOnlyMemory<char> digitsChunkMem in number.digits.GetChunks())
+                {
+                    ReadOnlySpan<char> chunkDigits = digitsChunkMem.Span;
+                    for (int i = 0; i < chunkDigits.Length; i++)
+                    {
+                        char digitChar = chunkDigits[i];
+                        if (digitChar == '\0')
+                        {
+                            break;
+                        }
+
+                        int hexValue = HexConverter.FromChar(digitChar);
+                        Debug.Assert(hexValue != 0xFF);
+
+                        partialValue = (partialValue << 4) | (uint)hexValue;
+                        partialDigitCount++;
+
+                        if (partialDigitCount == DigitsPerBlock)
+                        {
+                            bitsBuffer[bitsBufferPos] = partialValue;
+                            bitsBufferPos--;
+                            partialValue = 0;
+                            partialDigitCount = 0;
+                        }
+                    }
+                }
+
+                Debug.Assert(partialDigitCount == 0 && bitsBufferPos == -1);
+
+                // BigInteger requires leading zero blocks to be truncated.
+                bitsBuffer = bitsBuffer.TrimEnd(0u);
+
+                int sign;
+                uint[]? bits;
 
-                if (shift)
+                if (bitsBuffer.IsEmpty)
                 {
-                    bits[bitIndex] = (byte)(bits[bitIndex] | (b << 4));
-                    bitIndex++;
+                    sign = 0;
+                    bits = null;
+                }
+                else if (bitsBuffer.Length == 1)
+                {
+                    sign = (int)bitsBuffer[0];
+                    bits = null;
+
+                    if ((!isNegative && sign < 0) || sign == int.MinValue)
+                    {
+                        sign = isNegative ? -1 : 1;
+                        bits = new[] { (uint)sign };
+                    }
                 }
                 else
                 {
-                    bits[bitIndex] = (byte)(isNegative ? (b | 0xF0) : (b));
+                    sign = isNegative ? -1 : 1;
+                    bits = bitsBuffer.ToArray();
+
+                    if (isNegative)
+                    {
+                        NumericsHelpers.DangerousMakeTwosComplement(bits);
+                    }
                 }
-                shift = !shift;
-            }
 
-            value = new BigInteger(bits);
-            return true;
+                result = new BigInteger(sign, bits);
+                return true;
+            }
+            finally
+            {
+                if (arrayFromPool != null)
+                {
+                    ArrayPool<int>.Shared.Return(arrayFromPool);
+                }
+            }
         }
 
-        private static unsafe bool NumberToBigInteger(ref BigNumberBuffer number, ref BigInteger value)
+        private static bool NumberToBigInteger(ref BigNumberBuffer number, out BigInteger result)
         {
-            int i = number.scale;
-            int cur = 0;
+            Span<uint> stackBuffer = stackalloc uint[BigInteger.StackallocUInt32Limit];
+            Span<uint> currentBuffer = stackBuffer;
+            int currentBufferSize = 0;
+            int[]? arrayFromPool = null;
+
+            uint partialValue = 0;
+            int partialDigitCount = 0;
+            int totalDigitCount = 0;
+            int numberScale = number.scale;
 
-            BigInteger ten = 10;
-            value = 0;
-            while (--i >= 0)
+            const int MaxPartialDigits = 9;
+            const uint TenPowMaxPartial = 1000000000;
+
+            try
             {
-                value *= ten;
-                if (number.digits[cur] != '\0')
+                foreach (ReadOnlyMemory<char> digitsChunk in number.digits.GetChunks())
+                {
+                    if (!ProcessChunk(digitsChunk.Span, ref currentBuffer))
+                    {
+                        result = default;
+                        return false;
+                    }
+                }
+
+                if (partialDigitCount > 0)
+                {
+                    MultiplyAdd(ref currentBuffer, s_uint32PowersOfTen[partialDigitCount], partialValue);
+                }
+
+                int trailingZeroCount = numberScale - totalDigitCount;
+
+                while (trailingZeroCount >= MaxPartialDigits)
                 {
-                    value += number.digits[cur++] - '0';
+                    MultiplyAdd(ref currentBuffer, TenPowMaxPartial, 0);
+                    trailingZeroCount -= MaxPartialDigits;
                 }
+
+                if (trailingZeroCount > 0)
+                {
+                    MultiplyAdd(ref currentBuffer, s_uint32PowersOfTen[trailingZeroCount], 0);
+                }
+
+                int sign;
+                uint[]? bits;
+
+                if (currentBufferSize == 0)
+                {
+                    sign = 0;
+                    bits = null;
+                }
+                else if (currentBufferSize == 1 && currentBuffer[0] <= int.MaxValue)
+                {
+                    sign = (int)(number.sign ? -currentBuffer[0] : currentBuffer[0]);
+                    bits = null;
+                }
+                else
+                {
+                    sign = number.sign ? -1 : 1;
+                    bits = currentBuffer.Slice(0, currentBufferSize).ToArray();
+                }
+
+                result = new BigInteger(sign, bits);
+                return true;
             }
-            while (number.digits[cur] != '\0')
+            finally
             {
-                if (number.digits[cur++] != '0') return false; // Disallow non-zero trailing decimal places
+                if (arrayFromPool != null)
+                {
+                    ArrayPool<int>.Shared.Return(arrayFromPool);
+                }
+            }
+
+            bool ProcessChunk(ReadOnlySpan<char> chunkDigits, ref Span<uint> currentBuffer)
+            {
+                int remainingIntDigitCount = Math.Max(numberScale - totalDigitCount, 0);
+                ReadOnlySpan<char> intDigitsSpan = chunkDigits.Slice(0, Math.Min(remainingIntDigitCount, chunkDigits.Length));
+
+                bool endReached = false;
+
+                // Storing these captured variables in locals for faster access in the loop.
+                uint _partialValue = partialValue;
+                int _partialDigitCount = partialDigitCount;
+                int _totalDigitCount = totalDigitCount;
+
+                for (int i = 0; i < intDigitsSpan.Length; i++)
+                {
+                    char digitChar = chunkDigits[i];
+                    if (digitChar == '\0')
+                    {
+                        endReached = true;
+                        break;
+                    }
+
+                    _partialValue = _partialValue * 10 + (uint)(digitChar - '0');
+                    _partialDigitCount++;
+                    _totalDigitCount++;
+
+                    // Update the buffer when enough partial digits have been accumulated.
+                    if (_partialDigitCount == MaxPartialDigits)
+                    {
+                        MultiplyAdd(ref currentBuffer, TenPowMaxPartial, _partialValue);
+                        _partialValue = 0;
+                        _partialDigitCount = 0;
+                    }
+                }
+
+                // Check for nonzero digits after the decimal point.
+                if (!endReached)
+                {
+                    ReadOnlySpan<char> fracDigitsSpan = chunkDigits.Slice(intDigitsSpan.Length);
+                    for (int i = 0; i < fracDigitsSpan.Length; i++)
+                    {
+                        char digitChar = fracDigitsSpan[i];
+                        if (digitChar == '\0')
+                        {
+                            break;
+                        }
+                        if (digitChar != '0')
+                        {
+                            return false;
+                        }
+                    }
+                }
+
+                partialValue = _partialValue;
+                partialDigitCount = _partialDigitCount;
+                totalDigitCount = _totalDigitCount;
+
+                return true;
             }
-            if (number.sign)
+
+            void MultiplyAdd(ref Span<uint> currentBuffer, uint multiplier, uint addValue)
             {
-                value = -value;
+                Span<uint> curBits = currentBuffer.Slice(0, currentBufferSize);
+                uint carry = addValue;
+
+                for (int i = 0; i < curBits.Length; i++)
+                {
+                    ulong p = (ulong)multiplier * curBits[i] + carry;
+                    curBits[i] = (uint)p;
+                    carry = (uint)(p >> 32);
+                }
+
+                if (carry == 0)
+                {
+                    return;
+                }
+
+                if (currentBufferSize == currentBuffer.Length)
+                {
+                    int[]? arrayToReturn = arrayFromPool;
+
+                    arrayFromPool = ArrayPool<int>.Shared.Rent(checked(currentBufferSize * 2));
+                    Span<uint> newBuffer = MemoryMarshal.Cast<int, uint>(arrayFromPool);
+                    currentBuffer.CopyTo(newBuffer);
+                    currentBuffer = newBuffer;
+
+                    if (arrayToReturn != null)
+                    {
+                        ArrayPool<int>.Shared.Return(arrayToReturn);
+                    }
+                }
+
+                currentBuffer[currentBufferSize] = carry;
+                currentBufferSize++;
             }
-            return true;
         }
 
         // This function is consistent with VM\COMNumber.cpp!COMNumber::ParseFormatSpecifier