Add CollectionsMarshal.SetCount(list, count) (#82146)
authorMichał Petryka <35800402+MichalPetryka@users.noreply.github.com>
Tue, 25 Apr 2023 01:14:10 +0000 (03:14 +0200)
committerGitHub <noreply@github.com>
Tue, 25 Apr 2023 01:14:10 +0000 (21:14 -0400)
* Add CollectionsMarshal.SetCount(list, count)

Adds the ability to resize lists, exposed in
CollectionsMarshal due to potentially risky
behaviour caused by the lack of element initialization.

Supersedes #77794.

Fixes #55217.

* Update XML doc

* Add missing using

* Fix test

* Update CollectionsMarshalTests.cs

* Update CollectionsMarshal.cs

* Update CollectionsMarshalTests.cs

* Update CollectionsMarshalTests.cs

src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs
src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/CollectionsMarshalTests.cs

index 6a60224305a6b11272f71912e825ff34baecfcf3..a54dc8405f1ab6ba5cf725203c98e817e39fcc59 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Collections.Generic;
+using System.Runtime.CompilerServices;
 
 namespace System.Runtime.InteropServices
 {
@@ -39,5 +40,40 @@ namespace System.Runtime.InteropServices
         /// <remarks>Items should not be added to or removed from the <see cref="Dictionary{TKey, TValue}"/> while the ref <typeparamref name="TValue"/> is in use.</remarks>
         public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key, out bool exists) where TKey : notnull
             => ref Dictionary<TKey, TValue>.CollectionsMarshalHelper.GetValueRefOrAddDefault(dictionary, key, out exists);
+
+        /// <summary>
+        /// Sets the count of the <see cref="List{T}"/> to the specified value.
+        /// </summary>
+        /// <param name="list">The list to set the count of.</param>
+        /// <param name="count">The value to set the list's count to.</param>
+        /// <exception cref="NullReferenceException">
+        /// <paramref name="list"/> is <see langword="null"/>.
+        /// </exception>
+        /// <exception cref="ArgumentOutOfRangeException">
+        /// <paramref name="count"/> is negative.
+        /// </exception>
+        /// <remarks>
+        /// When increasing the count, uninitialized data is being exposed.
+        /// </remarks>
+        public static void SetCount<T>(List<T> list, int count)
+        {
+            if (count < 0)
+            {
+                ThrowHelper.ThrowArgumentOutOfRangeException_NeedNonNegNum(nameof(count));
+            }
+
+            list._version++;
+
+            if (count > list.Capacity)
+            {
+                list.Grow(count);
+            }
+            else if (count < list._size && RuntimeHelpers.IsReferenceOrContainsReferences<T>())
+            {
+                Array.Clear(list._items, count, list._size - count);
+            }
+
+            list._size = count;
+        }
     }
 }
index 374d5ed8b0709e47a9501a278e4d98ae393c65d5..5f82435829d892bdcdf428f470ce8b52874fe3a8 100644 (file)
@@ -623,6 +623,7 @@ namespace System.Runtime.InteropServices
         public static System.Span<T> AsSpan<T>(System.Collections.Generic.List<T>? list) { throw null; }
         public static ref TValue GetValueRefOrNullRef<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull { throw null; }
         public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key, out bool exists) where TKey : notnull { throw null; }
+        public static void SetCount<T>(System.Collections.Generic.List<T> list, int count) { throw null; }
     }
     [System.AttributeUsageAttribute(System.AttributeTargets.Class, Inherited=false)]
     public sealed partial class ComDefaultInterfaceAttribute : System.Attribute
index 876c0681bc664879099169711d6e626c4c9f7867..8a3ec2207da4864f6f37887f518022e4f8f7505c 100644 (file)
@@ -505,5 +505,62 @@ namespace System.Runtime.InteropServices.Tests
             public int Value;
             public int Property { get; set; }
         }
+
+        [Fact]
+        public void ListSetCount()
+        {
+            List<int> list = null!;
+            Assert.Throws<NullReferenceException>(() => CollectionsMarshal.SetCount(list, 3));
+
+            Assert.Throws<ArgumentOutOfRangeException>(() => CollectionsMarshal.SetCount(list, -1));
+
+            list = new();
+            Assert.Throws<ArgumentOutOfRangeException>(() => CollectionsMarshal.SetCount(list, -1));
+
+            CollectionsMarshal.SetCount(list, 5);
+            Assert.Equal(5, list.Count);
+
+            list = new() { 1, 2, 3, 4, 5 };
+            ref int intRef = ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list));
+            // make sure that size decrease preserves content
+            CollectionsMarshal.SetCount(list, 3);
+            Assert.Equal(3, list.Count);
+            Assert.Throws<ArgumentOutOfRangeException>(() => list[3]);
+            SequenceEquals<int>(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3 });
+            Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
+
+            // make sure that size increase preserves content and doesn't clear
+            CollectionsMarshal.SetCount(list, 5);
+            SequenceEquals<int>(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3, 4, 5 });
+            Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
+
+            // make sure that reallocations preserve content
+            int newCount = list.Capacity * 2;
+            CollectionsMarshal.SetCount(list, newCount);
+            Assert.Equal(newCount, list.Count);
+            SequenceEquals<int>(CollectionsMarshal.AsSpan(list)[..3], new int[] { 1, 2, 3 });
+            Assert.True(!Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
+
+            List<string> listReference = new() { "a", "b", "c", "d", "e" };
+            ref string stringRef = ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference));
+            CollectionsMarshal.SetCount(listReference, 3);
+            // verify that reference types aren't cleared
+            SequenceEquals<string>(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c" });
+            Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference))));
+            CollectionsMarshal.SetCount(listReference, 5);
+            // verify that removed reference types are cleared
+            SequenceEquals<string>(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c", null, null });
+            Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference))));
+
+            static void SequenceEquals<T>(ReadOnlySpan<T> actual, ReadOnlySpan<T> expected)
+            {
+                Assert.Equal(actual.Length, expected.Length);
+
+                for (int i = 0; i < actual.Length; i++)
+                {
+                    Assert.Equal(actual[i], expected[i]);
+                }
+            }
+        }
     }
 }