From: Emmanuel André <2341261+manandre@users.noreply.github.com> Date: Sat, 7 Dec 2019 01:09:01 +0000 (+0100) Subject: Test span-based CopyTo/Async method (#390) X-Git-Tag: submit/tizen/20210909.063632~10757 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9af8a4e93bc86006925414aa1f91852d5bbb29f1;p=platform%2Fupstream%2Fdotnet%2Fruntime.git Test span-based CopyTo/Async method (#390) * Add tests for span-based CopyTo/Async methods * Add tests for cross sync streams A custom stream implementation can override CopyTo(Stream) methods with calls to WriteAsync and vice versa * Add new tests using CallTrackingStream * Apply suggestions from code review Co-Authored-By: Stephen Toub * Apply additional suggestions from code review --- diff --git a/src/libraries/System.IO/tests/Stream/Stream.CopyToSpanTests.cs b/src/libraries/System.IO/tests/Stream/Stream.CopyToSpanTests.cs new file mode 100644 index 0000000..a9b57ad --- /dev/null +++ b/src/libraries/System.IO/tests/Stream/Stream.CopyToSpanTests.cs @@ -0,0 +1,535 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.IO.Tests +{ + public partial class StreamCopyToSpanTests + { + [Fact] + public void CopyTo_InvalidArgsThrows() + { + using Stream s = new MemoryStream(); + + AssertExtensions.Throws("callback", () => s.CopyTo(null, null, 0)); + AssertExtensions.Throws("bufferSize", () => s.CopyTo((_, __) => { }, null, 0)); + + AssertExtensions.Throws("callback", () => s.CopyToAsync(null, null, 0, default)); + AssertExtensions.Throws("bufferSize", () => s.CopyToAsync((_, __, ___) => default, null, 0, default)); + } + + [Fact] + public void CopyToAsync_PrecanceledToken_Cancels() + { + using var src = new MemoryStream(); + Assert.Equal(TaskStatus.Canceled, src.CopyToAsync((_, __, ___) => default, null, 4096, new CancellationToken(true)).Status); + } + + [Theory] + [MemberData(nameof(CopyTo_TestData))] + public async Task CopyToAsync_CancellationToken_Propagated(MemoryStream input) + { + using var src = input; + src.WriteByte(0); + src.Position = 0; + + CancellationToken cancellationToken = new CancellationTokenSource().Token; + CancellationToken expectedToken = (input is CustomMemoryStream cms && cms.Sync) ? default(CancellationToken) : cancellationToken; + await src.CopyToAsync( + (_, __, token) => new ValueTask(Task.Run(() => Assert.Equal(expectedToken, token))), + null, + 4096, + cancellationToken + ); + } + + [Theory] + [MemberData(nameof(CopyTo_TestData))] + public async Task CopyToAsync_State_Propagated(MemoryStream input) + { + using var src = input; + src.WriteByte(0); + src.Position = 0; + + const int expected = 42; + await src.CopyToAsync( + (_, state, __) => new ValueTask(Task.Run(() => Assert.Equal(expected, state))), + expected, + 4096, + default + ); + } + + [Theory] + [InlineData(0)] + [InlineData(42)] + [InlineData(100000)] + public void CopyToAsync_StreamToken_ExpectedBufferSizePropagated(int length) + { + using var src = new CustomMemoryStream_BufferSize(); + src.Write(new byte[length], 0, length); + src.Position = 0; + + Assert.Equal(length, ((Task)src.CopyToAsync((_, __, ___) => default, null, length, default(CancellationToken))).Result); + } + + private sealed class CustomMemoryStream_BufferSize : MemoryStream + { + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) => + Task.FromResult(bufferSize); + } + + [Theory] + [MemberData(nameof(CopyTo_TestData))] + public void CopyTo_AllDataCopied(MemoryStream input) + { + using var src = input; + src.Write(Enumerable.Range(0, 10000).Select(i => (byte)i).ToArray(), 0, 256); + src.Position = 0; + + using var dst = new MemoryStream(); + src.CopyTo((span, _) => dst.Write(span), null, 4096); + + Assert.Equal(src.ToArray(), dst.ToArray()); + } + + [Theory] + [MemberData(nameof(CopyTo_TestData))] + public async Task CopyToAsync_AllDataCopied(MemoryStream input) + { + using var src = input; + src.Write(Enumerable.Range(0, 10000).Select(i => (byte)i).ToArray(), 0, 256); + src.Position = 0; + + using var dst = new MemoryStream(); + await src.CopyToAsync((memory, _, ___) => dst.WriteAsync(memory), null, 4096, default); + + Assert.Equal(src.ToArray(), dst.ToArray()); + } + + private sealed class CustomMemoryStream : MemoryStream + { + private readonly bool _spanCopy; + private readonly bool _sync; + + public bool Sync => _sync; + + public CustomMemoryStream(bool spanCopy, bool sync) + : base() + { + _spanCopy = spanCopy; + _sync = sync; + } + + public override void CopyTo(Stream destination, int bufferSize) + { + if (_sync) + { + CopyToInternal(destination, bufferSize); + } + else + { + CopyToAsyncInternal(destination, bufferSize, CancellationToken.None).GetAwaiter().GetResult(); + } + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + if (!_sync) + { + return CopyToAsyncInternal(destination, bufferSize, cancellationToken); + } + else + { + try + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + CopyToInternal(destination, bufferSize); + return Task.CompletedTask; + } + catch (Exception e) + { + return Task.FromException(e); + } + } + } + + private void CopyToInternal(Stream destination, int bufferSize) + { + byte[] buffer = ArrayPool.Shared.Rent(bufferSize); + try + { + int read; + while ((read = Read(buffer, 0, buffer.Length)) != 0) + { + if (_spanCopy) + destination.Write(new ReadOnlySpan(buffer, 0, read)); + else + destination.Write(buffer, 0, read); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + private async Task CopyToAsyncInternal(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + byte[] buffer = ArrayPool.Shared.Rent(bufferSize); + try + { + while (true) + { + int bytesRead = await ReadAsync(new Memory(buffer), cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) break; + if (_spanCopy) + await destination.WriteAsync(new ReadOnlyMemory(buffer, 0, bytesRead), cancellationToken).ConfigureAwait(false); + else + await destination.WriteAsync(buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + } + + public static IEnumerable CopyTo_TestData() + { + foreach (var sync in new[] { false, true }) + foreach (var spanCopy in new[] { false, true }) + yield return new object[] { new CustomMemoryStream(spanCopy, sync) }; + + yield return new object[] { new MemoryStream() }; + } + + [Fact] + public void IfCanSeekIsFalseLengthAndPositionShouldNotBeCalled() + { + var baseStream = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => false, + readFunc: (buffer, offset, count) => 0); + var trackingStream = new CallTrackingStream(baseStream); + + trackingStream.CopyTo((_, __) => { }, null, 1); + + Assert.InRange(trackingStream.TimesCalled(nameof(trackingStream.CanSeek)), 0, 1); + Assert.Equal(0, trackingStream.TimesCalled(nameof(trackingStream.Length))); + Assert.Equal(0, trackingStream.TimesCalled(nameof(trackingStream.Position))); + // We can't override CopyTo since it's not virtual, so checking TimesCalled + // for CopyTo will result in 0. Instead, we check that Read was called, + // and validate the parameters passed there. + Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read))); + + byte[] outerBuffer = trackingStream.ReadBuffer; + int outerOffset = trackingStream.ReadOffset; + int outerCount = trackingStream.ReadCount; + + Assert.NotNull(outerBuffer); + Assert.InRange(outerOffset, 0, outerBuffer.Length - outerCount); + Assert.InRange(outerCount, 1, int.MaxValue); // the buffer can't be size 0 + } + + [Fact] + public async Task AsyncIfCanSeekIsFalseLengthAndPositionShouldNotBeCalled() + { + var baseStream = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => false, + readFunc: (buffer, offset, count) => 0); + var trackingStream = new CallTrackingStream(baseStream); + + await trackingStream.CopyToAsync((_, __, ___) => default, null, 1, default(CancellationToken)); + + Assert.InRange(trackingStream.TimesCalled(nameof(trackingStream.CanSeek)), 0, 1); + Assert.Equal(0, trackingStream.TimesCalled(nameof(trackingStream.Length))); + Assert.Equal(0, trackingStream.TimesCalled(nameof(trackingStream.Position))); + Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.CopyToAsync))); + + Assert.InRange(trackingStream.CopyToAsyncBufferSize, 1, int.MaxValue); + Assert.Equal(default(CancellationToken), trackingStream.CopyToAsyncCancellationToken); + } + + [Fact] + public void IfCanSeekIsTrueLengthAndPositionShouldOnlyBeCalledOnce() + { + var baseStream = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + readFunc: (buffer, offset, count) => 0, + lengthFunc: () => 0L, + positionGetFunc: () => 0L); + var trackingStream = new CallTrackingStream(baseStream); + + trackingStream.CopyTo((_, __) => { }, null, 1); + + Assert.InRange(trackingStream.TimesCalled(nameof(trackingStream.Length)), 0, 1); + Assert.InRange(trackingStream.TimesCalled(nameof(trackingStream.Position)), 0, 1); + } + + [Fact] + public async Task AsyncIfCanSeekIsTrueLengthAndPositionShouldOnlyBeCalledOnce() + { + var baseStream = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + readFunc: (buffer, offset, count) => 0, + lengthFunc: () => 0L, + positionGetFunc: () => 0L); + var trackingStream = new CallTrackingStream(baseStream); + + await trackingStream.CopyToAsync((_, __, ___) => default, null, 1, default(CancellationToken)); + + Assert.InRange(trackingStream.TimesCalled(nameof(trackingStream.Length)), 0, 1); + Assert.InRange(trackingStream.TimesCalled(nameof(trackingStream.Position)), 0, 1); + } + + [Theory] + [MemberData(nameof(LengthIsLessThanOrEqualToPosition))] + public void IfLengthIsLessThanOrEqualToPositionCopyToShouldStillBeCalledWithAPositiveBufferSize(long length, long position) + { + // Streams with their Lengths <= their Positions, e.g. + // new MemoryStream { Position = 3 }.SetLength(1) + // should still be called CopyTo{Async} on with a + // bufferSize of at least 1. + + var baseStream = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + lengthFunc: () => length, + positionGetFunc: () => position, + readFunc: (buffer, offset, count) => 0); + var trackingStream = new CallTrackingStream(baseStream); + + trackingStream.CopyTo((_, __) => { }, null, 1); + + // CopyTo is not virtual, so we can't override it in + // CallTrackingStream and record the arguments directly. + // Instead, validate the arguments passed to Read. + + Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read))); + + byte[] outerBuffer = trackingStream.ReadBuffer; + int outerOffset = trackingStream.ReadOffset; + int outerCount = trackingStream.ReadCount; + + Assert.NotNull(outerBuffer); + Assert.InRange(outerOffset, 0, outerBuffer.Length - outerCount); + Assert.InRange(outerCount, 1, int.MaxValue); + } + + [Theory] + [MemberData(nameof(LengthIsLessThanOrEqualToPosition))] + public async Task AsyncIfLengthIsLessThanOrEqualToPositionCopyToShouldStillBeCalledWithAPositiveBufferSize(long length, long position) + { + var baseStream = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + lengthFunc: () => length, + positionGetFunc: () => position, + readFunc: (buffer, offset, count) => 0); + var trackingStream = new CallTrackingStream(baseStream); + + await trackingStream.CopyToAsync((_, __, ___) => default, null, 1, default(CancellationToken)); + + Assert.InRange(trackingStream.CopyToAsyncBufferSize, 1, int.MaxValue); + Assert.Equal(default(CancellationToken), trackingStream.CopyToAsyncCancellationToken); + } + + [Theory] + [MemberData(nameof(LengthMinusPositionPositiveOverflows))] + public void IfLengthMinusPositionPositiveOverflowsBufferSizeShouldStillBePositive(long length, long position) + { + // The new implementation of Stream.CopyTo calculates the bytes left + // in the Stream by calling Length - Position. This can overflow to a + // negative number, so this tests that if that happens we don't send + // in a negative bufferSize. + + var baseStream = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + lengthFunc: () => length, + positionGetFunc: () => position, + readFunc: (buffer, offset, count) => 0); + var trackingStream = new CallTrackingStream(baseStream); + + trackingStream.CopyTo((_, __) => { }, null, 1); + + // CopyTo is not virtual, so we can't override it in + // CallTrackingStream and record the arguments directly. + // Instead, validate the arguments passed to Read. + + Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read))); + + byte[] outerBuffer = trackingStream.ReadBuffer; + int outerOffset = trackingStream.ReadOffset; + int outerCount = trackingStream.ReadCount; + + Assert.NotNull(outerBuffer); + Assert.InRange(outerOffset, 0, outerBuffer.Length - outerCount); + Assert.InRange(outerCount, 1, int.MaxValue); + } + + [Theory] + [MemberData(nameof(LengthMinusPositionPositiveOverflows))] + public async Task AsyncIfLengthMinusPositionPositiveOverflowsBufferSizeShouldStillBePositive(long length, long position) + { + var baseStream = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + lengthFunc: () => length, + positionGetFunc: () => position, + readFunc: (buffer, offset, count) => 0); + var trackingStream = new CallTrackingStream(baseStream); + + await trackingStream.CopyToAsync((_, __, ___) => default, null, 1, default(CancellationToken)); + + // Note: We can't check how many times ReadAsync was called + // here, since trackingStream overrides CopyToAsync and forwards + // to the inner (non-tracking) stream for the implementation + + Assert.InRange(trackingStream.CopyToAsyncBufferSize, 1, int.MaxValue); + Assert.Equal(default(CancellationToken), trackingStream.CopyToAsyncCancellationToken); + } + + [Theory] + [MemberData(nameof(LengthIsGreaterThanPositionAndDoesNotOverflow))] + public void IfLengthIsGreaterThanPositionAndDoesNotOverflowEverythingShouldGoNormally(long length, long position) + { + const int ReadLimit = 7; + + // Lambda state + byte[] outerBuffer = null; + int? outerOffset = null; + int? outerCount = null; + int readsLeft = ReadLimit; + + var srcBase = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + lengthFunc: () => length, + positionGetFunc: () => position, + readFunc: (buffer, offset, count) => + { + Assert.NotNull(buffer); + Assert.InRange(offset, 0, buffer.Length - count); + Assert.InRange(count, 1, int.MaxValue); + + // CopyTo should always pass in the same buffer/offset/count + + if (outerBuffer != null) Assert.Same(outerBuffer, buffer); + else outerBuffer = buffer; + + if (outerOffset != null) Assert.Equal(outerOffset, offset); + else outerOffset = offset; + + if (outerCount != null) Assert.Equal(outerCount, count); + else outerCount = count; + + return --readsLeft; // CopyTo will call Read on this ReadLimit times before stopping + }); + + var src = new CallTrackingStream(srcBase); + + int timesCalled = 0; + src.CopyTo((_, __) => { timesCalled++; }, null, 1); + + Assert.Equal(ReadLimit, src.TimesCalled(nameof(src.Read))); + Assert.Equal(ReadLimit - 1, timesCalled); + } + + [Theory] + [MemberData(nameof(LengthIsGreaterThanPositionAndDoesNotOverflow))] + public async Task AsyncIfLengthIsGreaterThanPositionAndDoesNotOverflowEverythingShouldGoNormally(long length, long position) + { + const int ReadLimit = 7; + + // Lambda state + byte[] outerBuffer = null; + int? outerOffset = null; + int? outerCount = null; + int readsLeft = ReadLimit; + + var srcBase = new DelegateStream( + canReadFunc: () => true, + canSeekFunc: () => true, + lengthFunc: () => length, + positionGetFunc: () => position, + readFunc: (buffer, offset, count) => + { + Assert.NotNull(buffer); + Assert.InRange(offset, 0, buffer.Length - count); + Assert.InRange(count, 1, int.MaxValue); + + // CopyTo should always pass in the same buffer/offset/count + + if (outerBuffer != null) Assert.Same(outerBuffer, buffer); + else outerBuffer = buffer; + + if (outerOffset != null) Assert.Equal(outerOffset, offset); + else outerOffset = offset; + + if (outerCount != null) Assert.Equal(outerCount, count); + else outerCount = count; + + return --readsLeft; // CopyTo will call Read on this ReadLimit times before stopping + }); + + var src = new CallTrackingStream(srcBase); + + int timesCalled = 0; + await src.CopyToAsync((_, __, ___) => { timesCalled++; return default; }, null, 1, default(CancellationToken)); + + // Since we override CopyToAsync in CallTrackingStream, + // src.Read will actually not get called ReadLimit + // times, src.Inner.Read will. So, we just assert that + // CopyToAsync was called once for src. + + Assert.Equal(1, src.TimesCalled(nameof(src.CopyToAsync))); + Assert.Equal(ReadLimit - 1, timesCalled); // dest.WriteAsync will still get called repeatedly + } + + // Member data + + public static IEnumerable LengthIsLessThanOrEqualToPosition() + { + yield return new object[] { 5L, 5L }; // same number + yield return new object[] { 3L, 5L }; // length is less than position + yield return new object[] { -1L, -1L }; // negative numbers + yield return new object[] { 0L, 0L }; // both zero + yield return new object[] { -500L, 0L }; // negative number and zero + yield return new object[] { 0L, 500L }; // zero and positive number + yield return new object[] { -500L, 500L }; // negative and positive number + yield return new object[] { long.MinValue, long.MaxValue }; // length - position <= 0 will fail (overflow), but length <= position won't + } + + public static IEnumerable LengthMinusPositionPositiveOverflows() + { + yield return new object[] { long.MaxValue, long.MinValue }; // length - position will be -1 + yield return new object[] { 1L, -long.MaxValue }; + } + + public static IEnumerable LengthIsGreaterThanPositionAndDoesNotOverflow() + { + yield return new object[] { 5L, 3L }; + yield return new object[] { -3L, -6L }; + yield return new object[] { 0L, -3L }; + yield return new object[] { long.MaxValue, 0 }; // should not overflow or OOM + yield return new object[] { 85000, 123 }; // at least in the current implementation, we max out the bufferSize at 81920 + } + } +} diff --git a/src/libraries/System.IO/tests/System.IO.Tests.csproj b/src/libraries/System.IO/tests/System.IO.Tests.csproj index a76da86..22f066e 100644 --- a/src/libraries/System.IO/tests/System.IO.Tests.csproj +++ b/src/libraries/System.IO/tests/System.IO.Tests.csproj @@ -39,6 +39,7 @@ + diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index d203dcc..c1ce252 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -5310,7 +5310,7 @@ namespace System.IO public System.Threading.Tasks.Task CopyToAsync(System.IO.Stream destination, int bufferSize) { throw null; } public virtual System.Threading.Tasks.Task CopyToAsync(System.IO.Stream destination, int bufferSize, System.Threading.CancellationToken cancellationToken) { throw null; } public System.Threading.Tasks.Task CopyToAsync(System.IO.Stream destination, System.Threading.CancellationToken cancellationToken) { throw null; } - public virtual System.Threading.Tasks.Task CopyToAsync(Func, object?, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask> callback, object? state, int bufferSize, System.Threading.CancellationToken cancellationToken) { throw null; } + public virtual System.Threading.Tasks.Task CopyToAsync(System.Func, object?, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask> callback, object? state, int bufferSize, System.Threading.CancellationToken cancellationToken) { throw null; } [System.ObsoleteAttribute("CreateWaitHandle will be removed eventually. Please use \"new ManualResetEvent(false)\" instead.")] protected virtual System.Threading.WaitHandle CreateWaitHandle() { throw null; } public void Dispose() { }