fix CopyToAsyncCore - cancellation and infinite loop (dotnet/corefx#37848)
authorVladislav Richter <shutdown256@gmail.com>
Thu, 6 Jun 2019 09:05:53 +0000 (11:05 +0200)
committerDavid Fowler <davidfowl@gmail.com>
Thu, 6 Jun 2019 09:05:53 +0000 (02:05 -0700)
ReadResult result = await ReadAsync(cancellationToken).ConfigureAwait(false); should be before the try/finally block because if you cancel the read operation the finally clause will try to advance reader that is not in reading state and instead of OperationCancelledException you will end up with InvalidOperationException.

There is a bug either in PipeReader.CopyToAsyncCore() or ReadOnlySequence.TryGet():

When ReadOnlySequence.TryGet() reaches final segment it will return true and this final memory but it will also set position as default(SequencePosition) - I don't know if this is by design but CopyToAsyncCore() method does not take it in consideration and it will copy the memory but not advance the reader - this will cause it to repeat this data indefinitely.

* Fixed test CopyToAsyncWorks() to make sure that multiple separate reads from the reader produce expected result.
Added test to verify that cancellation between reads from the Reader throws OperationCancelledException.

Commit migrated from https://github.com/dotnet/corefx/commit/71aec681dd1247979e7956372abae8bbc02c9546

src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/PipeReader.cs
src/libraries/System.IO.Pipelines/tests/Infrastructure/WriteCheckStream.cs [new file with mode: 0644]
src/libraries/System.IO.Pipelines/tests/PipeReaderCopyToAsyncTests.cs
src/libraries/System.IO.Pipelines/tests/System.IO.Pipelines.Tests.csproj

index beb2e14..4cafd31 100644 (file)
@@ -113,9 +113,9 @@ namespace System.IO.Pipelines
             {
                 SequencePosition consumed = default;
 
+                ReadResult result = await ReadAsync(cancellationToken).ConfigureAwait(false);
                 try
                 {
-                    ReadResult result = await ReadAsync(cancellationToken).ConfigureAwait(false);
                     ReadOnlySequence<byte> buffer = result.Buffer;
                     SequencePosition position = buffer.Start;
 
@@ -131,6 +131,11 @@ namespace System.IO.Pipelines
                         consumed = position;
                     }
 
+                    if (consumed.Equals(default))
+                    {
+                        consumed = buffer.End;
+                    }
+
                     if (result.IsCompleted)
                     {
                         break;
diff --git a/src/libraries/System.IO.Pipelines/tests/Infrastructure/WriteCheckStream.cs b/src/libraries/System.IO.Pipelines/tests/Infrastructure/WriteCheckStream.cs
new file mode 100644 (file)
index 0000000..2cc77f8
--- /dev/null
@@ -0,0 +1,72 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.IO.Pipelines.Tests.Infrastructure
+{
+    public class WriteCheckMemoryStream : Stream
+    {
+        private readonly MemoryStream _ms = new MemoryStream();
+        private int _writeCnt = 0;
+        private int _waitCnt = 0;
+        private TaskCompletionSource<object> _waitSource = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
+        public CancellationTokenSource MidWriteCancellation { get; set; }
+        private readonly object _lock = new object();
+
+        public override bool CanRead => _ms.CanRead;
+
+        public override bool CanSeek => _ms.CanSeek;
+
+        public override bool CanWrite => _ms.CanWrite;
+
+        public override long Length => _ms.Length;
+
+        public override long Position { get => _ms.Position; set => _ms.Position = value; }
+
+        public override void Flush() => _ms.Flush();
+        public override int Read(byte[] buffer, int offset, int count) => _ms.Read(buffer, offset, count);
+        public override long Seek(long offset, SeekOrigin origin) => _ms.Seek(offset, origin);
+        public override void SetLength(long value) => _ms.SetLength(value);
+
+        public override void Write(byte[] buffer, int offset, int count)
+        {
+            lock (_lock)
+            {
+                _ms.Write(buffer, offset, count);
+                MidWriteCancellation?.Cancel();
+                _writeCnt += count;
+                CheckWaitCount();
+            }
+        }
+
+        public byte[] ToArray() => _ms.ToArray();
+
+        public Task WaitForBytesWrittenAsync(int cnt)
+        {
+            if (cnt <= 0)
+                throw new ArgumentException($"{nameof(cnt)} must be greater than 0");
+            lock (_lock)
+            {
+                _waitCnt = cnt;
+                _waitSource.TrySetCanceled();
+                _waitSource = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
+                CheckWaitCount();
+                return _waitSource.Task;
+            }
+        }
+        
+        private void CheckWaitCount() 
+        {
+            if (_waitCnt > 0 && _writeCnt >= _waitCnt)
+            {
+                _writeCnt = 0;
+                _waitCnt = 0;
+                _waitSource.TrySetResult(null);
+            }
+        }
+
+        public override string ToString() => Encoding.ASCII.GetString(_ms.ToArray());
+    }
+}
index a0284d1..6717cd8 100644 (file)
@@ -4,6 +4,8 @@
 
 using System;
 using System.Collections.Generic;
+using System.IO.Pipelines.Tests.Infrastructure;
+using System.Linq;
 using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
@@ -33,17 +35,26 @@ namespace System.IO.Pipelines.Tests
         [Fact]
         public async Task CopyToAsyncWorks()
         {
-            var helloBytes = Encoding.UTF8.GetBytes("Hello World");
+            var messages = new List<byte[]>()
+            {
+                Encoding.UTF8.GetBytes("Hello World1"),
+                Encoding.UTF8.GetBytes("Hello World2"),
+                Encoding.UTF8.GetBytes("Hello World3"),
+            };
 
             var pipe = new Pipe(s_testOptions);
-            await pipe.Writer.WriteAsync(helloBytes);
-            pipe.Writer.Complete();
-
-            var stream = new MemoryStream();
-            await pipe.Reader.CopyToAsync(stream);
-            pipe.Reader.Complete();
+            var stream = new WriteCheckMemoryStream();
 
-            Assert.Equal(helloBytes, stream.ToArray());
+            Task task = pipe.Reader.CopyToAsync(stream);
+            foreach (var msg in messages)
+            {
+                await pipe.Writer.WriteAsync(msg);
+                await stream.WaitForBytesWrittenAsync(msg.Length);
+            }
+            pipe.Writer.Complete();
+            await task;
+            
+            Assert.Equal(messages.SelectMany(msg => msg).ToArray(), stream.ToArray());
         }
 
         [Fact]
@@ -129,6 +140,18 @@ namespace System.IO.Pipelines.Tests
         }
 
         [Fact]
+        public async Task CancelingBetweenReadsThrowsOperationCancelledException()
+        {
+            var pipe = new Pipe(s_testOptions);
+            var stream = new WriteCheckMemoryStream { MidWriteCancellation = new CancellationTokenSource() };
+            Task task = pipe.Reader.CopyToAsync(stream, stream.MidWriteCancellation.Token);
+            pipe.Writer.WriteEmpty(10);
+            await pipe.Writer.FlushAsync();
+
+            await Assert.ThrowsAsync<OperationCanceledException>(() => task);
+        }
+
+        [Fact]
         public async Task CancelingViaCancellationTokenThrowsOperationCancelledException()
         {
             var pipe = new Pipe(s_testOptions);
index 2c2b940..deb70c1 100644 (file)
@@ -13,6 +13,7 @@
     <Compile Include="Infrastructure\HeapBufferPool.cs" />
     <Compile Include="Infrastructure\DisposeTrackingBufferPool.cs" />
     <Compile Include="Infrastructure\ReadOnlyStream.cs" />
+    <Compile Include="Infrastructure\WriteCheckStream.cs" />
     <Compile Include="Infrastructure\WriteOnlyStream.cs" />
     <Compile Include="Infrastructure\CancelledWritesStream.cs" />
     <Compile Include="Infrastructure\CancelledReadsStream.cs" />