PipeReader.CopyToAsync(Stream) should not perform zero-length writes (#85768)
authorMarc Gravell <marc.gravell@gmail.com>
Sat, 20 May 2023 21:36:24 +0000 (22:36 +0100)
committerGitHub <noreply@github.com>
Sat, 20 May 2023 21:36:24 +0000 (17:36 -0400)
* add failing test for zero-length write

* simplify test

* apply fix

src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/PipeReader.cs
src/libraries/System.IO.Pipelines/tests/PipeReaderCopyToAsyncTests.cs

index 9c7803f..5d9135f 100644 (file)
@@ -248,18 +248,27 @@ namespace System.IO.Pipelines
 
                     while (buffer.TryGet(ref position, out ReadOnlyMemory<byte> memory))
                     {
-                        FlushResult flushResult = await writeAsync(destination, memory, cancellationToken).ConfigureAwait(false);
-
-                        if (flushResult.IsCanceled)
+                        if (memory.IsEmpty)
                         {
-                            ThrowHelper.ThrowOperationCanceledException_FlushCanceled();
+                            // advance tracking only (to account for any boundary scenarios)
+                            consumed = position;
                         }
+                        else
+                        {
+                            // write and advance
+                            FlushResult flushResult = await writeAsync(destination, memory, cancellationToken).ConfigureAwait(false);
 
-                        consumed = position;
+                            if (flushResult.IsCanceled)
+                            {
+                                ThrowHelper.ThrowOperationCanceledException_FlushCanceled();
+                            }
 
-                        if (flushResult.IsCompleted)
-                        {
-                            return;
+                            consumed = position;
+
+                            if (flushResult.IsCompleted)
+                            {
+                                return;
+                            }
                         }
                     }
 
index f6e87e2..6b8dc8b 100644 (file)
@@ -322,5 +322,60 @@ namespace System.IO.Pipelines.Tests
 
             Assert.Equal(buffer.AsMemory(5).ToArray(), ms.ToArray());
         }
+
+        [Fact]
+        public async Task CopyToAsyncStreamDoesNotDoZeroLengthWrite()
+        {
+            using var ms = new LengthCheckStream();
+            var incompleteCopy = Task.Run(() => PipeReader.CopyToAsync(ms));
+            Pipe.Writer.Write(Array.Empty<byte>());
+            await Pipe.Writer.FlushAsync();
+            Pipe.Writer.Complete(null);
+            await incompleteCopy;
+            Assert.False(ms.ZeroLengthWriteDetected);
+        }
+
+        class LengthCheckStream : MemoryStream
+        {
+            public bool ZeroLengthWriteDetected { get; private set; }
+
+            private void Check(int count)
+            {
+                if (count == 0) ZeroLengthWriteDetected = true;
+            }
+
+            public override void Write(byte[] buffer, int offset, int count)
+            {
+                Check(count);
+                base.Write(buffer, offset, count);
+            }
+#if NETCOREAPP3_0_OR_GREATER
+            public override void Write(ReadOnlySpan<byte> buffer)
+            {
+                Check(buffer.Length);
+                base.Write(buffer);
+            }
+            public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+            {
+                Check(buffer.Length);
+                return base.WriteAsync(buffer, cancellationToken);
+            }
+#endif
+            public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+            {
+                Check(count);
+                return base.WriteAsync(buffer, offset, count, cancellationToken);
+            }
+            public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
+            {
+                Check(count);
+                return base.BeginWrite(buffer, offset, count, callback, state);
+            }
+            public override void WriteByte(byte value)
+            {
+                Check(1);
+                base.WriteByte(value);
+            }
+        }
     }
 }