Expose/test WithCancellation for async enumerables
authorStephen Toub <stoub@microsoft.com>
Thu, 10 Jan 2019 20:31:51 +0000 (15:31 -0500)
committerStephen Toub <stoub@microsoft.com>
Sat, 12 Jan 2019 14:39:07 +0000 (09:39 -0500)
Commit migrated from https://github.com/dotnet/corefx/commit/74a04273fb5219210c69d6d117fc29af26017721

src/libraries/System.Threading.Tasks/ref/System.Threading.Tasks.cs
src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/ConfiguredAsyncEnumerableTests.netcoreapp.cs [deleted file]
src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/ConfiguredCancelableAsyncEnumerableTests.netcoreapp.cs [new file with mode: 0644]
src/libraries/System.Threading.Tasks/tests/System.Threading.Tasks.Tests.csproj

index 6e2b94e29baaea71cd1adae2ff1626934f4984a8..c0f458d61f9eaeb6576ee93d208d45134a36698b 100644 (file)
@@ -87,10 +87,12 @@ namespace System.Runtime.CompilerServices
         public void SetStateMachine(System.Runtime.CompilerServices.IAsyncStateMachine stateMachine) { }
         public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { }
     }
-    public readonly struct ConfiguredAsyncEnumerable<T>
+    public readonly struct ConfiguredCancelableAsyncEnumerable<T>
     {
         private readonly object _dummy;
-        public Enumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default) { throw null; }
+        public Enumerator GetAsyncEnumerator() { throw null; }
+        public ConfiguredCancelableAsyncEnumerable<T> ConfigureAwait(bool continueOnCapturedContext) { throw null; }
+        public ConfiguredCancelableAsyncEnumerable<T> WithCancellation(System.Threading.CancellationToken cancellationToken) { throw null; }
         public readonly struct Enumerator
         {
             private readonly object _dummy;
@@ -161,7 +163,8 @@ namespace System.Threading.Tasks
     }
     public static partial class TaskExtensions
     {
-        public static System.Runtime.CompilerServices.ConfiguredAsyncEnumerable<T> ConfigureAwait<T>(this System.Collections.Generic.IAsyncEnumerable<T> source, bool continueOnCapturedContext) { throw null; }
+        public static System.Runtime.CompilerServices.ConfiguredCancelableAsyncEnumerable<T> ConfigureAwait<T>(this System.Collections.Generic.IAsyncEnumerable<T> source, bool continueOnCapturedContext) { throw null; }
+        public static System.Runtime.CompilerServices.ConfiguredCancelableAsyncEnumerable<T> WithCancellation<T>(this System.Collections.Generic.IAsyncEnumerable<T> source, CancellationToken cancellationToken) { throw null; }
         public static System.Threading.Tasks.Task Unwrap(this System.Threading.Tasks.Task<System.Threading.Tasks.Task> task) { throw null; }
         public static System.Threading.Tasks.Task<TResult> Unwrap<TResult>(this System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> task) { throw null; }
     }
diff --git a/src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/ConfiguredAsyncEnumerableTests.netcoreapp.cs b/src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/ConfiguredAsyncEnumerableTests.netcoreapp.cs
deleted file mode 100644 (file)
index 0d124e9..0000000
+++ /dev/null
@@ -1,148 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System.Collections.Generic;
-using System.Linq;
-using System.Threading;
-using System.Threading.Tasks;
-using System.Threading.Tasks.Sources;
-using Xunit;
-
-namespace System.Runtime.CompilerServices.Tests
-{
-    public class ConfiguredAsyncEnumerableTests
-    {
-        [Fact]
-        public void ConfigureAwait_GetAsyncEnumerator_Default_Throws()
-        {
-            ConfiguredAsyncEnumerable<int> e = default;
-            Assert.Throws<NullReferenceException>(() => e.GetAsyncEnumerator());
-
-            e = ((IAsyncEnumerable<int>)null).ConfigureAwait(false);
-            Assert.Throws<NullReferenceException>(() => e.GetAsyncEnumerator());
-        }
-
-        [Fact]
-        public void ConfigureAwait_EnumeratorMembers_Default_Throws()
-        {
-            ConfiguredAsyncEnumerable<int>.Enumerator e = default;
-            Assert.Throws<NullReferenceException>(() => e.MoveNextAsync());
-            Assert.Throws<NullReferenceException>(() => e.Current);
-            Assert.Throws<NullReferenceException>(() => e.DisposeAsync());
-        }
-
-        [Fact]
-        public void ConfigureAwait_GetAsyncEnumerator_CancellationTokenPassedthrough()
-        {
-            var enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
-            var cts = new CancellationTokenSource();
-            ConfiguredAsyncEnumerable<int>.Enumerator enumerator = enumerable.ConfigureAwait(false).GetAsyncEnumerator(cts.Token);
-            Assert.Equal(cts.Token, enumerable.CancellationToken);
-        }
-
-        [Theory]
-        [InlineData(false)]
-        [InlineData(true)]
-        public void ConfigureAwait_AwaitMoveNextAsync_FlagsSetAppropriately(bool continueOnCapturedContext)
-        {
-            var enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
-            ConfiguredAsyncEnumerable<int>.Enumerator enumerator = enumerable.ConfigureAwait(continueOnCapturedContext).GetAsyncEnumerator();
-            ConfiguredValueTaskAwaitable<bool>.ConfiguredValueTaskAwaiter moveNextAwaiter = enumerator.MoveNextAsync().GetAwaiter();
-            moveNextAwaiter.UnsafeOnCompleted(() => { });
-            Assert.Equal(
-                continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
-                enumerable.Flags);
-        }
-
-        [Theory]
-        [InlineData(false)]
-        [InlineData(true)]
-        public void ConfigureAwait_AwaitDisposeAsync_FlagsSetAppropriately(bool continueOnCapturedContext)
-        {
-            var enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
-            ConfiguredAsyncEnumerable<int>.Enumerator enumerator = enumerable.ConfigureAwait(continueOnCapturedContext).GetAsyncEnumerator();
-            ConfiguredValueTaskAwaitable.ConfiguredValueTaskAwaiter disposeAwaiter = enumerator.DisposeAsync().GetAwaiter();
-            disposeAwaiter.UnsafeOnCompleted(() => { });
-            Assert.Equal(
-                continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
-                enumerable.Flags);
-        }
-
-        [Fact]
-        public async Task ConfigureAwait_CanBeEnumeratedWithStandardPattern()
-        {
-            IAsyncEnumerable<int> asyncEnumerable = new EnumerableWithDelayToAsyncEnumerable<int>(Enumerable.Range(1, 10), 1);
-            int sum = 0;
-
-            ConfiguredAsyncEnumerable<int>.Enumerator e = asyncEnumerable.ConfigureAwait(false).GetAsyncEnumerator();
-            try
-            {
-                while (await e.MoveNextAsync())
-                {
-                    sum += e.Current;
-                }
-            }
-            finally
-            {
-                await e.DisposeAsync();
-            }
-
-            Assert.Equal(55, sum);
-        }
-
-        private sealed class TrackFlagsAsyncEnumerable : IAsyncEnumerable<int>, IAsyncEnumerator<int>, IValueTaskSource<bool>, IValueTaskSource
-        {
-            public ValueTaskSourceOnCompletedFlags Flags;
-            public CancellationToken CancellationToken;
-
-            public IAsyncEnumerator<int> GetAsyncEnumerator(CancellationToken cancellationToken = default)
-            {
-                CancellationToken = cancellationToken;
-                return this;
-            }
-
-            public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(this, 0);
-            public int Current => throw new NotImplementedException();
-            public ValueTask DisposeAsync() => new ValueTask(this, 0);
-
-            public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) => Flags = flags;
-            public ValueTaskSourceStatus GetStatus(short token) => ValueTaskSourceStatus.Pending;
-            public bool GetResult(short token) => throw new NotImplementedException();
-            void IValueTaskSource.GetResult(short token) => throw new NotImplementedException();
-        }
-
-        private sealed class EnumerableWithDelayToAsyncEnumerable<T> : IAsyncEnumerable<T>, IAsyncEnumerator<T>
-        {
-            private readonly int _delayMs;
-            private readonly IEnumerable<T> _enumerable;
-            private IEnumerator<T> _enumerator;
-
-            public EnumerableWithDelayToAsyncEnumerable(IEnumerable<T> enumerable, int delayMs)
-            {
-                _enumerable = enumerable;
-                _delayMs = delayMs;
-            }
-
-            public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
-            {
-                _enumerator = _enumerable.GetEnumerator();
-                return this;
-            }
-
-            public async ValueTask<bool> MoveNextAsync()
-            {
-                await Task.Delay(_delayMs);
-                return _enumerator.MoveNext();
-            }
-
-            public T Current => _enumerator.Current;
-
-            public async ValueTask DisposeAsync()
-            {
-                await Task.Delay(_delayMs);
-                _enumerator.Dispose();
-            }
-        }
-    }
-}
diff --git a/src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/ConfiguredCancelableAsyncEnumerableTests.netcoreapp.cs b/src/libraries/System.Threading.Tasks/tests/System.Runtime.CompilerServices/ConfiguredCancelableAsyncEnumerableTests.netcoreapp.cs
new file mode 100644 (file)
index 0000000..1345543
--- /dev/null
@@ -0,0 +1,216 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using System.Threading.Tasks.Sources;
+using Xunit;
+
+namespace System.Runtime.CompilerServices.Tests
+{
+    public class ConfiguredCancelableAsyncEnumerableTests
+    {
+        [Fact]
+        public void Default_GetAsyncEnumerator_Throws()
+        {
+            ConfiguredCancelableAsyncEnumerable<int> e = default;
+            Assert.Throws<NullReferenceException>(() => e.GetAsyncEnumerator());
+
+            e = ((IAsyncEnumerable<int>)null).ConfigureAwait(false);
+            Assert.Throws<NullReferenceException>(() => e.GetAsyncEnumerator());
+        }
+
+        [Fact]
+        public void Default_EnumeratorMembers_Throws()
+        {
+            ConfiguredCancelableAsyncEnumerable<int>.Enumerator e = default;
+            Assert.Throws<NullReferenceException>(() => e.MoveNextAsync());
+            Assert.Throws<NullReferenceException>(() => e.Current);
+            Assert.Throws<NullReferenceException>(() => e.DisposeAsync());
+        }
+
+        [Fact]
+        public void Default_WithCancellation_ConfigureAwait_NoThrow()
+        {
+            ConfiguredCancelableAsyncEnumerable<int> e = TaskExtensions.WithCancellation((IAsyncEnumerable<int>)null, default);
+            e = e.ConfigureAwait(false);
+            e = e.WithCancellation(default);
+            Assert.Throws<NullReferenceException>(() => e.GetAsyncEnumerator());
+        }
+
+        [Fact]
+        public void Default_ConfigureAwait_WithCancellation_NoThrow()
+        {
+            ConfiguredCancelableAsyncEnumerable<int> e = TaskExtensions.ConfigureAwait((IAsyncEnumerable<int>)null, false);
+            e = e.WithCancellation(default);
+            e = e.ConfigureAwait(false);
+            Assert.Throws<NullReferenceException>(() => e.GetAsyncEnumerator());
+        }
+
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public void ConfigureAwait_AwaitMoveNextAsync_FlagsSetAppropriately(bool continueOnCapturedContext)
+        {
+            TrackFlagsAsyncEnumerable enumerable;
+            CancellationToken token = new CancellationTokenSource().Token;
+
+            // Single ConfigureAwait call
+            enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
+            enumerable.ConfigureAwait(continueOnCapturedContext).GetAsyncEnumerator().MoveNextAsync().GetAwaiter().UnsafeOnCompleted(() => { });
+            Assert.Equal(
+                continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
+                enumerable.Flags);
+
+            // Unnecessary multiple calls; only last one is used
+            enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
+            enumerable.ConfigureAwait(!continueOnCapturedContext).ConfigureAwait(continueOnCapturedContext).GetAsyncEnumerator().MoveNextAsync().GetAwaiter().UnsafeOnCompleted(() => { });
+            Assert.Equal(
+                continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
+                enumerable.Flags);
+
+            // After WithCancellation
+            enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
+            enumerable.WithCancellation(token).ConfigureAwait(continueOnCapturedContext).GetAsyncEnumerator().MoveNextAsync().GetAwaiter().UnsafeOnCompleted(() => { });
+            Assert.Equal(
+                continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
+                enumerable.Flags);
+
+            // Before WithCancellation
+            enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
+            enumerable.ConfigureAwait(continueOnCapturedContext).WithCancellation(token).GetAsyncEnumerator().MoveNextAsync().GetAwaiter().UnsafeOnCompleted(() => { });
+            Assert.Equal(
+                continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
+                enumerable.Flags);
+        }
+
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public void ConfigureAwait_AwaitDisposeAsync_FlagsSetAppropriately(bool continueOnCapturedContext)
+        {
+            var enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
+            enumerable.ConfigureAwait(continueOnCapturedContext).GetAsyncEnumerator().DisposeAsync().GetAwaiter().UnsafeOnCompleted(() => { });
+            Assert.Equal(
+                continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
+                enumerable.Flags);
+        }
+
+        [Fact]
+        public async Task CanBeEnumeratedWithStandardPattern()
+        {
+            IAsyncEnumerable<int> asyncEnumerable = new EnumerableWithDelayToAsyncEnumerable<int>(Enumerable.Range(1, 10), 1);
+            int sum = 0;
+
+            ConfiguredCancelableAsyncEnumerable<int>.Enumerator e = asyncEnumerable.ConfigureAwait(false).WithCancellation(new CancellationTokenSource().Token).GetAsyncEnumerator();
+            try
+            {
+                while (await e.MoveNextAsync())
+                {
+                    sum += e.Current;
+                }
+            }
+            finally
+            {
+                await e.DisposeAsync();
+            }
+
+            Assert.Equal(55, sum);
+        }
+
+        [Fact]
+        public void WithCancellation_TokenPassedThrough()
+        {
+            TrackFlagsAsyncEnumerable enumerable;
+            CancellationToken token1 = new CancellationTokenSource().Token;
+            CancellationToken token2 = new CancellationTokenSource().Token;
+
+            // No WithCancellation call
+            enumerable = new TrackFlagsAsyncEnumerable();
+            enumerable.GetAsyncEnumerator();
+            Assert.Equal(CancellationToken.None, enumerable.CancellationToken);
+
+            // Only ConfigureAwait call
+            enumerable = new TrackFlagsAsyncEnumerable();
+            enumerable.ConfigureAwait(false).GetAsyncEnumerator();
+            Assert.Equal(CancellationToken.None, enumerable.CancellationToken);
+
+            // Single WithCancellation call
+            enumerable = new TrackFlagsAsyncEnumerable();
+            enumerable.WithCancellation(token1).GetAsyncEnumerator();
+            Assert.Equal(token1, enumerable.CancellationToken);
+
+            // Unnecessary multiple calls; only last one is used
+            enumerable = new TrackFlagsAsyncEnumerable();
+            enumerable.WithCancellation(token1).WithCancellation(token2).GetAsyncEnumerator();
+            Assert.Equal(token2, enumerable.CancellationToken);
+
+            // Before ConfigureAwait
+            enumerable = new TrackFlagsAsyncEnumerable();
+            enumerable.WithCancellation(token1).ConfigureAwait(false).GetAsyncEnumerator();
+            Assert.Equal(token1, enumerable.CancellationToken);
+
+            // After ConfigureAwait
+            enumerable = new TrackFlagsAsyncEnumerable();
+            enumerable.ConfigureAwait(false).WithCancellation(token1).GetAsyncEnumerator();
+            Assert.Equal(token1, enumerable.CancellationToken);
+        }
+
+        private sealed class TrackFlagsAsyncEnumerable : IAsyncEnumerable<int>, IAsyncEnumerator<int>, IValueTaskSource<bool>, IValueTaskSource
+        {
+            public ValueTaskSourceOnCompletedFlags Flags;
+            public CancellationToken CancellationToken;
+
+            public IAsyncEnumerator<int> GetAsyncEnumerator(CancellationToken cancellationToken = default)
+            {
+                CancellationToken = cancellationToken;
+                return this;
+            }
+
+            public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(this, 0);
+            public int Current => throw new NotImplementedException();
+            public ValueTask DisposeAsync() => new ValueTask(this, 0);
+
+            public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) => Flags = flags;
+            public ValueTaskSourceStatus GetStatus(short token) => ValueTaskSourceStatus.Pending;
+            public bool GetResult(short token) => throw new NotImplementedException();
+            void IValueTaskSource.GetResult(short token) => throw new NotImplementedException();
+        }
+
+        private sealed class EnumerableWithDelayToAsyncEnumerable<T> : IAsyncEnumerable<T>, IAsyncEnumerator<T>
+        {
+            private readonly int _delayMs;
+            private readonly IEnumerable<T> _enumerable;
+            private IEnumerator<T> _enumerator;
+
+            public EnumerableWithDelayToAsyncEnumerable(IEnumerable<T> enumerable, int delayMs)
+            {
+                _enumerable = enumerable;
+                _delayMs = delayMs;
+            }
+
+            public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
+            {
+                _enumerator = _enumerable.GetEnumerator();
+                return this;
+            }
+
+            public async ValueTask<bool> MoveNextAsync()
+            {
+                await Task.Delay(_delayMs);
+                return _enumerator.MoveNext();
+            }
+
+            public T Current => _enumerator.Current;
+
+            public async ValueTask DisposeAsync()
+            {
+                await Task.Delay(_delayMs);
+                _enumerator.Dispose();
+            }
+        }
+    }
+}
index b3e868d1c2303802e63fc6aee5687baf9c86feec..e43398d4dc8238b2b653b777672657a475576984 100644 (file)
@@ -58,7 +58,7 @@
     <Compile Include="Task\TaskCanceledExceptionTests.netcoreapp.cs" />
     <Compile Include="Task\TaskStatusTest.netcoreapp.cs" />
     <Compile Include="System.Runtime.CompilerServices\AsyncTaskMethodBuilderTests.netcoreapp.cs" />
-    <Compile Include="System.Runtime.CompilerServices\ConfiguredAsyncEnumerableTests.netcoreapp.cs" />
+    <Compile Include="System.Runtime.CompilerServices\ConfiguredCancelableAsyncEnumerableTests.netcoreapp.cs" />
     <Compile Include="$(CommonTestPath)\System\Diagnostics\Tracing\TestEventListener.cs">
       <Link>Common\System\Diagnostics\Tracing\TestEventListener.cs</Link>
     </Compile>