Add DataflowBlock.ReceiveAllAsync extension method (#37876)
authorEmmanuel André <2341261+manandre@users.noreply.github.com>
Sat, 6 Mar 2021 11:56:24 +0000 (12:56 +0100)
committerGitHub <noreply@github.com>
Sat, 6 Mar 2021 11:56:24 +0000 (06:56 -0500)
src/libraries/System.Threading.Tasks.Dataflow/ref/System.Threading.Tasks.Dataflow.csproj
src/libraries/System.Threading.Tasks.Dataflow/ref/System.Threading.Tasks.Dataflow.netstandard21.cs [new file with mode: 0644]
src/libraries/System.Threading.Tasks.Dataflow/src/Base/DataflowBlock.IAsyncEnumerable.cs [new file with mode: 0644]
src/libraries/System.Threading.Tasks.Dataflow/src/Base/DataflowBlock.cs
src/libraries/System.Threading.Tasks.Dataflow/src/System.Threading.Tasks.Dataflow.csproj
src/libraries/System.Threading.Tasks.Dataflow/tests/Dataflow/DataflowBlockExtensionTests.IAsyncEnumerable.cs [new file with mode: 0644]
src/libraries/System.Threading.Tasks.Dataflow/tests/Dataflow/DataflowBlockExtensionTests.cs
src/libraries/System.Threading.Tasks.Dataflow/tests/System.Threading.Tasks.Dataflow.Tests.csproj

index 7c65f89..d24314b 100644 (file)
@@ -1,9 +1,12 @@
 <Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
-    <TargetFrameworks>netstandard2.0;netstandard1.0;netstandard1.1;net461</TargetFrameworks>
+    <TargetFrameworks>netstandard2.1;netstandard2.0;netstandard1.0;netstandard1.1;net461</TargetFrameworks>
     <Nullable>enable</Nullable>
   </PropertyGroup>
   <ItemGroup>
     <Compile Include="System.Threading.Tasks.Dataflow.cs" />
   </ItemGroup>
-</Project>
\ No newline at end of file
+  <ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.1'">
+    <Compile Include="System.Threading.Tasks.Dataflow.netstandard21.cs" />
+  </ItemGroup>
+</Project>
diff --git a/src/libraries/System.Threading.Tasks.Dataflow/ref/System.Threading.Tasks.Dataflow.netstandard21.cs b/src/libraries/System.Threading.Tasks.Dataflow/ref/System.Threading.Tasks.Dataflow.netstandard21.cs
new file mode 100644 (file)
index 0000000..ed78964
--- /dev/null
@@ -0,0 +1,13 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// ------------------------------------------------------------------------------
+// Changes to this file must follow the https://aka.ms/api-review process.
+// ------------------------------------------------------------------------------
+
+namespace System.Threading.Tasks.Dataflow
+{
+    public static partial class DataflowBlock
+    {
+        public static System.Collections.Generic.IAsyncEnumerable<TOutput> ReceiveAllAsync<TOutput>(this System.Threading.Tasks.Dataflow.IReceivableSourceBlock<TOutput> source, System.Threading.CancellationToken cancellationToken = default) { throw null; }
+    }
+}
diff --git a/src/libraries/System.Threading.Tasks.Dataflow/src/Base/DataflowBlock.IAsyncEnumerable.cs b/src/libraries/System.Threading.Tasks.Dataflow/src/Base/DataflowBlock.IAsyncEnumerable.cs
new file mode 100644 (file)
index 0000000..972a5c9
--- /dev/null
@@ -0,0 +1,38 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+
+namespace System.Threading.Tasks.Dataflow
+{
+    public static partial class DataflowBlock
+    {
+        /// <summary>Creates an <see cref="IAsyncEnumerable{TOutput}"/> that enables receiving all of the data from the source.</summary>
+        /// <typeparam name="TOutput">Specifies the type of data contained in the source.</typeparam>
+        /// <param name="source">The source from which to asynchronously receive.</param>
+        /// <param name="cancellationToken">The <see cref="System.Threading.CancellationToken"/> which may be used to cancel the receive operation.</param>
+        /// <returns>The created async enumerable.</returns>
+        /// <exception cref="System.ArgumentNullException">The <paramref name="source"/> is null (Nothing in Visual Basic).</exception>
+        public static IAsyncEnumerable<TOutput> ReceiveAllAsync<TOutput>(this IReceivableSourceBlock<TOutput> source, CancellationToken cancellationToken = default)
+        {
+            if (source == null)
+            {
+                throw new ArgumentNullException(nameof(source));
+            }
+
+            return ReceiveAllAsyncCore(source, cancellationToken);
+
+            static async IAsyncEnumerable<TOutput> ReceiveAllAsyncCore(IReceivableSourceBlock<TOutput> source, [EnumeratorCancellation] CancellationToken cancellationToken)
+            {
+                while (await source.OutputAvailableAsync(cancellationToken).ConfigureAwait(false))
+                {
+                    while (source.TryReceive(out TOutput? item))
+                    {
+                        yield return item;
+                    }
+                }
+            }
+        }
+    }
+}
index e3393d1..cfc2ebf 100644 (file)
@@ -25,7 +25,7 @@ namespace System.Threading.Tasks.Dataflow
     /// <summary>
     /// Provides a set of static (Shared in Visual Basic) methods for working with dataflow blocks.
     /// </summary>
-    public static class DataflowBlock
+    public static partial class DataflowBlock
     {
         #region LinkTo
         /// <summary>Links the <see cref="ISourceBlock{TOutput}"/> to the specified <see cref="ITargetBlock{TOutput}"/>.</summary>
index 7b324e8..fa21b7c 100644 (file)
@@ -1,6 +1,6 @@
 <Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
-    <TargetFrameworks>$(NetCoreAppCurrent);netstandard2.0;netstandard1.0;netstandard1.1;net461</TargetFrameworks>
+    <TargetFrameworks>$(NetCoreAppCurrent);netstandard2.1;netstandard2.0;netstandard1.0;netstandard1.1;net461</TargetFrameworks>
     <ExcludeCurrentNetCoreAppFromPackage>true</ExcludeCurrentNetCoreAppFromPackage>
     <Nullable>enable</Nullable>
   </PropertyGroup>
@@ -55,6 +55,9 @@
     <Compile Include="Internal\ConcurrentQueue.cs" />
     <Compile Include="Internal\IProducerConsumerCollection.cs" />
   </ItemGroup>
+  <ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.1' or '$(TargetFramework)' == '$(NetCoreAppCurrent)'">
+    <Compile Include="Base\DataflowBlock.IAsyncEnumerable.cs" />
+  </ItemGroup>
   <ItemGroup>
     <None Include="XmlDocs\CommonXmlDocComments.xml" />
   </ItemGroup>
diff --git a/src/libraries/System.Threading.Tasks.Dataflow/tests/Dataflow/DataflowBlockExtensionTests.IAsyncEnumerable.cs b/src/libraries/System.Threading.Tasks.Dataflow/tests/Dataflow/DataflowBlockExtensionTests.IAsyncEnumerable.cs
new file mode 100644 (file)
index 0000000..2990371
--- /dev/null
@@ -0,0 +1,264 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using Xunit;
+
+namespace System.Threading.Tasks.Dataflow.Tests
+{
+    public partial class DataflowBlockExtensionTests
+    {
+        [Fact]
+        public void ReceiveAllAsync_ArgumentValidation()
+        {
+            AssertExtensions.Throws<ArgumentNullException>("source", () => ((IReceivableSourceBlock<int>)null).ReceiveAllAsync());
+            AssertExtensions.Throws<ArgumentNullException>("source", () => ((IReceivableSourceBlock<int>)null).ReceiveAllAsync(new CancellationToken(true)));
+        }
+
+        [Fact]
+        public void ReceiveAllAsync_NotIdempotent()
+        {
+            var source = new BufferBlock<int>();
+            IAsyncEnumerable<int> e = source.ReceiveAllAsync();
+            Assert.NotNull(e);
+            Assert.NotSame(e, source.ReceiveAllAsync());
+        }
+
+        [Fact]
+        public async Task ReceiveAllAsync_UseMoveNextAsyncAfterCompleted_ReturnsFalse()
+        {
+            var source = new BufferBlock<int>();
+            IAsyncEnumerator<int> e = source.ReceiveAllAsync().GetAsyncEnumerator();
+
+            ValueTask<bool> vt = e.MoveNextAsync();
+            Assert.False(vt.IsCompleted);
+            source.Complete();
+            Assert.False(await vt);
+
+            vt = e.MoveNextAsync();
+            Assert.True(vt.IsCompletedSuccessfully);
+            Assert.False(vt.Result);
+        }
+
+        [Fact]
+        public void ReceiveAllAsync_AvailableDataCompletesSynchronously()
+        {
+            var source = new BufferBlock<int>();
+
+            IAsyncEnumerator<int> e = source.ReceiveAllAsync().GetAsyncEnumerator();
+            try
+            {
+                for (int i = 100; i < 110; i++)
+                {
+                    Assert.True(source.Post(i));
+                    ValueTask<bool> vt = e.MoveNextAsync();
+                    Assert.True(vt.IsCompletedSuccessfully);
+                    Assert.True(vt.Result);
+                    Assert.Equal(i, e.Current);
+                }
+            }
+            finally
+            {
+                ValueTask vt = e.DisposeAsync();
+                Assert.True(vt.IsCompletedSuccessfully);
+                vt.GetAwaiter().GetResult();
+            }
+        }
+
+        [Fact]
+        public async Task ReceiveAllAsync_UnavailableDataCompletesAsynchronously()
+        {
+            var source = new BufferBlock<int>();
+
+            IAsyncEnumerator<int> e = source.ReceiveAllAsync().GetAsyncEnumerator();
+            try
+            {
+                for (int i = 100; i < 110; i++)
+                {
+                    ValueTask<bool> vt = e.MoveNextAsync();
+                    Assert.False(vt.IsCompleted);
+                    Task producer = Task.Run(() => source.Post(i));
+                    Assert.True(await vt);
+                    await producer;
+                    Assert.Equal(i, e.Current);
+                }
+            }
+            finally
+            {
+                ValueTask vt = e.DisposeAsync();
+                Assert.True(vt.IsCompletedSuccessfully);
+                vt.GetAwaiter().GetResult();
+            }
+        }
+
+        [Theory]
+        [InlineData(0)]
+        [InlineData(1)]
+        [InlineData(128)]
+        public async Task ReceiveAllAsync_ProducerConsumer_ConsumesAllData(int items)
+        {
+            var source = new BufferBlock<int>();
+
+            int producedTotal = 0, consumedTotal = 0;
+            await Task.WhenAll(
+                Task.Run(async () =>
+                {
+                    for (int i = 0; i < items; i++)
+                    {
+                        await source.SendAsync(i);
+                        producedTotal += i;
+                    }
+                    source.Complete();
+                }),
+                Task.Run(async () =>
+                {
+                    IAsyncEnumerator<int> e = source.ReceiveAllAsync().GetAsyncEnumerator();
+                    try
+                    {
+                        while (await e.MoveNextAsync())
+                        {
+                            consumedTotal += e.Current;
+                        }
+                    }
+                    finally
+                    {
+                        await e.DisposeAsync();
+                    }
+                }));
+
+            Assert.Equal(producedTotal, consumedTotal);
+        }
+
+        [Fact]
+        public async Task ReceiveAllAsync_MultipleEnumerationsToEnd()
+        {
+            var source = new BufferBlock<int>();
+
+            Assert.True(source.Post(42));
+            source.Complete();
+
+            IAsyncEnumerable<int> enumerable = source.ReceiveAllAsync();
+            IAsyncEnumerator<int> e = enumerable.GetAsyncEnumerator();
+
+            Assert.True(await e.MoveNextAsync());
+            Assert.Equal(42, e.Current);
+
+            Assert.False(await e.MoveNextAsync());
+            Assert.False(await e.MoveNextAsync());
+
+            await e.DisposeAsync();
+
+            e = enumerable.GetAsyncEnumerator();
+
+            Assert.False(await e.MoveNextAsync());
+            Assert.False(await e.MoveNextAsync());
+        }
+
+        [Theory]
+        [InlineData(false, false)]
+        [InlineData(false, true)]
+        [InlineData(true, false)]
+        [InlineData(true, true)]
+        public void ReceiveAllAsync_MultipleSingleElementEnumerations_AllItemsEnumerated(bool sameEnumerable, bool dispose)
+        {
+            var source = new BufferBlock<int>();
+            IAsyncEnumerable<int> enumerable = source.ReceiveAllAsync();
+
+            for (int i = 0; i < 10; i++)
+            {
+                Assert.True(source.Post(i));
+                IAsyncEnumerator<int> e = (sameEnumerable ? enumerable : source.ReceiveAllAsync()).GetAsyncEnumerator();
+                ValueTask<bool> vt = e.MoveNextAsync();
+                Assert.True(vt.IsCompletedSuccessfully);
+                Assert.True(vt.Result);
+                Assert.Equal(i, e.Current);
+                if (dispose)
+                {
+                    ValueTask dvt = e.DisposeAsync();
+                    Assert.True(dvt.IsCompletedSuccessfully);
+                    dvt.GetAwaiter().GetResult();
+                }
+            }
+        }
+
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task ReceiveAllAsync_DualConcurrentEnumeration_AllItemsEnumerated(bool sameEnumerable)
+        {
+            var source = new BufferBlock<int>();
+
+            IAsyncEnumerable<int> enumerable = source.ReceiveAllAsync();
+
+            IAsyncEnumerator<int> e1 = enumerable.GetAsyncEnumerator();
+            IAsyncEnumerator<int> e2 = (sameEnumerable ? enumerable : source.ReceiveAllAsync()).GetAsyncEnumerator();
+            Assert.NotSame(e1, e2);
+
+            ValueTask<bool> vt1, vt2;
+            int producerTotal = 0, consumerTotal = 0;
+            for (int i = 0; i < 10; i++)
+            {
+                vt1 = e1.MoveNextAsync();
+                vt2 = e2.MoveNextAsync();
+
+                await source.SendAsync(i);
+                producerTotal += i;
+                await source.SendAsync(i * 2);
+                producerTotal += i * 2;
+
+                Assert.True(await vt1);
+                Assert.True(await vt2);
+                consumerTotal += e1.Current;
+                consumerTotal += e2.Current;
+            }
+
+            vt1 = e1.MoveNextAsync();
+            vt2 = e2.MoveNextAsync();
+            source.Complete();
+            Assert.False(await vt1);
+            Assert.False(await vt2);
+
+            Assert.Equal(producerTotal, consumerTotal);
+        }
+
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task ReceiveAllAsync_CanceledBeforeMoveNextAsync_Throws(bool dataAvailable)
+        {
+            var source = new BufferBlock<int>();
+            if (dataAvailable)
+            {
+                Assert.True(source.Post(42));
+            }
+
+            using var cts = new CancellationTokenSource();
+            cts.Cancel();
+
+            IAsyncEnumerator<int> e = source.ReceiveAllAsync(cts.Token).GetAsyncEnumerator();
+            ValueTask<bool> vt = e.MoveNextAsync();
+            Assert.True(vt.IsCompleted);
+            Assert.False(vt.IsCompletedSuccessfully);
+            OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await vt);
+            Assert.Equal(cts.Token, oce.CancellationToken);
+        }
+
+        [Fact]
+        public async Task ReceiveAllAsync_CanceledAfterMoveNextAsync_Throws()
+        {
+            var source = new BufferBlock<int>();
+            using var cts = new CancellationTokenSource();
+
+            IAsyncEnumerator<int> e = source.ReceiveAllAsync(cts.Token).GetAsyncEnumerator();
+            ValueTask<bool> vt = e.MoveNextAsync();
+            Assert.False(vt.IsCompleted);
+
+            cts.Cancel();
+            OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await vt);
+
+            vt = e.MoveNextAsync();
+            Assert.True(vt.IsCompletedSuccessfully);
+            Assert.False(vt.Result);
+        }
+    }
+}
index be3d006..d761c6b 100644 (file)
@@ -8,7 +8,7 @@ using Xunit;
 
 namespace System.Threading.Tasks.Dataflow.Tests
 {
-    public class DataflowBlockExtensionsTests
+    public partial class DataflowBlockExtensionsTests
     {
         [Fact]
         public void TestDataflowMessageHeader()
index f817822..d782247 100644 (file)
@@ -26,6 +26,9 @@
     <Compile Include="$(CommonTestPath)System\Diagnostics\DebuggerAttributes.cs"
              Link="Common\System\Diagnostics\DebuggerAttributes.cs" />
   </ItemGroup>
+  <ItemGroup Condition="'$(TargetFramework)' == '$(NetCoreAppCurrent)'">
+    <Compile Include="Dataflow\DataflowBlockExtensionTests.IAsyncEnumerable.cs" />
+  </ItemGroup>
   <ItemGroup Condition="'$(TargetFramework)' == 'net461'">
     <ProjectReference Include="..\src\System.Threading.Tasks.Dataflow.csproj" />
   </ItemGroup>