// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
+using System.Runtime.CompilerServices;
namespace System.Runtime.InteropServices
{
/// <remarks>Items should not be added to or removed from the <see cref="Dictionary{TKey, TValue}"/> while the ref <typeparamref name="TValue"/> is in use.</remarks>
public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key, out bool exists) where TKey : notnull
=> ref Dictionary<TKey, TValue>.CollectionsMarshalHelper.GetValueRefOrAddDefault(dictionary, key, out exists);
+
+ /// <summary>
+ /// Sets the count of the <see cref="List{T}"/> to the specified value.
+ /// </summary>
+ /// <param name="list">The list to set the count of.</param>
+ /// <param name="count">The value to set the list's count to.</param>
+ /// <exception cref="NullReferenceException">
+ /// <paramref name="list"/> is <see langword="null"/>.
+ /// </exception>
+ /// <exception cref="ArgumentOutOfRangeException">
+ /// <paramref name="count"/> is negative.
+ /// </exception>
+ /// <remarks>
+ /// When increasing the count, uninitialized data is being exposed.
+ /// </remarks>
+ public static void SetCount<T>(List<T> list, int count)
+ {
+ if (count < 0)
+ {
+ ThrowHelper.ThrowArgumentOutOfRangeException_NeedNonNegNum(nameof(count));
+ }
+
+ list._version++;
+
+ if (count > list.Capacity)
+ {
+ list.Grow(count);
+ }
+ else if (count < list._size && RuntimeHelpers.IsReferenceOrContainsReferences<T>())
+ {
+ Array.Clear(list._items, count, list._size - count);
+ }
+
+ list._size = count;
+ }
}
}
public int Value;
public int Property { get; set; }
}
+
+ [Fact]
+ public void ListSetCount()
+ {
+ List<int> list = null!;
+ Assert.Throws<NullReferenceException>(() => CollectionsMarshal.SetCount(list, 3));
+
+ Assert.Throws<ArgumentOutOfRangeException>(() => CollectionsMarshal.SetCount(list, -1));
+
+ list = new();
+ Assert.Throws<ArgumentOutOfRangeException>(() => CollectionsMarshal.SetCount(list, -1));
+
+ CollectionsMarshal.SetCount(list, 5);
+ Assert.Equal(5, list.Count);
+
+ list = new() { 1, 2, 3, 4, 5 };
+ ref int intRef = ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list));
+ // make sure that size decrease preserves content
+ CollectionsMarshal.SetCount(list, 3);
+ Assert.Equal(3, list.Count);
+ Assert.Throws<ArgumentOutOfRangeException>(() => list[3]);
+ SequenceEquals<int>(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3 });
+ Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
+
+ // make sure that size increase preserves content and doesn't clear
+ CollectionsMarshal.SetCount(list, 5);
+ SequenceEquals<int>(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3, 4, 5 });
+ Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
+
+ // make sure that reallocations preserve content
+ int newCount = list.Capacity * 2;
+ CollectionsMarshal.SetCount(list, newCount);
+ Assert.Equal(newCount, list.Count);
+ SequenceEquals<int>(CollectionsMarshal.AsSpan(list)[..3], new int[] { 1, 2, 3 });
+ Assert.True(!Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list))));
+
+ List<string> listReference = new() { "a", "b", "c", "d", "e" };
+ ref string stringRef = ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference));
+ CollectionsMarshal.SetCount(listReference, 3);
+ // verify that reference types aren't cleared
+ SequenceEquals<string>(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c" });
+ Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference))));
+ CollectionsMarshal.SetCount(listReference, 5);
+ // verify that removed reference types are cleared
+ SequenceEquals<string>(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c", null, null });
+ Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference))));
+
+ static void SequenceEquals<T>(ReadOnlySpan<T> actual, ReadOnlySpan<T> expected)
+ {
+ Assert.Equal(actual.Length, expected.Length);
+
+ for (int i = 0; i < actual.Length; i++)
+ {
+ Assert.Equal(actual[i], expected[i]);
+ }
+ }
+ }
}
}