while (buffer.TryGet(ref position, out ReadOnlyMemory<byte> memory))
{
- FlushResult flushResult = await writeAsync(destination, memory, cancellationToken).ConfigureAwait(false);
-
- if (flushResult.IsCanceled)
+ if (memory.IsEmpty)
{
- ThrowHelper.ThrowOperationCanceledException_FlushCanceled();
+ // advance tracking only (to account for any boundary scenarios)
+ consumed = position;
}
+ else
+ {
+ // write and advance
+ FlushResult flushResult = await writeAsync(destination, memory, cancellationToken).ConfigureAwait(false);
- consumed = position;
+ if (flushResult.IsCanceled)
+ {
+ ThrowHelper.ThrowOperationCanceledException_FlushCanceled();
+ }
- if (flushResult.IsCompleted)
- {
- return;
+ consumed = position;
+
+ if (flushResult.IsCompleted)
+ {
+ return;
+ }
}
}
Assert.Equal(buffer.AsMemory(5).ToArray(), ms.ToArray());
}
+
+ [Fact]
+ public async Task CopyToAsyncStreamDoesNotDoZeroLengthWrite()
+ {
+ using var ms = new LengthCheckStream();
+ var incompleteCopy = Task.Run(() => PipeReader.CopyToAsync(ms));
+ Pipe.Writer.Write(Array.Empty<byte>());
+ await Pipe.Writer.FlushAsync();
+ Pipe.Writer.Complete(null);
+ await incompleteCopy;
+ Assert.False(ms.ZeroLengthWriteDetected);
+ }
+
+ class LengthCheckStream : MemoryStream
+ {
+ public bool ZeroLengthWriteDetected { get; private set; }
+
+ private void Check(int count)
+ {
+ if (count == 0) ZeroLengthWriteDetected = true;
+ }
+
+ public override void Write(byte[] buffer, int offset, int count)
+ {
+ Check(count);
+ base.Write(buffer, offset, count);
+ }
+#if NETCOREAPP3_0_OR_GREATER
+ public override void Write(ReadOnlySpan<byte> buffer)
+ {
+ Check(buffer.Length);
+ base.Write(buffer);
+ }
+ public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+ {
+ Check(buffer.Length);
+ return base.WriteAsync(buffer, cancellationToken);
+ }
+#endif
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ Check(count);
+ return base.WriteAsync(buffer, offset, count, cancellationToken);
+ }
+ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
+ {
+ Check(count);
+ return base.BeginWrite(buffer, offset, count, callback, state);
+ }
+ public override void WriteByte(byte value)
+ {
+ Check(1);
+ base.WriteByte(value);
+ }
+ }
}
}