Address bugs in BigInteger (dotnet/coreclr#27280)
authorts2do <tsdodo@gmail.com>
Mon, 11 Nov 2019 23:35:28 +0000 (17:35 -0600)
committerAnton Lapounov <antonl@microsoft.com>
Mon, 11 Nov 2019 23:35:28 +0000 (15:35 -0800)
* Method Add(ref BigInteger lhs, uint value, ref BigInteger result) would store most of the result blocks into lhs instead of result.
* Method ShiftLeft(ulong input, uint shift, ref BigInteger output) with a shift argument exceeding 32 would generally compute the higher blocks incorrectly.
* Multiply(ref BigInteger lhs, uint value, ref BigInteger result) would not set result._length in some cases.
* IsZero() would incorrectly return false for non-canonical zeros with _length > 0.

Fix:
* Inline Add(ref BigInteger, uint, ref BigInteger) into Add(uint).
* Inline ShiftLeft(ulong, uint, ref BigInteger) into Pow2.
* Inline ExtendBlock and ExtendBlocks into Pow2.
* Properly handle 0 in SetUInt32 and SetUInt64.

Commit migrated from https://github.com/dotnet/coreclr/commit/65a7947491685a6a3479525334e4178f8eab4b75

src/libraries/System.Private.CoreLib/src/System/Number.BigInteger.cs
src/libraries/System.Private.CoreLib/src/System/Number.Dragon4.cs

index d4058c3..8c19c99 100644 (file)
@@ -332,41 +332,6 @@ namespace System
                 _length = (upper == 0) ? 1 : 2;
             }
 
-            public static void Add(ref BigInteger lhs, uint value, ref BigInteger result)
-            {
-                if (lhs.IsZero())
-                {
-                    result.SetUInt32(value);
-                    return;
-                }
-
-                if (value == 0)
-                {
-                    result.SetValue(ref lhs);
-                    return;
-                }
-
-                int lhsLength = lhs._length;
-                int index = 0;
-                uint carry = value;
-
-                while (index < lhsLength)
-                {
-                    ulong sum = (ulong)(lhs._blocks[index]) + carry;
-                    lhs._blocks[index] = (uint)(sum);
-                    carry = (uint)(sum >> 32);
-
-                    index++;
-                }
-
-                if (carry != 0)
-                {
-                    Debug.Assert(unchecked((uint)(lhsLength)) + 1 <= MaxBlockCount);
-                    result._blocks[index] = carry;
-                    result._length = (lhsLength + 1);
-                }
-            }
-
             public static void Add(ref BigInteger lhs, ref BigInteger rhs, out BigInteger result)
             {
                 // determine which operand has the smaller length
@@ -771,6 +736,10 @@ namespace System
                     result._blocks[index] = carry;
                     result._length = (lhsLength + 1);
                 }
+                else
+                {
+                    result._length = lhsLength;
+                }
             }
 
             public static void Multiply(ref BigInteger lhs, ref BigInteger rhs, ref BigInteger result)
@@ -851,8 +820,14 @@ namespace System
 
             public static void Pow2(uint exponent, out BigInteger result)
             {
-                result = new BigInteger(0);
-                ShiftLeft(1, exponent, ref result);
+                uint blocksToShift = DivRem32(exponent, out uint remainingBitsToShift);
+                result._length = (int)blocksToShift + 1;
+                Debug.Assert(unchecked((uint)result._length) <= MaxBlockCount);
+                if (blocksToShift > 0)
+                {
+                    Buffer.ZeroMemory((byte*)result.GetBlocksPointer(), blocksToShift * sizeof(uint));
+                }
+                result._blocks[blocksToShift] = 1U << (int)(remainingBitsToShift);
             }
 
             public static void Pow10(uint exponent, out BigInteger result)
@@ -923,57 +898,6 @@ namespace System
                 result.SetValue(ref lhs);
             }
 
-            public static void ShiftLeft(ulong input, uint shift, ref BigInteger output)
-            {
-                if (shift == 0)
-                {
-                    return;
-                }
-
-                uint blocksToShift = Math.DivRem(shift, 32, out uint remainingBitsToShift);
-
-                if (blocksToShift > 0)
-                {
-                    // If blocks shifted, we should fill the corresponding blocks with zero.
-                    output.ExtendBlocks(0, blocksToShift);
-                }
-
-                if (remainingBitsToShift == 0)
-                {
-                    // We shift 32 * n (n >= 1) bits. No remaining bits.
-                    output.ExtendBlock((uint)(input));
-
-                    uint highBits = (uint)(input >> 32);
-
-                    if (highBits != 0)
-                    {
-                        output.ExtendBlock(highBits);
-                    }
-                }
-                else
-                {
-                    // Extract the high position bits which would be shifted out of range.
-                    uint highPositionBits = (uint)(input) >> (int)(64 - remainingBitsToShift);
-
-                    // Shift the input. The result should be stored to current block.
-                    ulong shiftedInput = input << (int)(remainingBitsToShift);
-                    output.ExtendBlock((uint)(shiftedInput));
-
-                    uint highBits = (uint)(input >> 32);
-
-                    if (highBits != 0)
-                    {
-                        output.ExtendBlock(highBits);
-                    }
-
-                    if (highPositionBits != 0)
-                    {
-                        // If the high position bits is not 0, we should store them to next block.
-                        output.ExtendBlock(highPositionBits);
-                    }
-                }
-            }
-
             private static uint AddDivisor(ref BigInteger lhs, int lhsStartIndex, ref BigInteger rhs)
             {
                 int lhsLength = lhs._length;
@@ -1065,29 +989,33 @@ namespace System
 
             public void Add(uint value)
             {
-                Add(ref this, value, ref this);
-            }
-
-            public void ExtendBlock(uint blockValue)
-            {
-                _blocks[_length] = blockValue;
-                _length++;
-            }
-
-            public void ExtendBlocks(uint blockValue, uint blockCount)
-            {
-                Debug.Assert(blockCount > 0);
-
-                if (blockCount == 1)
+                int length = _length;
+                if (length == 0)
                 {
-                    ExtendBlock(blockValue);
+                    SetUInt32(value);
+                    return;
+                }
 
+                _blocks[0] += value;
+                if (_blocks[0] >= value)
+                {
+                    // No carry
                     return;
                 }
 
-                Buffer.ZeroMemory((byte*)(GetBlocksPointer() + _length), (blockCount - 1) * sizeof(uint));
-                _length += (int)(blockCount);
-                _blocks[_length - 1] = blockValue;
+                for (int index = 1; index < length; index++)
+                {
+                    _blocks[index]++;
+                    if (_blocks[index] > 0)
+                    {
+                        // No carry
+                        return;
+                    }
+                }
+
+                Debug.Assert(unchecked((uint)(length)) + 1 <= MaxBlockCount);
+                _blocks[length] = 1;
+                _length = length + 1;
             }
 
             public uint GetBlock(uint index)
@@ -1119,11 +1047,9 @@ namespace System
 
             public void Multiply(ref BigInteger value)
             {
-                var result = new BigInteger(0);
-                Multiply(ref this, ref value, ref result);
-
-                Buffer.Memcpy((byte*)GetBlocksPointer(), (byte*)result.GetBlocksPointer(), result._length * sizeof(uint));
-                _length = result._length;
+                BigInteger temp = new BigInteger(0);
+                temp.SetValue(ref this);
+                Multiply(ref temp, ref value, ref this);
             }
 
             public void Multiply10()
@@ -1157,6 +1083,11 @@ namespace System
 
             public void MultiplyPow10(uint exponent)
             {
+                if (IsZero())
+                {
+                    return;
+                }
+
                 Pow10(exponent, out BigInteger poweredValue);
 
                 if (poweredValue._length == 1)
@@ -1171,19 +1102,30 @@ namespace System
 
             public void SetUInt32(uint value)
             {
-                _blocks[0] = value;
-                _length = 1;
+                if (value == 0)
+                {
+                    SetZero();
+                }
+                else
+                {
+                    _blocks[0] = value;
+                    _length = 1;
+                }
             }
 
             public void SetUInt64(ulong value)
             {
-                uint lower = (uint)(value);
-                uint upper = (uint)(value >> 32);
-
-                _blocks[0] = lower;
-                _blocks[1] = upper;
+                if (value <= uint.MaxValue)
+                {
+                    SetUInt32((uint)(value));
+                }
+                else
+                {
+                    _blocks[0] = (uint)(value);
+                    _blocks[1] = (uint)(value >> 32);
 
-                _length = (upper == 0) ? 1 : 2;
+                    _length = 2;
+                }
             }
 
             public void SetValue(ref BigInteger rhs)
@@ -1208,7 +1150,7 @@ namespace System
                     return;
                 }
 
-                uint blocksToShift = Math.DivRem(shift, 32, out uint remainingBitsToShift);
+                uint blocksToShift = DivRem32(shift, out uint remainingBitsToShift);
 
                 // Copy blocks from high to low
                 int readIndex = (length - 1);
@@ -1292,6 +1234,12 @@ namespace System
                 // This is safe to do since we are a ref struct
                 return (uint*)(Unsafe.AsPointer(ref _blocks[0]));
             }
+
+            private static uint DivRem32(uint value, out uint remainder)
+            {
+                remainder = value & 31;
+                return value >> 5;
+            }
         }
     }
 }
index ebd9019..d065726 100644 (file)
@@ -420,7 +420,7 @@ namespace System
                 //      compare(value, 0.5)
                 //      compare(scale * value, scale * 0.5)
                 //      compare(2 * scale * value, scale)
-                scaledValue.Multiply(2);
+                scaledValue.ShiftLeft(1); // Multiply by 2
                 int compare = BigInteger.Compare(ref scaledValue, ref scale);
                 roundDown = compare < 0;