Add Enumerable.TryGetNonEnumeratedCount (Implements #27183) (#48239)
authorEirik Tsarpalis <eirik.tsarpalis@gmail.com>
Mon, 15 Feb 2021 17:15:15 +0000 (17:15 +0000)
committerGitHub <noreply@github.com>
Mon, 15 Feb 2021 17:15:15 +0000 (17:15 +0000)
* implement Enumerable.TryGetEnumeratingCount

* address feedback

* update consistency tests

* Replace EnumerableHelpers.TryGetCount with new method

* Rename to method name as approved

* make method is renamed in all projects

src/libraries/Common/src/System/Collections/Generic/EnumerableHelpers.Linq.cs
src/libraries/Common/src/System/Collections/Generic/SparseArrayBuilder.cs
src/libraries/System.Linq.Queryable/tests/Queryable.cs
src/libraries/System.Linq/Directory.Build.props
src/libraries/System.Linq/ref/System.Linq.cs
src/libraries/System.Linq/src/System/Linq/Concat.SpeedOpt.cs
src/libraries/System.Linq/src/System/Linq/Count.cs
src/libraries/System.Linq/tests/ConsistencyTests.cs
src/libraries/System.Linq/tests/CountTests.cs

index 7efd10d..7be694f 100644 (file)
@@ -12,32 +12,6 @@ namespace System.Collections.Generic
     internal static partial class EnumerableHelpers
     {
         /// <summary>
-        /// Tries to get the count of the enumerable cheaply.
-        /// </summary>
-        /// <typeparam name="T">The element type of the source enumerable.</typeparam>
-        /// <param name="source">The enumerable to count.</param>
-        /// <param name="count">The count of the enumerable, if it could be obtained cheaply.</param>
-        /// <returns><c>true</c> if the enumerable could be counted cheaply; otherwise, <c>false</c>.</returns>
-        internal static bool TryGetCount<T>(IEnumerable<T> source, out int count)
-        {
-            Debug.Assert(source != null);
-
-            if (source is ICollection<T> collection)
-            {
-                count = collection.Count;
-                return true;
-            }
-
-            if (source is IIListProvider<T> provider)
-            {
-                return (count = provider.GetCount(onlyIfCheap: true)) >= 0;
-            }
-
-            count = -1;
-            return false;
-        }
-
-        /// <summary>
         /// Copies items from an enumerable to an array.
         /// </summary>
         /// <typeparam name="T">The element type of the enumerable.</typeparam>
index 72bffe2..c0ff03e 100644 (file)
@@ -190,7 +190,7 @@ namespace System.Collections.Generic
         public bool ReserveOrAdd(IEnumerable<T> items)
         {
             int itemCount;
-            if (EnumerableHelpers.TryGetCount(items, out itemCount))
+            if (System.Linq.Enumerable.TryGetNonEnumeratedCount(items, out itemCount))
             {
                 if (itemCount > 0)
                 {
index 53284a3..9766b25 100644 (file)
@@ -121,16 +121,17 @@ namespace System.Linq.Tests
                 typeof(Enumerable),
                 typeof(Queryable),
                  new [] {
-                     "ToLookup",
-                     "ToDictionary",
-                     "ToArray",
-                     "AsEnumerable",
-                     "ToList",
+                     nameof(Enumerable.ToLookup),
+                     nameof(Enumerable.ToDictionary),
+                     nameof(Enumerable.ToArray),
+                     nameof(Enumerable.AsEnumerable),
+                     nameof(Enumerable.ToList),
+                     nameof(Enumerable.Append),
+                     nameof(Enumerable.Prepend),
+                     nameof(Enumerable.ToHashSet),
+                     nameof(Enumerable.TryGetNonEnumeratedCount),
                      "Fold",
                      "LeftJoin",
-                     "Append",
-                     "Prepend",
-                     "ToHashSet"
                  }
                 );
 
@@ -140,7 +141,7 @@ namespace System.Linq.Tests
                 typeof(Queryable),
                 typeof(Enumerable),
                  new [] {
-                     "AsQueryable"
+                     nameof(Queryable.AsQueryable)
                  }
                 );
 
index 63f02a0..e8d6554 100644 (file)
@@ -3,4 +3,4 @@
   <PropertyGroup>
     <StrongNameKeyId>Microsoft</StrongNameKeyId>
   </PropertyGroup>
-</Project>
\ No newline at end of file
+</Project>
index 4c8007c..f401b5e 100644 (file)
@@ -189,6 +189,7 @@ namespace System.Linq
         public static System.Linq.ILookup<TKey, TSource> ToLookup<TSource, TKey>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
         public static System.Linq.ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Func<TSource, TElement> elementSelector) { throw null; }
         public static System.Linq.ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Func<TSource, TElement> elementSelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
+        public static bool TryGetNonEnumeratedCount<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, out int count) { throw null; }
         public static System.Collections.Generic.IEnumerable<TSource> Union<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second) { throw null; }
         public static System.Collections.Generic.IEnumerable<TSource> Union<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
         public static System.Collections.Generic.IEnumerable<TSource> Where<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
index 93aa590..83176f0 100644 (file)
@@ -13,7 +13,7 @@ namespace System.Linq
             public override int GetCount(bool onlyIfCheap)
             {
                 int firstCount, secondCount;
-                if (!EnumerableHelpers.TryGetCount(_first, out firstCount))
+                if (!_first.TryGetNonEnumeratedCount(out firstCount))
                 {
                     if (onlyIfCheap)
                     {
@@ -23,7 +23,7 @@ namespace System.Linq
                     firstCount = _first.Count();
                 }
 
-                if (!EnumerableHelpers.TryGetCount(_second, out secondCount))
+                if (!_second.TryGetNonEnumeratedCount(out secondCount))
                 {
                     if (onlyIfCheap)
                     {
index c34b9af..14f3d45 100644 (file)
@@ -72,6 +72,59 @@ namespace System.Linq
             return count;
         }
 
+        /// <summary>
+        ///   Attempts to determine the number of elements in a sequence without forcing an enumeration.
+        /// </summary>
+        /// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
+        /// <param name="source">A sequence that contains elements to be counted.</param>
+        /// <param name="count">
+        ///     When this method returns, contains the count of <paramref name="source" /> if successful,
+        ///     or zero if the method failed to determine the count.</param>
+        /// <returns>
+        ///   <see langword="true" /> if the count of <paramref name="source"/> can be determined without enumeration;
+        ///   otherwise, <see langword="false" />.
+        /// </returns>
+        /// <remarks>
+        ///   The method performs a series of type tests, identifying common subtypes whose
+        ///   count can be determined without enumerating; this includes <see cref="ICollection{T}"/>,
+        ///   <see cref="ICollection"/> as well as internal types used in the LINQ implementation.
+        ///
+        ///   The method is typically a constant-time operation, but ultimately this depends on the complexity
+        ///   characteristics of the underlying collection implementation.
+        /// </remarks>
+        public static bool TryGetNonEnumeratedCount<TSource>(this IEnumerable<TSource> source, out int count)
+        {
+            if (source == null)
+            {
+                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
+            }
+
+            if (source is ICollection<TSource> collectionoft)
+            {
+                count = collectionoft.Count;
+                return true;
+            }
+
+            if (source is IIListProvider<TSource> listProv)
+            {
+                int c = listProv.GetCount(onlyIfCheap: true);
+                if (c >= 0)
+                {
+                    count = c;
+                    return true;
+                }
+            }
+
+            if (source is ICollection collection)
+            {
+                count = collection.Count;
+                return true;
+            }
+
+            count = 0;
+            return false;
+        }
+
         public static long LongCount<TSource>(this IEnumerable<TSource> source)
         {
             if (source == null)
index 35e59f3..7ad8084 100644 (file)
@@ -41,9 +41,10 @@ namespace System.Linq.Tests
                 nameof(Enumerable.ToArray),
                 nameof(Enumerable.AsEnumerable),
                 nameof(Enumerable.ToList),
+                nameof(Enumerable.ToHashSet),
+                nameof(Enumerable.TryGetNonEnumeratedCount),
                 "Fold",
                 "LeftJoin",
-                "ToHashSet"
             };
 
             return result;
index ea6c449..aee562d 100644 (file)
@@ -126,5 +126,89 @@ namespace System.Linq.Tests
             Func<int, bool> predicate = null;
             AssertExtensions.Throws<ArgumentNullException>("predicate", () => Enumerable.Range(0, 3).Count(predicate));
         }
+
+        [Fact]
+        public void NonEnumeratingCount_NullSource_ThrowsArgumentNullException()
+        {
+            AssertExtensions.Throws<ArgumentNullException>("source", () => ((IEnumerable<int>)null).TryGetNonEnumeratedCount(out _));
+        }
+
+        [Theory]
+        [MemberData(nameof(NonEnumeratingCount_SupportedEnumerables))]
+        public void NonEnumeratingCount_SupportedEnumerables_ShouldReturnExpectedCount<T>(int expectedCount, IEnumerable<T> source)
+        {
+            Assert.True(source.TryGetNonEnumeratedCount(out int actualCount));
+            Assert.Equal(expectedCount, actualCount);
+        }
+
+        [Theory]
+        [MemberData(nameof(NonEnumeratingCount_UnsupportedEnumerables))]
+        public void NonEnumeratingCount_UnsupportedEnumerables_ShouldReturnFalse<T>(IEnumerable<T> source)
+        {
+            Assert.False(source.TryGetNonEnumeratedCount(out int actualCount));
+            Assert.Equal(0, actualCount);
+        }
+
+        [Fact]
+        public void NonEnumeratingCount_ShouldNotEnumerateSource()
+        {
+            bool isEnumerated = false;
+            Assert.False(Source().TryGetNonEnumeratedCount(out int count));
+            Assert.Equal(0, count);
+            Assert.False(isEnumerated);
+
+            IEnumerable<int> Source()
+            {
+                isEnumerated = true;
+                yield return 42;
+            }
+        }
+
+        public static IEnumerable<object[]> NonEnumeratingCount_SupportedEnumerables()
+        {
+            yield return WrapArgs(4, new int[]{ 1, 2, 3, 4 });
+            yield return WrapArgs(4, new List<int>(new int[] { 1, 2, 3, 4 }));
+            yield return WrapArgs(4, new Stack<int>(new int[] { 1, 2, 3, 4 }));
+
+            yield return WrapArgs(0, Enumerable.Empty<string>());
+
+            if (PlatformDetection.IsSpeedOptimized)
+            {
+                yield return WrapArgs(100, Enumerable.Range(1, 100));
+                yield return WrapArgs(80, Enumerable.Repeat(1, 80));
+                yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1));
+                yield return WrapArgs(4, new int[] { 1, 2, 3, 4 }.Select(x => x + 1));
+                yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1));
+                yield return WrapArgs(7, Enumerable.Range(1, 20).ToLookup(x => x % 7));
+                yield return WrapArgs(20, Enumerable.Range(1, 20).Reverse());
+                yield return WrapArgs(20, Enumerable.Range(1, 20).OrderBy(x => -x));
+                yield return WrapArgs(20, Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10)));
+            }
+
+            static object[] WrapArgs<T>(int expectedCount, IEnumerable<T> source) => new object[] { expectedCount, source };
+        }
+
+        public static IEnumerable<object[]> NonEnumeratingCount_UnsupportedEnumerables()
+        {
+            yield return WrapArgs(Enumerable.Range(1, 100).Where(x => x % 2 == 0));
+            yield return WrapArgs(Enumerable.Range(1, 100).GroupBy(x => x % 2 == 0));
+            yield return WrapArgs(new Stack<int>(new int[] { 1, 2, 3, 4 }).Select(x => x + 1));
+            yield return WrapArgs(Enumerable.Range(1, 100).Distinct());
+
+            if (!PlatformDetection.IsSpeedOptimized)
+            {
+                yield return WrapArgs(Enumerable.Range(1, 100));
+                yield return WrapArgs(Enumerable.Repeat(1, 80));
+                yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1));
+                yield return WrapArgs(new int[] { 1, 2, 3, 4 }.Select(x => x + 1));            
+                yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1));
+                yield return WrapArgs(Enumerable.Range(1, 20).ToLookup(x => x % 7));
+                yield return WrapArgs(Enumerable.Range(1, 20).Reverse());
+                yield return WrapArgs(Enumerable.Range(1, 20).OrderBy(x => -x));
+                yield return WrapArgs(Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10)));
+            }
+
+            static object[] WrapArgs<T>(IEnumerable<T> source) => new object[] { source };
+        }
     }
 }