internal static partial class EnumerableHelpers
{
/// <summary>
- /// Tries to get the count of the enumerable cheaply.
- /// </summary>
- /// <typeparam name="T">The element type of the source enumerable.</typeparam>
- /// <param name="source">The enumerable to count.</param>
- /// <param name="count">The count of the enumerable, if it could be obtained cheaply.</param>
- /// <returns><c>true</c> if the enumerable could be counted cheaply; otherwise, <c>false</c>.</returns>
- internal static bool TryGetCount<T>(IEnumerable<T> source, out int count)
- {
- Debug.Assert(source != null);
-
- if (source is ICollection<T> collection)
- {
- count = collection.Count;
- return true;
- }
-
- if (source is IIListProvider<T> provider)
- {
- return (count = provider.GetCount(onlyIfCheap: true)) >= 0;
- }
-
- count = -1;
- return false;
- }
-
- /// <summary>
/// Copies items from an enumerable to an array.
/// </summary>
/// <typeparam name="T">The element type of the enumerable.</typeparam>
public bool ReserveOrAdd(IEnumerable<T> items)
{
int itemCount;
- if (EnumerableHelpers.TryGetCount(items, out itemCount))
+ if (System.Linq.Enumerable.TryGetNonEnumeratedCount(items, out itemCount))
{
if (itemCount > 0)
{
typeof(Enumerable),
typeof(Queryable),
new [] {
- "ToLookup",
- "ToDictionary",
- "ToArray",
- "AsEnumerable",
- "ToList",
+ nameof(Enumerable.ToLookup),
+ nameof(Enumerable.ToDictionary),
+ nameof(Enumerable.ToArray),
+ nameof(Enumerable.AsEnumerable),
+ nameof(Enumerable.ToList),
+ nameof(Enumerable.Append),
+ nameof(Enumerable.Prepend),
+ nameof(Enumerable.ToHashSet),
+ nameof(Enumerable.TryGetNonEnumeratedCount),
"Fold",
"LeftJoin",
- "Append",
- "Prepend",
- "ToHashSet"
}
);
typeof(Queryable),
typeof(Enumerable),
new [] {
- "AsQueryable"
+ nameof(Queryable.AsQueryable)
}
);
<PropertyGroup>
<StrongNameKeyId>Microsoft</StrongNameKeyId>
</PropertyGroup>
-</Project>
\ No newline at end of file
+</Project>
public static System.Linq.ILookup<TKey, TSource> ToLookup<TSource, TKey>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
public static System.Linq.ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Func<TSource, TElement> elementSelector) { throw null; }
public static System.Linq.ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Func<TSource, TElement> elementSelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
+ public static bool TryGetNonEnumeratedCount<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, out int count) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Union<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Union<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Where<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
public override int GetCount(bool onlyIfCheap)
{
int firstCount, secondCount;
- if (!EnumerableHelpers.TryGetCount(_first, out firstCount))
+ if (!_first.TryGetNonEnumeratedCount(out firstCount))
{
if (onlyIfCheap)
{
firstCount = _first.Count();
}
- if (!EnumerableHelpers.TryGetCount(_second, out secondCount))
+ if (!_second.TryGetNonEnumeratedCount(out secondCount))
{
if (onlyIfCheap)
{
return count;
}
+ /// <summary>
+ /// Attempts to determine the number of elements in a sequence without forcing an enumeration.
+ /// </summary>
+ /// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
+ /// <param name="source">A sequence that contains elements to be counted.</param>
+ /// <param name="count">
+ /// When this method returns, contains the count of <paramref name="source" /> if successful,
+ /// or zero if the method failed to determine the count.</param>
+ /// <returns>
+ /// <see langword="true" /> if the count of <paramref name="source"/> can be determined without enumeration;
+ /// otherwise, <see langword="false" />.
+ /// </returns>
+ /// <remarks>
+ /// The method performs a series of type tests, identifying common subtypes whose
+ /// count can be determined without enumerating; this includes <see cref="ICollection{T}"/>,
+ /// <see cref="ICollection"/> as well as internal types used in the LINQ implementation.
+ ///
+ /// The method is typically a constant-time operation, but ultimately this depends on the complexity
+ /// characteristics of the underlying collection implementation.
+ /// </remarks>
+ public static bool TryGetNonEnumeratedCount<TSource>(this IEnumerable<TSource> source, out int count)
+ {
+ if (source == null)
+ {
+ ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
+ }
+
+ if (source is ICollection<TSource> collectionoft)
+ {
+ count = collectionoft.Count;
+ return true;
+ }
+
+ if (source is IIListProvider<TSource> listProv)
+ {
+ int c = listProv.GetCount(onlyIfCheap: true);
+ if (c >= 0)
+ {
+ count = c;
+ return true;
+ }
+ }
+
+ if (source is ICollection collection)
+ {
+ count = collection.Count;
+ return true;
+ }
+
+ count = 0;
+ return false;
+ }
+
public static long LongCount<TSource>(this IEnumerable<TSource> source)
{
if (source == null)
nameof(Enumerable.ToArray),
nameof(Enumerable.AsEnumerable),
nameof(Enumerable.ToList),
+ nameof(Enumerable.ToHashSet),
+ nameof(Enumerable.TryGetNonEnumeratedCount),
"Fold",
"LeftJoin",
- "ToHashSet"
};
return result;
Func<int, bool> predicate = null;
AssertExtensions.Throws<ArgumentNullException>("predicate", () => Enumerable.Range(0, 3).Count(predicate));
}
+
+ [Fact]
+ public void NonEnumeratingCount_NullSource_ThrowsArgumentNullException()
+ {
+ AssertExtensions.Throws<ArgumentNullException>("source", () => ((IEnumerable<int>)null).TryGetNonEnumeratedCount(out _));
+ }
+
+ [Theory]
+ [MemberData(nameof(NonEnumeratingCount_SupportedEnumerables))]
+ public void NonEnumeratingCount_SupportedEnumerables_ShouldReturnExpectedCount<T>(int expectedCount, IEnumerable<T> source)
+ {
+ Assert.True(source.TryGetNonEnumeratedCount(out int actualCount));
+ Assert.Equal(expectedCount, actualCount);
+ }
+
+ [Theory]
+ [MemberData(nameof(NonEnumeratingCount_UnsupportedEnumerables))]
+ public void NonEnumeratingCount_UnsupportedEnumerables_ShouldReturnFalse<T>(IEnumerable<T> source)
+ {
+ Assert.False(source.TryGetNonEnumeratedCount(out int actualCount));
+ Assert.Equal(0, actualCount);
+ }
+
+ [Fact]
+ public void NonEnumeratingCount_ShouldNotEnumerateSource()
+ {
+ bool isEnumerated = false;
+ Assert.False(Source().TryGetNonEnumeratedCount(out int count));
+ Assert.Equal(0, count);
+ Assert.False(isEnumerated);
+
+ IEnumerable<int> Source()
+ {
+ isEnumerated = true;
+ yield return 42;
+ }
+ }
+
+ public static IEnumerable<object[]> NonEnumeratingCount_SupportedEnumerables()
+ {
+ yield return WrapArgs(4, new int[]{ 1, 2, 3, 4 });
+ yield return WrapArgs(4, new List<int>(new int[] { 1, 2, 3, 4 }));
+ yield return WrapArgs(4, new Stack<int>(new int[] { 1, 2, 3, 4 }));
+
+ yield return WrapArgs(0, Enumerable.Empty<string>());
+
+ if (PlatformDetection.IsSpeedOptimized)
+ {
+ yield return WrapArgs(100, Enumerable.Range(1, 100));
+ yield return WrapArgs(80, Enumerable.Repeat(1, 80));
+ yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1));
+ yield return WrapArgs(4, new int[] { 1, 2, 3, 4 }.Select(x => x + 1));
+ yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1));
+ yield return WrapArgs(7, Enumerable.Range(1, 20).ToLookup(x => x % 7));
+ yield return WrapArgs(20, Enumerable.Range(1, 20).Reverse());
+ yield return WrapArgs(20, Enumerable.Range(1, 20).OrderBy(x => -x));
+ yield return WrapArgs(20, Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10)));
+ }
+
+ static object[] WrapArgs<T>(int expectedCount, IEnumerable<T> source) => new object[] { expectedCount, source };
+ }
+
+ public static IEnumerable<object[]> NonEnumeratingCount_UnsupportedEnumerables()
+ {
+ yield return WrapArgs(Enumerable.Range(1, 100).Where(x => x % 2 == 0));
+ yield return WrapArgs(Enumerable.Range(1, 100).GroupBy(x => x % 2 == 0));
+ yield return WrapArgs(new Stack<int>(new int[] { 1, 2, 3, 4 }).Select(x => x + 1));
+ yield return WrapArgs(Enumerable.Range(1, 100).Distinct());
+
+ if (!PlatformDetection.IsSpeedOptimized)
+ {
+ yield return WrapArgs(Enumerable.Range(1, 100));
+ yield return WrapArgs(Enumerable.Repeat(1, 80));
+ yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1));
+ yield return WrapArgs(new int[] { 1, 2, 3, 4 }.Select(x => x + 1));
+ yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1));
+ yield return WrapArgs(Enumerable.Range(1, 20).ToLookup(x => x % 7));
+ yield return WrapArgs(Enumerable.Range(1, 20).Reverse());
+ yield return WrapArgs(Enumerable.Range(1, 20).OrderBy(x => -x));
+ yield return WrapArgs(Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10)));
+ }
+
+ static object[] WrapArgs<T>(IEnumerable<T> source) => new object[] { source };
+ }
}
}