From 383947f2956a8b8128269ec90bcf91da95fd359c Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Sat, 20 May 2023 22:36:24 +0100 Subject: [PATCH] PipeReader.CopyToAsync(Stream) should not perform zero-length writes (#85768) * add failing test for zero-length write * simplify test * apply fix --- .../src/System/IO/Pipelines/PipeReader.cs | 25 ++++++---- .../tests/PipeReaderCopyToAsyncTests.cs | 55 ++++++++++++++++++++++ 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/PipeReader.cs b/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/PipeReader.cs index 9c7803f..5d9135f 100644 --- a/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/PipeReader.cs +++ b/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/PipeReader.cs @@ -248,18 +248,27 @@ namespace System.IO.Pipelines while (buffer.TryGet(ref position, out ReadOnlyMemory memory)) { - FlushResult flushResult = await writeAsync(destination, memory, cancellationToken).ConfigureAwait(false); - - if (flushResult.IsCanceled) + if (memory.IsEmpty) { - ThrowHelper.ThrowOperationCanceledException_FlushCanceled(); + // advance tracking only (to account for any boundary scenarios) + consumed = position; } + else + { + // write and advance + FlushResult flushResult = await writeAsync(destination, memory, cancellationToken).ConfigureAwait(false); - consumed = position; + if (flushResult.IsCanceled) + { + ThrowHelper.ThrowOperationCanceledException_FlushCanceled(); + } - if (flushResult.IsCompleted) - { - return; + consumed = position; + + if (flushResult.IsCompleted) + { + return; + } } } diff --git a/src/libraries/System.IO.Pipelines/tests/PipeReaderCopyToAsyncTests.cs b/src/libraries/System.IO.Pipelines/tests/PipeReaderCopyToAsyncTests.cs index f6e87e2..6b8dc8b 100644 --- a/src/libraries/System.IO.Pipelines/tests/PipeReaderCopyToAsyncTests.cs +++ b/src/libraries/System.IO.Pipelines/tests/PipeReaderCopyToAsyncTests.cs @@ -322,5 +322,60 @@ namespace System.IO.Pipelines.Tests Assert.Equal(buffer.AsMemory(5).ToArray(), ms.ToArray()); } + + [Fact] + public async Task CopyToAsyncStreamDoesNotDoZeroLengthWrite() + { + using var ms = new LengthCheckStream(); + var incompleteCopy = Task.Run(() => PipeReader.CopyToAsync(ms)); + Pipe.Writer.Write(Array.Empty()); + await Pipe.Writer.FlushAsync(); + Pipe.Writer.Complete(null); + await incompleteCopy; + Assert.False(ms.ZeroLengthWriteDetected); + } + + class LengthCheckStream : MemoryStream + { + public bool ZeroLengthWriteDetected { get; private set; } + + private void Check(int count) + { + if (count == 0) ZeroLengthWriteDetected = true; + } + + public override void Write(byte[] buffer, int offset, int count) + { + Check(count); + base.Write(buffer, offset, count); + } +#if NETCOREAPP3_0_OR_GREATER + public override void Write(ReadOnlySpan buffer) + { + Check(buffer.Length); + base.Write(buffer); + } + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + Check(buffer.Length); + return base.WriteAsync(buffer, cancellationToken); + } +#endif + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Check(count); + return base.WriteAsync(buffer, offset, count, cancellationToken); + } + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + Check(count); + return base.BeginWrite(buffer, offset, count, callback, state); + } + public override void WriteByte(byte value) + { + Check(1); + base.WriteByte(value); + } + } } } -- 2.7.4