Improve set_position to reuse buffer. (#54991)
authorLateApexEarlySpeed <72254037+lateapexearlyspeed@users.noreply.github.com>
Wed, 11 Aug 2021 15:16:07 +0000 (23:16 +0800)
committerGitHub <noreply@github.com>
Wed, 11 Aug 2021 15:16:07 +0000 (17:16 +0200)
Co-authored-by: Jeff Handley <jeff.handley@microsoft.com>
Co-authored-by: Adam Sitnik <adam.sitnik@gmail.com>
src/libraries/System.IO/tests/BufferedStream/BufferedStreamTests.cs
src/libraries/System.Private.CoreLib/src/System/IO/BufferedStream.cs

index 8395a64..0826875 100644 (file)
@@ -128,6 +128,94 @@ namespace System.IO.Tests
             Assert.Equal(0, wrapper.TimesCalled(nameof(wrapper.FlushAsync)));
         }
 
+        [Theory]
+        [MemberData(nameof(SetPosMethods))]
+        public void SetPositionInsideBufferRange_Read_WillNotReadUnderlyingStreamAgain(int sharedBufSize, Action<Stream, long> setPos)
+        {
+            var trackingStream = new CallTrackingStream(new MemoryStream());
+            var bufferedStream = new BufferedStream(trackingStream, sharedBufSize);
+            bufferedStream.Write(Enumerable.Range(0, sharedBufSize * 2).Select(i => (byte)i).ToArray(), 0, sharedBufSize * 2);
+            setPos(bufferedStream, 0);
+
+            var readBuf = new byte[sharedBufSize - 1];
+
+            // First half part verification
+            byte[] expectedReadBuf = Enumerable.Range(0, sharedBufSize - 1).Select(i => (byte)i).ToArray();
+
+            // Call Read() to fill shared read buffer
+            int readBytes = bufferedStream.Read(readBuf, 0, readBuf.Length);
+            Assert.Equal(readBuf.Length, readBytes);
+            Assert.Equal(sharedBufSize - 1, bufferedStream.Position);
+            Assert.Equal(expectedReadBuf, readBuf);
+            Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read)));
+
+            // Set position inside range of shared read buffer
+            for (int pos = 0; pos < sharedBufSize - 1; pos++)
+            {
+                setPos(bufferedStream, pos);
+
+                readBytes = bufferedStream.Read(readBuf, pos, readBuf.Length - pos);
+                Assert.Equal(readBuf.Length - pos, readBytes);
+                Assert.Equal(sharedBufSize - 1, bufferedStream.Position);
+                Assert.Equal(expectedReadBuf, readBuf);
+                // Should not trigger underlying stream's Read()
+                Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read)));
+            }
+
+            Assert.Equal(sharedBufSize - 1, bufferedStream.ReadByte());
+            Assert.Equal(sharedBufSize, bufferedStream.Position);
+            // Should not trigger underlying stream's Read()
+            Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read)));
+
+            // Second half part verification
+            expectedReadBuf = Enumerable.Range(sharedBufSize, sharedBufSize - 1).Select(i => (byte)i).ToArray();
+            // Call Read() to fill shared read buffer
+            readBytes = bufferedStream.Read(readBuf, 0, readBuf.Length);
+            Assert.Equal(readBuf.Length, readBytes);
+            Assert.Equal(sharedBufSize * 2 - 1, bufferedStream.Position);
+            Assert.Equal(expectedReadBuf, readBuf);
+            Assert.Equal(2, trackingStream.TimesCalled(nameof(trackingStream.Read)));
+
+            // Set position inside range of shared read buffer
+            for (int pos = 0; pos < sharedBufSize - 1; pos++)
+            {
+                setPos(bufferedStream, sharedBufSize + pos);
+
+                readBytes = bufferedStream.Read(readBuf, pos, readBuf.Length - pos);
+                Assert.Equal(readBuf.Length - pos, readBytes);
+                Assert.Equal(sharedBufSize * 2 - 1, bufferedStream.Position);
+                Assert.Equal(expectedReadBuf, readBuf);
+                // Should not trigger underlying stream's Read()
+                Assert.Equal(2, trackingStream.TimesCalled(nameof(trackingStream.Read)));
+            }
+
+            Assert.Equal(sharedBufSize * 2 - 1, bufferedStream.ReadByte());
+            Assert.Equal(sharedBufSize * 2, bufferedStream.Position);
+            // Should not trigger underlying stream's Read()
+            Assert.Equal(2, trackingStream.TimesCalled(nameof(trackingStream.Read)));
+        }
+
+        public static IEnumerable<object[]> SetPosMethods
+        {
+            get
+            {
+                var setByPosition = (Action<Stream, long>)((stream, pos) => stream.Position = pos);
+                var seekFromBegin = (Action<Stream, long>)((stream, pos) => stream.Seek(pos, SeekOrigin.Begin));
+                var seekFromCurrent = (Action<Stream, long>)((stream, pos) => stream.Seek(pos - stream.Position, SeekOrigin.Current));
+                var seekFromEnd = (Action<Stream, long>)((stream, pos) => stream.Seek(pos - stream.Length, SeekOrigin.End));
+
+                yield return new object[] { 3, setByPosition };
+                yield return new object[] { 3, seekFromBegin };
+                yield return new object[] { 3, seekFromCurrent };
+                yield return new object[] { 3, seekFromEnd };
+
+                yield return new object[] { 10, setByPosition };
+                yield return new object[] { 10, seekFromBegin };
+                yield return new object[] { 10, seekFromCurrent };
+                yield return new object[] { 10, seekFromEnd };
+            }
+        }
+
         [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
         public async Task ConcurrentOperationsAreSerialized()
         {
index 1b1f5ad..bcdbc29 100644 (file)
@@ -205,15 +205,7 @@ namespace System.IO
                 if (value < 0)
                     ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.value, ExceptionResource.ArgumentOutOfRange_NeedNonNegNum);
 
-                EnsureNotClosed();
-                EnsureCanSeek();
-
-                if (_writePos > 0)
-                    FlushWrite();
-
-                _readPos = 0;
-                _readLen = 0;
-                _stream!.Seek(value, SeekOrigin.Begin);
+                Seek(value, SeekOrigin.Begin);
             }
         }