Add ChannelReader<T>.{Can}Count (#312)
authorStephen Toub <stoub@microsoft.com>
Tue, 3 Dec 2019 00:30:00 +0000 (19:30 -0500)
committerGitHub <noreply@github.com>
Tue, 3 Dec 2019 00:30:00 +0000 (19:30 -0500)
src/libraries/System.Threading.Channels/ref/System.Threading.Channels.cs
src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs
src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs
src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs
src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs
src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs
src/libraries/System.Threading.Channels/tests/UnboundedChannelTests.cs

index 51612ad..9bdaa4c 100644 (file)
@@ -44,7 +44,9 @@ namespace System.Threading.Channels
     public abstract partial class ChannelReader<T>
     {
         protected ChannelReader() { }
+        public virtual bool CanCount { get { throw null; } }
         public virtual System.Threading.Tasks.Task Completion { get { throw null; } }
+        public virtual int Count { get { throw null; } }
         public virtual System.Threading.Tasks.ValueTask<T> ReadAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
         public abstract bool TryRead([System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out T item);
         public abstract System.Threading.Tasks.ValueTask<bool> WaitToReadAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
index 135de5c..3f7a967 100644 (file)
@@ -69,6 +69,28 @@ namespace System.Threading.Channels
 
             public override Task Completion => _parent._completion.Task;
 
+            public override bool CanCount => true;
+
+            public override int Count
+            {
+                get
+                {
+                    BoundedChannel<T> parent = _parent;
+                    lock (parent.SyncObj)
+                    {
+                        parent.AssertInvariants();
+                        return parent._items.Count;
+                    }
+                }
+            }
+
+            /// <summary>Gets the number of items in the channel. This should only be used by the debugger.</summary>
+            /// <remarks>
+            /// Unlike <see cref="Count"/>, avoids locking so as to not block the debugger if another suspended thread is holding the lock.
+            /// Hence, this must only be used from the debugger in a serialized context.
+            /// </remarks>
+            private int ItemsCountForDebugger => _parent._items.Count;
+
             public override bool TryRead([MaybeNullWhen(false)] out T item)
             {
                 BoundedChannel<T> parent = _parent;
@@ -229,9 +251,6 @@ namespace System.Threading.Channels
                 return item;
             }
 
-            /// <summary>Gets the number of items in the channel. This should only be used by the debugger.</summary>
-            private int ItemsCountForDebugger => _parent._items.Count;
-
             /// <summary>Gets an enumerator the debugger can use to show the contents of the channel.</summary>
             IEnumerator<T> IDebugEnumerable<T>.GetEnumerator() => _parent._items.GetEnumerator();
         }
index c597426..834f56f 100644 (file)
@@ -19,6 +19,13 @@ namespace System.Threading.Channels
         /// </summary>
         public virtual Task Completion => ChannelUtilities.s_neverCompletingTask;
 
+        /// <summary>Gets whether <see cref="Count"/> is available for use on this <see cref="ChannelReader{T}"/> instance.</summary>
+        public virtual bool CanCount => false;
+
+        /// <summary>Gets the current number of items available from this channel reader.</summary>
+        /// <exception cref="NotSupportedException">Counting is not supported on this instance.</exception>
+        public virtual int Count => throw new NotSupportedException();
+
         /// <summary>Attempts to read an item from the channel.</summary>
         /// <param name="item">The read item, or a default value if no item could be read.</param>
         /// <returns>true if an item was read; otherwise, false if no item was read.</returns>
index e09128b..cd9c10b 100644 (file)
@@ -38,7 +38,7 @@ namespace System.Threading.Channels
             Writer = new UnboundedChannelWriter(this);
         }
 
-        [DebuggerDisplay("Items={ItemsCountForDebugger}")]
+        [DebuggerDisplay("Items={Count}")]
         [DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
         private sealed class UnboundedChannelReader : ChannelReader<T>, IDebugEnumerable<T>
         {
@@ -55,6 +55,10 @@ namespace System.Threading.Channels
 
             public override Task Completion => _parent._completion.Task;
 
+            public override bool CanCount => true;
+
+            public override int Count => _parent._items.Count;
+
             public override ValueTask<T> ReadAsync(CancellationToken cancellationToken)
             {
                 if (cancellationToken.IsCancellationRequested)
@@ -179,9 +183,6 @@ namespace System.Threading.Channels
                 }
             }
 
-            /// <summary>Gets the number of items in the channel.  This should only be used by the debugger.</summary>
-            private int ItemsCountForDebugger => _parent._items.Count;
-
             /// <summary>Gets an enumerator the debugger can use to show the contents of the channel.</summary>
             IEnumerator<T> IDebugEnumerable<T>.GetEnumerator() => _parent._items.GetEnumerator();
         }
index 903d85c..ea27c14 100644 (file)
@@ -17,6 +17,41 @@ namespace System.Threading.Channels.Tests
             return c;
         }
 
+        [Fact]
+        public void Count_IncrementsDecrementsAsExpected()
+        {
+            const int Bound = 3;
+
+            Channel<int> c = Channel.CreateBounded<int>(Bound);
+            Assert.True(c.Reader.CanCount);
+
+            for (int iter = 0; iter < 2; iter++)
+            {
+                for (int i = 0; i < Bound; i++)
+                {
+                    Assert.Equal(i, c.Reader.Count);
+                    Assert.True(c.Writer.TryWrite(i));
+                    Assert.Equal(i + 1, c.Reader.Count);
+                }
+
+                Assert.False(c.Writer.TryWrite(42));
+                Assert.Equal(Bound, c.Reader.Count);
+
+                if (iter != 0)
+                {
+                    c.Writer.Complete();
+                }
+
+                for (int i = 0; i < Bound; i++)
+                {
+                    Assert.Equal(Bound - i, c.Reader.Count);
+                    Assert.True(c.Reader.TryRead(out int item));
+                    Assert.Equal(i, item);
+                    Assert.Equal(Bound - (i + 1), c.Reader.Count);
+                }
+            }
+        }
+
         [Theory]
         [InlineData(1)]
         [InlineData(10)]
index 60515f8..42721f1 100644 (file)
@@ -130,6 +130,16 @@ namespace System.Threading.Channels.Tests
         }
 
         [Fact]
+        public void Count_ThrowsIfUnsupported()
+        {
+            Channel<int> c = CreateChannel();
+            if (!c.Reader.CanCount)
+            {
+                Assert.Throws<NotSupportedException>(() => c.Reader.Count);
+            }
+        }
+
+        [Fact]
         public void SingleProducerConsumer_ConcurrentReadWrite_Success()
         {
             Channel<int> c = CreateChannel();
index 4688961..707ff15 100644 (file)
@@ -19,6 +19,43 @@ namespace System.Threading.Channels.Tests
         protected override Channel<T> CreateFullChannel<T>() => null;
 
         [Fact]
+        public void Count_IncrementsDecrementsAsExpected()
+        {
+            Channel<int> c = CreateChannel();
+
+            if (RequiresSingleReader)
+            {
+                Assert.False(c.Reader.CanCount);
+                return;
+            }
+
+            Assert.True(c.Reader.CanCount);
+
+            const int ItemsToAdd = 5;
+            for (int iter = 0; iter < 2; iter++)
+            {
+                for (int i = 0; i < ItemsToAdd; i++)
+                {
+                    Assert.Equal(i, c.Reader.Count);
+                    Assert.True(c.Writer.TryWrite(i));
+                    Assert.Equal(i + 1, c.Reader.Count);
+                }
+
+                if (iter != 0)
+                {
+                    c.Writer.Complete();
+                }
+
+                while (c.Reader.TryRead(out int item))
+                {
+                    Assert.Equal(ItemsToAdd - (item + 1), c.Reader.Count);
+                }
+
+                Assert.Equal(0, c.Reader.Count);
+            }
+        }
+
+        [Fact]
         public async Task Complete_BeforeEmpty_NoWaiters_TriggersCompletion()
         {
             Channel<int> c = CreateChannel();