From 5d6b5bda85f9755a6716b44e517fe2bc6fc93de3 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 2 Dec 2019 19:30:00 -0500 Subject: [PATCH] Add ChannelReader.{Can}Count (#312) --- .../ref/System.Threading.Channels.cs | 2 ++ .../System/Threading/Channels/BoundedChannel.cs | 25 +++++++++++++-- .../src/System/Threading/Channels/ChannelReader.cs | 7 ++++ .../System/Threading/Channels/UnboundedChannel.cs | 9 +++--- .../tests/BoundedChannelTests.cs | 35 ++++++++++++++++++++ .../tests/ChannelTestBase.cs | 10 ++++++ .../tests/UnboundedChannelTests.cs | 37 ++++++++++++++++++++++ 7 files changed, 118 insertions(+), 7 deletions(-) diff --git a/src/libraries/System.Threading.Channels/ref/System.Threading.Channels.cs b/src/libraries/System.Threading.Channels/ref/System.Threading.Channels.cs index 51612ad..9bdaa4c 100644 --- a/src/libraries/System.Threading.Channels/ref/System.Threading.Channels.cs +++ b/src/libraries/System.Threading.Channels/ref/System.Threading.Channels.cs @@ -44,7 +44,9 @@ namespace System.Threading.Channels public abstract partial class ChannelReader { protected ChannelReader() { } + public virtual bool CanCount { get { throw null; } } public virtual System.Threading.Tasks.Task Completion { get { throw null; } } + public virtual int Count { get { throw null; } } public virtual System.Threading.Tasks.ValueTask ReadAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public abstract bool TryRead([System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out T item); public abstract System.Threading.Tasks.ValueTask WaitToReadAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs index 135de5c..3f7a967 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs @@ -69,6 +69,28 @@ namespace System.Threading.Channels public override Task Completion => _parent._completion.Task; + public override bool CanCount => true; + + public override int Count + { + get + { + BoundedChannel parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + return parent._items.Count; + } + } + } + + /// Gets the number of items in the channel. This should only be used by the debugger. + /// + /// Unlike , avoids locking so as to not block the debugger if another suspended thread is holding the lock. + /// Hence, this must only be used from the debugger in a serialized context. + /// + private int ItemsCountForDebugger => _parent._items.Count; + public override bool TryRead([MaybeNullWhen(false)] out T item) { BoundedChannel parent = _parent; @@ -229,9 +251,6 @@ namespace System.Threading.Channels return item; } - /// Gets the number of items in the channel. This should only be used by the debugger. - private int ItemsCountForDebugger => _parent._items.Count; - /// Gets an enumerator the debugger can use to show the contents of the channel. IEnumerator IDebugEnumerable.GetEnumerator() => _parent._items.GetEnumerator(); } diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs index c597426..834f56f 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs @@ -19,6 +19,13 @@ namespace System.Threading.Channels /// public virtual Task Completion => ChannelUtilities.s_neverCompletingTask; + /// Gets whether is available for use on this instance. + public virtual bool CanCount => false; + + /// Gets the current number of items available from this channel reader. + /// Counting is not supported on this instance. + public virtual int Count => throw new NotSupportedException(); + /// Attempts to read an item from the channel. /// The read item, or a default value if no item could be read. /// true if an item was read; otherwise, false if no item was read. diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs index e09128b..cd9c10b 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs @@ -38,7 +38,7 @@ namespace System.Threading.Channels Writer = new UnboundedChannelWriter(this); } - [DebuggerDisplay("Items={ItemsCountForDebugger}")] + [DebuggerDisplay("Items={Count}")] [DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))] private sealed class UnboundedChannelReader : ChannelReader, IDebugEnumerable { @@ -55,6 +55,10 @@ namespace System.Threading.Channels public override Task Completion => _parent._completion.Task; + public override bool CanCount => true; + + public override int Count => _parent._items.Count; + public override ValueTask ReadAsync(CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) @@ -179,9 +183,6 @@ namespace System.Threading.Channels } } - /// Gets the number of items in the channel. This should only be used by the debugger. - private int ItemsCountForDebugger => _parent._items.Count; - /// Gets an enumerator the debugger can use to show the contents of the channel. IEnumerator IDebugEnumerable.GetEnumerator() => _parent._items.GetEnumerator(); } diff --git a/src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs b/src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs index 903d85c..ea27c14 100644 --- a/src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs +++ b/src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs @@ -17,6 +17,41 @@ namespace System.Threading.Channels.Tests return c; } + [Fact] + public void Count_IncrementsDecrementsAsExpected() + { + const int Bound = 3; + + Channel c = Channel.CreateBounded(Bound); + Assert.True(c.Reader.CanCount); + + for (int iter = 0; iter < 2; iter++) + { + for (int i = 0; i < Bound; i++) + { + Assert.Equal(i, c.Reader.Count); + Assert.True(c.Writer.TryWrite(i)); + Assert.Equal(i + 1, c.Reader.Count); + } + + Assert.False(c.Writer.TryWrite(42)); + Assert.Equal(Bound, c.Reader.Count); + + if (iter != 0) + { + c.Writer.Complete(); + } + + for (int i = 0; i < Bound; i++) + { + Assert.Equal(Bound - i, c.Reader.Count); + Assert.True(c.Reader.TryRead(out int item)); + Assert.Equal(i, item); + Assert.Equal(Bound - (i + 1), c.Reader.Count); + } + } + } + [Theory] [InlineData(1)] [InlineData(10)] diff --git a/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs b/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs index 60515f8..42721f1 100644 --- a/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs +++ b/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs @@ -130,6 +130,16 @@ namespace System.Threading.Channels.Tests } [Fact] + public void Count_ThrowsIfUnsupported() + { + Channel c = CreateChannel(); + if (!c.Reader.CanCount) + { + Assert.Throws(() => c.Reader.Count); + } + } + + [Fact] public void SingleProducerConsumer_ConcurrentReadWrite_Success() { Channel c = CreateChannel(); diff --git a/src/libraries/System.Threading.Channels/tests/UnboundedChannelTests.cs b/src/libraries/System.Threading.Channels/tests/UnboundedChannelTests.cs index 4688961..707ff15 100644 --- a/src/libraries/System.Threading.Channels/tests/UnboundedChannelTests.cs +++ b/src/libraries/System.Threading.Channels/tests/UnboundedChannelTests.cs @@ -19,6 +19,43 @@ namespace System.Threading.Channels.Tests protected override Channel CreateFullChannel() => null; [Fact] + public void Count_IncrementsDecrementsAsExpected() + { + Channel c = CreateChannel(); + + if (RequiresSingleReader) + { + Assert.False(c.Reader.CanCount); + return; + } + + Assert.True(c.Reader.CanCount); + + const int ItemsToAdd = 5; + for (int iter = 0; iter < 2; iter++) + { + for (int i = 0; i < ItemsToAdd; i++) + { + Assert.Equal(i, c.Reader.Count); + Assert.True(c.Writer.TryWrite(i)); + Assert.Equal(i + 1, c.Reader.Count); + } + + if (iter != 0) + { + c.Writer.Complete(); + } + + while (c.Reader.TryRead(out int item)) + { + Assert.Equal(ItemsToAdd - (item + 1), c.Reader.Count); + } + + Assert.Equal(0, c.Reader.Count); + } + } + + [Fact] public async Task Complete_BeforeEmpty_NoWaiters_TriggersCompletion() { Channel c = CreateChannel(); -- 2.7.4