Add ImmutableInterlocked.Update overloads for ImmutableArray (dotnet/corefx#41600)
authorStephen Toub <stoub@microsoft.com>
Tue, 8 Oct 2019 18:46:26 +0000 (14:46 -0400)
committerGitHub <noreply@github.com>
Tue, 8 Oct 2019 18:46:26 +0000 (14:46 -0400)
* Add ImmutableInterlocked.Update overloads for ImmutableArray

The implementations and tests just copy the exist Update overloads, tweaked to work with ImmutableArray.

* Address PR feedback

Commit migrated from https://github.com/dotnet/corefx/commit/9de37814fe28f497c8679b52772f2e5dd7f73306

src/libraries/System.Collections.Immutable/ref/System.Collections.Immutable.cs
src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableInterlocked.cs
src/libraries/System.Collections.Immutable/tests/ImmutableInterlockedTests.cs

index 0fd48ec..2aa625a 100644 (file)
@@ -515,6 +515,8 @@ namespace System.Collections.Immutable
         public static bool TryUpdate<TKey, TValue>(ref System.Collections.Immutable.ImmutableDictionary<TKey, TValue> location, TKey key, TValue newValue, TValue comparisonValue) { throw null; }
         public static bool Update<T>(ref T location, System.Func<T, T> transformer) where T : class { throw null; }
         public static bool Update<T, TArg>(ref T location, System.Func<T, TArg, T> transformer, TArg transformerArgument) where T : class { throw null; }
+        public static bool Update<T>(ref System.Collections.Immutable.ImmutableArray<T> location, Func<System.Collections.Immutable.ImmutableArray<T>, System.Collections.Immutable.ImmutableArray<T>> transformer) { throw null; }
+        public static bool Update<T, TArg>(ref System.Collections.Immutable.ImmutableArray<T> location, Func<System.Collections.Immutable.ImmutableArray<T>, TArg, System.Collections.Immutable.ImmutableArray<T>> transformer, TArg transformerArgument) { throw null; }
     }
     public static partial class ImmutableList
     {
index 5e64564..d668b29 100644 (file)
@@ -98,6 +98,92 @@ namespace System.Collections.Immutable
             return true;
         }
 
+        /// <summary>
+        /// Mutates an immutable array in-place with optimistic locking transaction semantics
+        /// via a specified transformation function.
+        /// The transformation is retried as many times as necessary to win the optimistic locking race.
+        /// </summary>
+        /// <typeparam name="T">The type of data in the immutable array.</typeparam>
+        /// <param name="location">
+        /// The immutable array to be changed.
+        /// </param>
+        /// <param name="transformer">
+        /// A function that produces the new array from the old. This function should be side-effect free,
+        /// as it may run multiple times when races occur with other threads.</param>
+        /// <returns>
+        /// <c>true</c> if the location's value is changed by applying the result of the
+        /// <paramref name="transformer"/> function;
+        /// <c>false</c> if the location's value remained the same because the last
+        /// invocation of <paramref name="transformer"/> returned the existing value.
+        /// </returns>
+        public static bool Update<T>(ref ImmutableArray<T> location, Func<ImmutableArray<T>, ImmutableArray<T>> transformer)
+        {
+            Requires.NotNull(transformer, nameof(transformer));
+
+            bool successful;
+            T[] oldArray = Volatile.Read(ref location.array);
+            do
+            {
+                ImmutableArray<T> newImmutableArray = transformer(new ImmutableArray<T>(oldArray));
+                if (ReferenceEquals(oldArray, newImmutableArray.array))
+                {
+                    // No change was actually required.
+                    return false;
+                }
+
+                T[] interlockedResult = Interlocked.CompareExchange(ref location.array, newImmutableArray.array, oldArray);
+                successful = ReferenceEquals(oldArray, interlockedResult);
+                oldArray = interlockedResult; // we already have a volatile read that we can reuse for the next loop
+            }
+            while (!successful);
+
+            return true;
+        }
+
+        /// <summary>
+        /// Mutates an immutable array in-place with optimistic locking transaction semantics
+        /// via a specified transformation function.
+        /// The transformation is retried as many times as necessary to win the optimistic locking race.
+        /// </summary>
+        /// <typeparam name="T">The type of data in the immutable array.</typeparam>
+        /// <typeparam name="TArg">The type of argument passed to the <paramref name="transformer"/>.</typeparam>
+        /// <param name="location">
+        /// The immutable array to be changed.
+        /// </param>
+        /// <param name="transformer">
+        /// A function that produces the new array from the old. This function should be side-effect free,
+        /// as it may run multiple times when races occur with other threads.</param>
+        /// <param name="transformerArgument">The argument to pass to <paramref name="transformer"/>.</param>
+        /// <returns>
+        /// <c>true</c> if the location's value is changed by applying the result of the
+        /// <paramref name="transformer"/> function;
+        /// <c>false</c> if the location's value remained the same because the last
+        /// invocation of <paramref name="transformer"/> returned the existing value.
+        /// </returns>
+        public static bool Update<T, TArg>(ref ImmutableArray<T> location, Func<ImmutableArray<T>, TArg, ImmutableArray<T>> transformer, TArg transformerArgument)
+        {
+            Requires.NotNull(transformer, nameof(transformer));
+
+            bool successful;
+            T[] oldArray = Volatile.Read(ref location.array);
+            do
+            {
+                ImmutableArray<T> newImmutableArray = transformer(new ImmutableArray<T>(oldArray), transformerArgument);
+                if (ReferenceEquals(oldArray, newImmutableArray.array))
+                {
+                    // No change was actually required.
+                    return false;
+                }
+
+                T[] interlockedResult = Interlocked.CompareExchange(ref location.array, newImmutableArray.array, oldArray);
+                successful = ReferenceEquals(oldArray, interlockedResult);
+                oldArray = interlockedResult; // we already have a volatile read that we can reuse for the next loop
+            }
+            while (!successful);
+
+            return true;
+        }
+
         #region ImmutableArray<T> members
 
         /// <summary>
index e305023..90c3e2c 100644 (file)
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
@@ -12,6 +11,7 @@ namespace System.Collections.Immutable.Tests
     public class ImmutableInterlockedTests
     {
         private delegate bool UpdateDelegate<T>(ref T location, Func<T, T> transformer);
+        private delegate bool UpdateArrayDelegate<T>(ref ImmutableArray<T> location, Func<ImmutableArray<T>, ImmutableArray<T>> transformer);
 
         [Fact]
         public void Update_StartWithNull()
@@ -26,6 +26,30 @@ namespace System.Collections.Immutable.Tests
         }
 
         [Fact]
+        public void UpdateArray_StartWithDefault()
+        {
+            UpdateArrayHelper<int>(func =>
+            {
+                ImmutableArray<int> array = default;
+                Assert.True(func(ref array, l => { Assert.Equal(default, l); return ImmutableArray.Create(1); }));
+                Assert.Equal(1, array.Length);
+                Assert.Equal(1, array[0]);
+            });
+        }
+
+        [Fact]
+        public void UpdateArray_StartWithEmpty()
+        {
+            UpdateArrayHelper<int>(func =>
+            {
+                ImmutableArray<int> array = ImmutableArray<int>.Empty;
+                Assert.True(func(ref array, l => { Assert.Equal(0, l.Length); return ImmutableArray.Create(1); }));
+                Assert.Equal(1, array.Length);
+                Assert.Equal(1, array[0]);
+            });
+        }
+
+        [Fact]
         public void Update_IncrementalUpdate()
         {
             UpdateHelper<ImmutableList<int>>(func =>
@@ -39,6 +63,19 @@ namespace System.Collections.Immutable.Tests
         }
 
         [Fact]
+        public void UpdateArray_IncrementalUpdate()
+        {
+            UpdateArrayHelper<int>(func =>
+            {
+                ImmutableArray<int> array = ImmutableArray.Create(1);
+                Assert.True(func(ref array, l => l.Add(2)));
+                Assert.Equal(2, array.Length);
+                Assert.Equal(1, array[0]);
+                Assert.Equal(2, array[1]);
+            });
+        }
+
+        [Fact]
         public void Update_FuncThrowsThrough()
         {
             UpdateHelper<ImmutableList<int>>(func =>
@@ -49,6 +86,16 @@ namespace System.Collections.Immutable.Tests
         }
 
         [Fact]
+        public void UpdateArray_FuncThrowsThrough()
+        {
+            UpdateArrayHelper<int>(func =>
+            {
+                ImmutableArray<int> array = ImmutableArray.Create(42);
+                Assert.Throws<InvalidOperationException>(() => func(ref array, l => throw new InvalidOperationException()));
+            });
+        }
+
+        [Fact]
         public void Update_NoEffectualChange()
         {
             UpdateHelper<ImmutableList<int>>(func =>
@@ -59,6 +106,16 @@ namespace System.Collections.Immutable.Tests
         }
 
         [Fact]
+        public void UpdateArray_NoEffectualChange()
+        {
+            UpdateArrayHelper<int>(func =>
+            {
+                ImmutableArray<int> array = ImmutableArray.Create(42);
+                Assert.False(func(ref array, l => l));
+            });
+        }
+
+        [Fact]
         public void Update_HighConcurrency()
         {
             UpdateHelper<ImmutableList<int>>(func =>
@@ -70,7 +127,7 @@ namespace System.Collections.Immutable.Tests
                 var barrier = new Barrier(tasks.Length);
                 for (int i = 0; i < tasks.Length; i++)
                 {
-                    tasks[i] = Task.Run(delegate
+                    tasks[i] = Task.Factory.StartNew(delegate
                     {
                         // Maximize concurrency by blocking this thread until all the other threads are ready to go as well.
                         barrier.SignalAndWait();
@@ -79,7 +136,7 @@ namespace System.Collections.Immutable.Tests
                         {
                             Assert.True(func(ref list, l => l.Add(l.Count)));
                         }
-                    });
+                    }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default);
                 }
 
                 Task.WaitAll(tasks);
@@ -92,6 +149,39 @@ namespace System.Collections.Immutable.Tests
         }
 
         [Fact]
+        public void UpdateArray_HighConcurrency()
+        {
+            UpdateArrayHelper<int>(func =>
+            {
+                ImmutableArray<int> array = ImmutableArray.Create<int>();
+                int concurrencyLevel = Environment.ProcessorCount;
+                int iterations = 500;
+                Task[] tasks = new Task[concurrencyLevel];
+                var barrier = new Barrier(tasks.Length);
+                for (int i = 0; i < tasks.Length; i++)
+                {
+                    tasks[i] = Task.Factory.StartNew(delegate
+                    {
+                        // Maximize concurrency by blocking this thread until all the other threads are ready to go as well.
+                        barrier.SignalAndWait();
+
+                        for (int j = 0; j < iterations; j++)
+                        {
+                            Assert.True(func(ref array, l => l.Add(l.Length)));
+                        }
+                    }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default);
+                }
+
+                Task.WaitAll(tasks);
+                Assert.Equal(concurrencyLevel * iterations, array.Length);
+                for (int i = 0; i < array.Length; i++)
+                {
+                    Assert.Equal(i, array[i]);
+                }
+            });
+        }
+
+        [Fact]
         public void Update_CarefullyScheduled()
         {
             UpdateHelper<ImmutableHashSet<int>>(func =>
@@ -151,6 +241,65 @@ namespace System.Collections.Immutable.Tests
         }
 
         [Fact]
+        public void UpdateArray_CarefullyScheduled()
+        {
+            UpdateArrayHelper<int>(func =>
+            {
+                var array = ImmutableArray.Create<int>();
+                var task2TransformEntered = new AutoResetEvent(false);
+                var task1TransformExited = new AutoResetEvent(false);
+
+                var task1 = Task.Run(delegate
+                {
+                    int transform1ExecutionCounter = 0;
+                    func(
+                        ref array,
+                        s =>
+                        {
+                            Assert.Equal(1, ++transform1ExecutionCounter);
+                            task2TransformEntered.WaitOne();
+                            return s.Add(1);
+                        });
+                    task1TransformExited.Set();
+                    Assert.Equal(1, transform1ExecutionCounter);
+                });
+
+                var task2 = Task.Run(delegate
+                {
+                    int transform2ExecutionCounter = 0;
+                    func(
+                        ref array,
+                        s =>
+                        {
+                            switch (++transform2ExecutionCounter)
+                            {
+                                case 1:
+                                    task2TransformEntered.Set();
+                                    task1TransformExited.WaitOne();
+                                    Assert.True(s.IsEmpty);
+                                    break;
+                                case 2:
+                                    Assert.True(s.Contains(1));
+                                    Assert.Equal(1, s.Length);
+                                    break;
+                            }
+
+                            return s.Add(2);
+                        });
+
+                    // Verify that this transform had to execute twice.
+                    Assert.Equal(2, transform2ExecutionCounter);
+                });
+
+                // Wait for all tasks and rethrow any exceptions.
+                Task.WaitAll(task1, task2);
+                Assert.Equal(2, array.Length);
+                Assert.True(array.Contains(1));
+                Assert.True(array.Contains(2));
+            });
+        }
+
+        [Fact]
         public void InterlockedExchangeArrayDefault()
         {
             ImmutableArray<int> array = default(ImmutableArray<int>);
@@ -439,6 +588,22 @@ namespace System.Collections.Immutable.Tests
         }
 
         /// <summary>
+        /// Executes a test against both <see cref="ImmutableInterlocked.Update{T}(ref ImmutableArray{T}, Func{ImmutableArray{T}, ImmutableArray{T}})"/>
+        /// and <see cref="ImmutableInterlocked.Update{ImmutableArray{T}, TArg}(ref ImmutableArray{T}, Func{ImmutableArray{T}, TArg, ImmutableArray{T}}, TArg)"/>.
+        /// </summary>
+        /// <typeparam name="T">The type of value under test.</typeparam>
+        /// <param name="test">
+        /// The test to execute. Invoke the parameter instead of calling
+        /// the ImmutableInterlocked method so that the delegate can test both overloads
+        /// by being executed twice.
+        /// </param>
+        private static void UpdateArrayHelper<T>(Action<UpdateArrayDelegate<T>> test)
+        {
+            test(ImmutableInterlocked.Update<T>);
+            test(UpdateArrayWrapper<T>);
+        }
+
+        /// <summary>
         /// A wrapper that makes one overload look like another so the same test delegate can execute against both.
         /// </summary>
         /// <typeparam name="T">The type of value being changed.</typeparam>
@@ -457,5 +622,24 @@ namespace System.Collections.Immutable.Tests
                 },
                 1);
         }
+
+        /// <summary>
+        /// A wrapper that makes one overload look like another so the same test delegate can execute against both.
+        /// </summary>
+        /// <typeparam name="T">The type of value being changed.</typeparam>
+        /// <param name="location">The variable or field to be changed.</param>
+        /// <param name="transformer">The function that transforms the value.</param>
+        /// <returns>The result of the replacement function.</returns>
+        private static bool UpdateArrayWrapper<T>(ref ImmutableArray<T> location, Func<ImmutableArray<T>, ImmutableArray<T>> transformer)
+        {
+            return ImmutableInterlocked.Update<T, int>(
+                ref location,
+                (t, arg) =>
+                {
+                    Assert.Equal(1, arg);
+                    return transformer(t);
+                },
+                1);
+        }
     }
 }