Avoid most of the (small) regressions for seeded and derived Random instances (#57530)
authorStephen Toub <stoub@microsoft.com>
Tue, 17 Aug 2021 15:07:31 +0000 (11:07 -0400)
committerGitHub <noreply@github.com>
Tue, 17 Aug 2021 15:07:31 +0000 (11:07 -0400)
* Split Random compat implementation to address most regressions

When we introduced the new Random algorithm, we did so by factoring the old algorithm out into an implementation strategy class that is instantiated for all use other than `new Random()`.  This ends up penalizing other uses (providing a seed and/or deriving from Random) by adding more virtual dispatch than is strictly necessary, in particular for `new Random(seed)`.  This PR negates most of that (expected) regression by splitting the compat implementation in two, one class for `new Random(seed)` and one for `new DerivedRandom()`/`new DerivedRandom(seed)`; the former no longer needs to make virtual calls back out to the parent type.  The former is also one that a consumer can't really do anything to improve, whereas in the derived case, the derivation may override to provide a more optimal implementation.

* Change NextUInt64 implementation to reduce number of calls

We haven't shipped this yet, so we can change its implementation to make 3 calls instead of 8 and to delegate to a different overload of Next.

* Consolidate NextBytes to use span

src/libraries/System.Private.CoreLib/src/System/Random.Net5CompatImpl.cs
src/libraries/System.Private.CoreLib/src/System/Random.cs
src/libraries/System.Runtime.Extensions/tests/System/Random.cs

index 250c84c..c49a7d2 100644 (file)
 
 using System.Diagnostics;
 using System.Numerics;
+using System.Runtime.CompilerServices;
 
 namespace System
 {
     public partial class Random
     {
         /// <summary>
-        /// Provides an implementation used for compatibility with cases where either a) the
-        /// sequence of numbers could be predicted based on the algorithm employed historically and
-        /// thus expected (e.g. a specific seed used in tests) or b) where a derived type may
-        /// reasonably expect its overrides to be called.  The algorithm is based on a modified version
-        /// of Knuth's subtractive random number generator algorithm.  See https://github.com/dotnet/runtime/issues/23198
-        /// for a discussion of some of the modifications / discrepancies.
+        /// Provides an implementation used for compatibility with cases where a seed is specified
+        /// and thus the sequence produced historically could have been relied upon.
         /// </summary>
-        private sealed class Net5CompatImpl : ImplBase
+        private sealed class Net5CompatSeedImpl : ImplBase
         {
-            /// <summary>Thread-static instance used to seed any legacy implementations created with the default ctor.</summary>
-            [ThreadStatic]
-            private static XoshiroImpl? t_seedGenerator;
+            private CompatPrng _prng; // mutable struct; do not make this readonly
 
-            /// <summary>Reference to the <see cref="Random"/> containing this implementation instance.</summary>
-            /// <remarks>Used to ensure that any calls to other virtual members are performed using the Random-derived instance, if one exists.</remarks>
-            private readonly Random _parent;
-            private readonly int[] _seedArray;
-            private int _inext;
-            private int _inextp;
+            public Net5CompatSeedImpl(int seed) =>
+                _prng = new CompatPrng(seed);
 
-            public Net5CompatImpl(Random parent) : this(parent, (t_seedGenerator ??= new()).Next())
-            {
-            }
+            public override double Sample() => _prng.Sample();
 
-            public Net5CompatImpl(Random parent, int Seed)
-            {
-                _parent = parent;
+            public override int Next() => _prng.InternalSample();
 
-                // Initialize seed array.
-                int[] seedArray = _seedArray = new int[56];
+            public override int Next(int maxValue) => (int)(_prng.Sample() * maxValue);
 
-                int subtraction = (Seed == int.MinValue) ? int.MaxValue : Math.Abs(Seed);
-                int mj = 161803398 - subtraction; // magic number based on Phi (golden ratio)
-                seedArray[55] = mj;
-                int mk = 1;
+            public override int Next(int minValue, int maxValue)
+            {
+                long range = (long)maxValue - minValue;
+                return range <= int.MaxValue ?
+                    (int)(_prng.Sample() * range) + minValue :
+                    (int)((long)(_prng.GetSampleForLargeRange() * range) + minValue);
+            }
 
-                int ii = 0;
-                for (int i = 1; i < 55; i++)
+            public override long NextInt64()
+            {
+                while (true)
                 {
-                    // The range [1..55] is special (Knuth) and so we're wasting the 0'th position.
-                    if ((ii += 21) >= 55)
+                    // Get top 63 bits to get a value in the range [0, long.MaxValue], but try again
+                    // if the value is actually long.MaxValue, as the method is defined to return a value
+                    // in the range [0, long.MaxValue).
+                    ulong result = NextUInt64() >> 1;
+                    if (result != long.MaxValue)
                     {
-                        ii -= 55;
+                        return (long)result;
                     }
+                }
+            }
 
-                    seedArray[ii] = mk;
-                    mk = mj - mk;
-                    if (mk < 0)
-                    {
-                        mk += int.MaxValue;
-                    }
+            public override long NextInt64(long maxValue) => NextInt64(0, maxValue);
 
-                    mj = seedArray[ii];
-                }
+            public override long NextInt64(long minValue, long maxValue)
+            {
+                ulong exclusiveRange = (ulong)(maxValue - minValue);
 
-                for (int k = 1; k < 5; k++)
+                if (exclusiveRange > 1)
                 {
-                    for (int i = 1; i < 56; i++)
+                    // Narrow down to the smallest range [0, 2^bits] that contains maxValue - minValue
+                    // Then repeatedly generate a value in that outer range until we get one within the inner range.
+                    int bits = BitOperations.Log2Ceiling(exclusiveRange);
+                    while (true)
                     {
-                        int n = i + 30;
-                        if (n >= 55)
-                        {
-                            n -= 55;
-                        }
-
-                        seedArray[i] -= seedArray[1 + n];
-                        if (seedArray[i] < 0)
+                        ulong result = NextUInt64() >> (sizeof(long) * 8 - bits);
+                        if (result < exclusiveRange)
                         {
-                            seedArray[i] += int.MaxValue;
+                            return (long)result + minValue;
                         }
                     }
                 }
 
-                _inextp = 21;
+                Debug.Assert(minValue == maxValue || minValue + 1 == maxValue);
+                return minValue;
             }
 
-            public override double Sample() =>
-                // Including the division at the end gives us significantly improved random number distribution.
-                InternalSample() * (1.0 / int.MaxValue);
+            /// <summary>Produces a value in the range [0, ulong.MaxValue].</summary>
+            private ulong NextUInt64() =>
+                 ((ulong)(uint)Next(1 << 22)) |
+                (((ulong)(uint)Next(1 << 22)) << 22) |
+                (((ulong)(uint)Next(1 << 20)) << 44);
+
+            public override double NextDouble() => _prng.Sample();
+
+            public override float NextSingle() => (float)_prng.Sample();
 
-            public override int Next() => InternalSample();
+            public override void NextBytes(byte[] buffer) => _prng.NextBytes(buffer);
+
+            public override void NextBytes(Span<byte> buffer) => _prng.NextBytes(buffer);
+        }
+
+        /// <summary>
+        /// Provides an implementation used for compatibility with cases where a derived type may
+        /// reasonably expect its overrides to be called.
+        /// </summary>
+        private sealed class Net5CompatDerivedImpl : ImplBase
+        {
+            /// <summary>Reference to the <see cref="Random"/> containing this implementation instance.</summary>
+            /// <remarks>Used to ensure that any calls to other virtual members are performed using the Random-derived instance, if one exists.</remarks>
+            private readonly Random _parent;
+            /// <summary>Potentially lazily-initialized algorithm backing this instance.</summary>
+            private CompatPrng _prng; // mutable struct; do not make this readonly
+
+            public Net5CompatDerivedImpl(Random parent) : this(parent, Shared.Next()) { }
+
+            public Net5CompatDerivedImpl(Random parent, int seed)
+            {
+                _parent = parent;
+                _prng = new CompatPrng(seed);
+            }
+
+            public override double Sample() => _prng.Sample();
+
+            public override int Next() => _prng.InternalSample();
 
             public override int Next(int maxValue) => (int)(_parent.Sample() * maxValue);
 
@@ -98,7 +120,7 @@ namespace System
                 long range = (long)maxValue - minValue;
                 return range <= int.MaxValue ?
                     (int)(_parent.Sample() * range) + minValue :
-                    (int)((long)(GetSampleForLargeRange() * range) + minValue);
+                    (int)((long)(_prng.GetSampleForLargeRange() * range) + minValue);
             }
 
             public override long NextInt64()
@@ -142,34 +164,102 @@ namespace System
             }
 
             /// <summary>Produces a value in the range [0, ulong.MaxValue].</summary>
-            private unsafe ulong NextUInt64()
-            {
-                Span<byte> resultBytes = stackalloc byte[8];
-                NextBytes(resultBytes);
-                return BitConverter.ToUInt64(resultBytes);
-            }
+            private unsafe ulong NextUInt64() =>
+                 ((ulong)(uint)_parent.Next(1 << 22)) |
+                (((ulong)(uint)_parent.Next(1 << 22)) << 22) |
+                (((ulong)(uint)_parent.Next(1 << 20)) << 44);
 
             public override double NextDouble() => _parent.Sample();
 
             public override float NextSingle() => (float)_parent.Sample();
 
-            public override void NextBytes(byte[] buffer)
+            public override void NextBytes(byte[] buffer) => _prng.NextBytes(buffer);
+
+            public override void NextBytes(Span<byte> buffer)
             {
                 for (int i = 0; i < buffer.Length; i++)
                 {
-                    buffer[i] = (byte)InternalSample();
+                    buffer[i] = (byte)_parent.Next();
                 }
             }
+        }
 
-            public override void NextBytes(Span<byte> buffer)
+        /// <summary>
+        /// Implementation used for compatibility with previous releases. The algorithm is based on a modified version
+        /// of Knuth's subtractive random number generator algorithm.  See https://github.com/dotnet/runtime/issues/23198
+        /// for a discussion of some of the modifications / discrepancies.
+        /// </summary>
+        private struct CompatPrng
+        {
+            private int[] _seedArray;
+            private int _inext;
+            private int _inextp;
+
+            public CompatPrng(int seed)
+            {
+                // Initialize seed array.
+                int[] seedArray = new int[56];
+
+                int subtraction = (seed == int.MinValue) ? int.MaxValue : Math.Abs(seed);
+                int mj = 161803398 - subtraction; // magic number based on Phi (golden ratio)
+                seedArray[55] = mj;
+                int mk = 1;
+
+                int ii = 0;
+                for (int i = 1; i < 55; i++)
+                {
+                    // The range [1..55] is special (Knuth) and so we're wasting the 0'th position.
+                    if ((ii += 21) >= 55)
+                    {
+                        ii -= 55;
+                    }
+
+                    seedArray[ii] = mk;
+                    mk = mj - mk;
+                    if (mk < 0)
+                    {
+                        mk += int.MaxValue;
+                    }
+
+                    mj = seedArray[ii];
+                }
+
+                for (int k = 1; k < 5; k++)
+                {
+                    for (int i = 1; i < 56; i++)
+                    {
+                        int n = i + 30;
+                        if (n >= 55)
+                        {
+                            n -= 55;
+                        }
+
+                        seedArray[i] -= seedArray[1 + n];
+                        if (seedArray[i] < 0)
+                        {
+                            seedArray[i] += int.MaxValue;
+                        }
+                    }
+                }
+
+                _seedArray = seedArray;
+                _inext = 0;
+                _inextp = 21;
+            }
+
+            internal double Sample() =>
+                // Including the division at the end gives us significantly improved random number distribution.
+                InternalSample() * (1.0 / int.MaxValue);
+
+            internal void NextBytes(Span<byte> buffer)
             {
                 for (int i = 0; i < buffer.Length; i++)
                 {
-                    buffer[i] = (byte)_parent.Next();
+                    buffer[i] = (byte)InternalSample();
                 }
             }
 
-            private int InternalSample()
+            internal int InternalSample()
             {
                 int locINext = _inext;
                 if (++locINext >= 56)
@@ -202,7 +292,7 @@ namespace System
                 return retVal;
             }
 
-            private double GetSampleForLargeRange()
+            internal double GetSampleForLargeRange()
             {
                 // The distribution of the double returned by Sample is not good enough for a large range.
                 // If we use Sample for a range [int.MinValue..int.MaxValue), we will end up getting even numbers only.
index 396af2e..05766c7 100644 (file)
@@ -27,7 +27,7 @@ namespace System
             // With no seed specified, if this is the base type, we can implement this however we like.
             // If it's a derived type, for compat we respect the previous implementation, so that overrides
             // are called as they were previously.
-            _impl = GetType() == typeof(Random) ? new XoshiroImpl() : new Net5CompatImpl(this);
+            _impl = GetType() == typeof(Random) ? new XoshiroImpl() : new Net5CompatDerivedImpl(this);
 
         /// <summary>Initializes a new instance of the Random class, using the specified seed value.</summary>
         /// <param name="Seed">
@@ -35,9 +35,10 @@ namespace System
         /// is specified, the absolute value of the number is used.
         /// </param>
         public Random(int Seed) =>
-            // With a custom seed, for compat we respect the previous implementation so that the same sequence
-            // previously output continues to be output.
-            _impl = new Net5CompatImpl(this, Seed);
+            // With a custom seed, if this is the base Random class, we still need to respect the same algorithm that's been
+            // used in the past, but we can do so without having to deal with calling the right overrides in a derived type.
+            // If this is a derived type, we need to handle always using the same overrides we've done previously.
+            _impl = GetType() == typeof(Random) ? new Net5CompatSeedImpl(Seed) : new Net5CompatDerivedImpl(this, Seed);
 
         /// <summary>Constructor used by <see cref="ThreadSafeRandom"/>.</summary>
         /// <param name="isThreadSafeRandom">Must be true.</param>
index 5a5d95b..55f47da 100644 (file)
@@ -259,6 +259,40 @@ namespace System.Tests
             }
         }
 
+        [Fact]
+        public void ExpectedValues_DerivedSeededMatchesBaseSeeded()
+        {
+            for (int i = 0; i < 10; i++)
+            {
+                int seed = Random.Shared.Next();
+
+                var baseSeeded = new Random(seed);
+                var derivedSeeded = new SubRandom(seed);
+
+                byte[] baseBuffer = new byte[42];
+                byte[] derivedBuffer = new byte[42];
+
+                Assert.Equal(baseSeeded.Next(), derivedSeeded.Next());
+                Assert.Equal(baseSeeded.Next(42), derivedSeeded.Next(42));
+                Assert.Equal(baseSeeded.Next(1, 42), derivedSeeded.Next(1, 42));
+
+                Assert.Equal(baseSeeded.NextInt64(), derivedSeeded.NextInt64());
+                Assert.Equal(baseSeeded.NextInt64(42), derivedSeeded.NextInt64(42));
+                Assert.Equal(baseSeeded.NextInt64(int.MaxValue, long.MaxValue), derivedSeeded.NextInt64(int.MaxValue, long.MaxValue));
+
+                Assert.Equal(baseSeeded.NextDouble(), derivedSeeded.NextDouble());
+                Assert.Equal(baseSeeded.NextSingle(), derivedSeeded.NextSingle());
+
+                baseSeeded.NextBytes(baseBuffer);
+                derivedSeeded.NextBytes(derivedBuffer);
+                AssertExtensions.SequenceEqual(baseBuffer, derivedBuffer);
+
+                baseSeeded.NextBytes((Span<byte>)baseBuffer);
+                derivedSeeded.NextBytes((Span<byte>)derivedBuffer);
+                AssertExtensions.SequenceEqual(baseBuffer, derivedBuffer);
+            }
+        }
+
         [Theory]
         [InlineData(false)]
         [InlineData(true)]
@@ -304,6 +338,45 @@ namespace System.Tests
         [Theory]
         [InlineData(false)]
         [InlineData(true)]
+        public void ExpectedValues_Next64(bool derived)
+        {
+            long[][] expectedValues = new long[][]
+            {
+                new long[] { 7083764782846131554, 5154802594370149492, 9016307100457696812, 4310211293818176047, 9058748489721462462, 9180463484486351044, 7534648089071954807, 300923067154768701, 8614801378171577836, 748103725286293756, },
+                new long[] { 4307412542716114199, 3991716777541808379, 934021439863608487, 2287661736829161214, 6291475812869357143, 5675567160283690199, 8760158359188310330, 3523056573073259785, 7321359148939577545, 8176239586367362256, },
+                new long[] { 1531069098683313357, 2828622164618347905, 2075099020033370914, 265112179838049230, 3524203136017251824, 2170679632171954410, 762296592449890045, 6745207671181989590, 6027925715800599462, 6380994614502729892, },
+                new long[] { 7978097691503191170, 1665536347787909639, 3216194192387080605, 7465934659703810205, 756930459165146506, 7889164140914994429, 1987815658659267775, 743986732433846434, 4734483486568599171, 4585758438729022584, },
+                new long[] { 5201754247470390327, 502450530957471373, 4357271772556843032, 5443376306619676012, 7213038615260839203, 4384276612803258641, 3213334724868645506, 3966129034447456878, 3441041257334501728, 2790513466864390220, },
+                new long[] { 2425402007342470124, 8562736750981808916, 5498366944908455571, 3420835545721586235, 4445757142315711677, 879397880784545060, 4438844994985001029, 7188280132554089531, 2147607824197620797, 995277291092780064, },
+                new long[] { 8872430600164445090, 7399650934153467802, 6639444525078217998, 1398277192639549194, 1678493261558725718, 6597873593434562871, 5664364061194378760, 1187050397712924167, 854156798872598297, 8423404356080826356, },
+                new long[] { 6096087156129547095, 6236565117320932384, 7780530901338905481, 8599090876412287961, 8134592621559299055, 3092986065322827082, 6889874331310734282, 4409192699726534611, 8784086606495373814, 6628168180309216200, },
+                new long[] { 3319734916003724044, 5073479300492591271, 8921617277601690116, 6576541319421175976, 5367319944707193737, 8811470574065867102, 8115384601427089805, 7631343797835264416, 7490653173356395731, 4832923208444583836, },
+                new long[] { 543391471966728898, 3910384687569130797, 839331617009698943, 4553991762430063991, 2600047267857185570, 5306583045956228465, 117531630781691728, 1630122859085024108, 6197210944124395440, 3037678236577854320, },
+                new long[] { 6990420064788703863, 2747298870740789683, 1980409197177364218, 2531442205441049158, 9056155423952878267, 1801695517842395524, 1343041900898047250, 4852273957193753912, 4903777510985417357, 1242442060806244164, },
+                new long[] { 4214067824660783660, 1584213053910351417, 3121504369533171061, 508883852356914965, 6288873951005653589, 7520180026587532695, 2568560967107424981, 8074416259207364357, 3610335281753417066, 8670577921889409817, },
+                new long[] { 1437724380627982818, 421127237079913152, 4262581949700836336, 7709706332222675941, 3521610070248667630, 4015301294568819115, 3794080033314705560, 2073186524366198993, 2316893052521416775, 6875332950022680301, },
+                new long[] { 7884752973447860631, 8481404661011228486, 5403668325963620971, 5687156775231563956, 754337393396562311, 510404970364061118, 5019590303431061083, 5295337622472831645, 1023459619382438692, 5080087978158047937, },
+                new long[] { 5108409529412962636, 7318318844180790220, 6544754702222211302, 3664607218240451971, 7210445549492255009, 6228889479107101137, 6245109369640438813, 8517479924486442090, 8953389427003117057, 3284851802384340629, },
+                new long[] { 2332057289287139585, 6155241823445471314, 7685841078484995937, 1642048865158414930, 4443164076547127482, 2724001950995365348, 7470610843663772128, 2516258985740396086, 7659947197771116765, 1489615626614827625, },
+                new long[] { 8779085882107017399, 4992147210522010841, 8826927454745683420, 8842862548929056545, 1675900195788044372, 8442486459738405368, 8696129909873149859, 5738410083847028739, 6366513764632138682, 8917742691602873917, },
+                new long[] { 6002742438074216556, 3829061393691572575, 744641794153692247, 6820321788033063920, 8131999555790714861, 4937598931626669579, 698276939227751781, 8960552385858542031, 5073062739307116183, 7122497719736144401, },
+                new long[] { 3226390197944199201, 2665966780768112101, 1885719374323454674, 4797763434951026879, 5364726878938609542, 1432711403514933790, 1923787209344107304, 2959322651019473819, 3779620510077213044, 5327261543966631397, },
+                new long[] { 450046753911398359, 1502889760032793195, 3026814546677164365, 2775213877957817742, 2597454202086504224, 7151195912260070961, 3149306275553485035, 6181473749126106472, 2486187076938234961, 3532016572099901881, },
+            };
+
+            for (int seed = 0; seed < expectedValues.Length; seed++)
+            {
+                Random r = derived ? new SubRandom(seed) : new Random(seed);
+                for (int i = 0; i < expectedValues[seed].Length; i++)
+                {
+                    Assert.Equal(expectedValues[seed][i], r.NextInt64());
+                }
+            }
+        }
+
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
         public void ExpectedValues_NextBytes(bool derived)
         {
             byte[][] expectedValues = new byte[][]
@@ -367,10 +440,12 @@ namespace System.Tests
         [InlineData(true)]
         public void SampleOrNext_DerivedOverrideCalledWhereExpected(bool seeded)
         {
-            SubRandom r;
-
-            r = seeded ? new SubRandom(42) : new SubRandom();
+            // Validate our test Called state starts as false
+            SubRandom r = seeded ? new SubRandom(42) : new SubRandom();
             Assert.False(r.SampleCalled);
+            Assert.False(r.NextCalled);
+
+            // Validate the right Called is true where expected
 
             foreach (int maxValue in new[] { 0, 1, 42 })
             {
@@ -390,14 +465,14 @@ namespace System.Tests
             {
                 r = seeded ? new SubRandom(42) : new SubRandom();
                 r.NextInt64(maxValue);
-                Assert.True(r.NextCalled);
+                Assert.True(r.SampleCalled);
             }
 
             foreach ((long minValue, long maxValue) in new[] { (42L, 47L), ((long)int.MaxValue + 1, long.MaxValue) })
             {
                 r = seeded ? new SubRandom(42) : new SubRandom();
                 r.NextInt64(minValue, maxValue);
-                Assert.True(r.NextCalled);
+                Assert.True(r.SampleCalled);
             }
 
             r = seeded ? new SubRandom(42) : new SubRandom();
@@ -413,17 +488,9 @@ namespace System.Tests
             Assert.True(r.NextCalled);
 
             // Next was changed to not call Sample in .NET Framework 2.0.
-            // NextBytes(byte[]) just uses Next.
-            // And NextInt64 uses NextBytes(byte[]).
-            // NextInt64(long{, long}) will use Next, but not if the range is such that the min will always be returned.
-
+            // NextBytes((Span<byte>)) just uses Next().
             r = seeded ? new SubRandom(42) : new SubRandom();
             r.Next();
-            r.NextInt64();
-            r.NextInt64(0);
-            r.NextInt64(1);
-            r.NextInt64(0, 0);
-            r.NextInt64(0, 1);
             r.NextBytes((Span<byte>)new byte[1]);
             Assert.False(r.SampleCalled);
         }