Add new Span-based virtual sync Read/Write Stream methods (#13058)
authorStephen Toub <stoub@microsoft.com>
Thu, 27 Jul 2017 17:20:48 +0000 (13:20 -0400)
committerGitHub <noreply@github.com>
Thu, 27 Jul 2017 17:20:48 +0000 (13:20 -0400)
* Add virtual Stream.Read/Write Span-based APIs

* Override Span-based Read/Write on MemoryStream

* Override Span-based Read/Write on UnmanagedMemoryStream

* Address PR feedback

src/mscorlib/shared/System/IO/PinnedBufferMemoryStream.cs
src/mscorlib/shared/System/IO/UnmanagedMemoryStream.cs
src/mscorlib/shared/System/IO/UnmanagedMemoryStreamWrapper.cs
src/mscorlib/src/System/IO/MemoryStream.cs
src/mscorlib/src/System/IO/Stream.cs

index c8e720b..e8f74dd 100644 (file)
@@ -46,6 +46,10 @@ namespace System.IO
                 Initialize(ptr, len, len, FileAccess.Read);
         }
 
+        public override int Read(Span<byte> destination) => ReadCore(destination);
+
+        public override void Write(ReadOnlySpan<byte> source) => WriteCore(source);
+
         ~PinnedBufferMemoryStream()
         {
             Dispose(false);
index b78f50f..f808ab4 100644 (file)
@@ -361,8 +361,27 @@ namespace System.IO
                 throw new ArgumentOutOfRangeException(nameof(count), SR.ArgumentOutOfRange_NeedNonNegNum);
             if (buffer.Length - offset < count)
                 throw new ArgumentException(SR.Argument_InvalidOffLen);
-            Contract.EndContractBlock();  // Keep this in sync with contract validation in ReadAsync
 
+            return ReadCore(new Span<byte>(buffer, offset, count));
+        }
+
+        public override int Read(Span<byte> destination)
+        {
+            if (GetType() == typeof(UnmanagedMemoryStream))
+            {
+                return ReadCore(destination);
+            }
+            else
+            {
+                // UnmanagedMemoryStream is not sealed, and a derived type may have overridden Read(byte[], int, int) prior
+                // to this Read(Span<byte>) overload being introduced.  In that case, this Read(Span<byte>) overload
+                // should use the behavior of Read(byte[],int,int) overload.
+                return base.Read(destination);
+            }
+        }
+
+        internal int ReadCore(Span<byte> destination)
+        {
             if (!_isOpen) throw Error.GetStreamIsClosed();
             if (!CanRead) throw Error.GetReadNotSupported();
 
@@ -370,20 +389,22 @@ namespace System.IO
             // changes our position after we decide we can read some bytes.
             long pos = Interlocked.Read(ref _position);
             long len = Interlocked.Read(ref _length);
-            long n = len - pos;
-            if (n > count)
-                n = count;
+            long n = Math.Min(len - pos, destination.Length);
             if (n <= 0)
+            {
                 return 0;
+            }
 
             int nInt = (int)n; // Safe because n <= count, which is an Int32
             if (nInt < 0)
+            {
                 return 0;  // _position could be beyond EOF
+            }
             Debug.Assert(pos + nInt >= 0, "_position + n >= 0");  // len is less than 2^63 -1.
 
             unsafe
             {
-                fixed (byte* pBuffer = buffer)
+                fixed (byte* pBuffer = &destination.DangerousGetPinnableReference())
                 {
                     if (_buffer != null)
                     {
@@ -393,7 +414,7 @@ namespace System.IO
                         try
                         {
                             _buffer.AcquirePointer(ref pointer);
-                            Buffer.Memcpy(pBuffer + offset, pointer + pos + _offset, nInt);
+                            Buffer.Memcpy(pBuffer, pointer + pos + _offset, nInt);
                         }
                         finally
                         {
@@ -405,7 +426,7 @@ namespace System.IO
                     }
                     else
                     {
-                        Buffer.Memcpy(pBuffer + offset, _mem + pos, nInt);
+                        Buffer.Memcpy(pBuffer, _mem + pos, nInt);
                     }
                 }
             }
@@ -583,17 +604,38 @@ namespace System.IO
                 throw new ArgumentOutOfRangeException(nameof(count), SR.ArgumentOutOfRange_NeedNonNegNum);
             if (buffer.Length - offset < count)
                 throw new ArgumentException(SR.Argument_InvalidOffLen);
-            Contract.EndContractBlock();  // Keep contract validation in sync with WriteAsync(..)
 
+            WriteCore(new Span<byte>(buffer, offset, count));
+        }
+
+        public override void Write(ReadOnlySpan<byte> source)
+        {
+            if (GetType() == typeof(UnmanagedMemoryStream))
+            {
+                WriteCore(source);
+            }
+            else
+            {
+                // UnmanagedMemoryStream is not sealed, and a derived type may have overridden Write(byte[], int, int) prior
+                // to this Write(Span<byte>) overload being introduced.  In that case, this Write(Span<byte>) overload
+                // should use the behavior of Write(byte[],int,int) overload.
+                base.Write(source);
+            }
+        }
+
+        internal unsafe void WriteCore(ReadOnlySpan<byte> source)
+        {
             if (!_isOpen) throw Error.GetStreamIsClosed();
             if (!CanWrite) throw Error.GetWriteNotSupported();
 
             long pos = Interlocked.Read(ref _position);  // Use a local to avoid a race condition
             long len = Interlocked.Read(ref _length);
-            long n = pos + count;
+            long n = pos + source.Length;
             // Check for overflow
             if (n < 0)
+            {
                 throw new IOException(SR.IO_StreamTooLong);
+            }
 
             if (n > _capacity)
             {
@@ -606,10 +648,7 @@ namespace System.IO
                 // zero any memory in the middle.
                 if (pos > len)
                 {
-                    unsafe
-                    {
-                        Buffer.ZeroMemory(_mem + len, pos - len);
-                    }
+                    Buffer.ZeroMemory(_mem + len, pos - len);
                 }
 
                 // set length after zeroing memory to avoid race condition of accessing unzeroed memory
@@ -619,39 +658,37 @@ namespace System.IO
                 }
             }
 
-            unsafe
+            fixed (byte* pBuffer = &source.DangerousGetPinnableReference())
             {
-                fixed (byte* pBuffer = buffer)
+                if (_buffer != null)
                 {
-                    if (_buffer != null)
+                    long bytesLeft = _capacity - pos;
+                    if (bytesLeft < source.Length)
                     {
-                        long bytesLeft = _capacity - pos;
-                        if (bytesLeft < count)
-                        {
-                            throw new ArgumentException(SR.Arg_BufferTooSmall);
-                        }
+                        throw new ArgumentException(SR.Arg_BufferTooSmall);
+                    }
 
-                        byte* pointer = null;
-                        RuntimeHelpers.PrepareConstrainedRegions();
-                        try
-                        {
-                            _buffer.AcquirePointer(ref pointer);
-                            Buffer.Memcpy(pointer + pos + _offset, pBuffer + offset, count);
-                        }
-                        finally
-                        {
-                            if (pointer != null)
-                            {
-                                _buffer.ReleasePointer();
-                            }
-                        }
+                    byte* pointer = null;
+                    RuntimeHelpers.PrepareConstrainedRegions();
+                    try
+                    {
+                        _buffer.AcquirePointer(ref pointer);
+                        Buffer.Memcpy(pointer + pos + _offset, pBuffer, source.Length);
                     }
-                    else
+                    finally
                     {
-                        Buffer.Memcpy(_mem + pos, pBuffer + offset, count);
+                        if (pointer != null)
+                        {
+                            _buffer.ReleasePointer();
+                        }
                     }
                 }
+                else
+                {
+                    Buffer.Memcpy(_mem + pos, pBuffer, source.Length);
+                }
             }
+
             Interlocked.Exchange(ref _position, n);
             return;
         }
index d547e77..f3e743a 100644 (file)
@@ -114,6 +114,11 @@ namespace System.IO
             return _unmanagedStream.Read(buffer, offset, count);
         }
 
+        public override int Read(Span<byte> destination)
+        {
+            return _unmanagedStream.Read(destination);
+        }
+
         public override int ReadByte()
         {
             return _unmanagedStream.ReadByte();
@@ -136,6 +141,11 @@ namespace System.IO
             _unmanagedStream.Write(buffer, offset, count);
         }
 
+        public override void Write(ReadOnlySpan<byte> source)
+        {
+            _unmanagedStream.Write(source);
+        }
+
         public override void WriteByte(byte value)
         {
             _unmanagedStream.WriteByte(value);
index daf09d1..91662c5 100644 (file)
@@ -391,6 +391,37 @@ namespace System.IO
             return n;
         }
 
+        public override int Read(Span<byte> destination)
+        {
+            if (GetType() != typeof(MemoryStream))
+            {
+                // MemoryStream is not sealed, and a derived type may have overridden Read(byte[], int, int) prior
+                // to this Read(Span<byte>) overload being introduced.  In that case, this Read(Span<byte>) overload
+                // should use the behavior of Read(byte[],int,int) overload.
+                return base.Read(destination);
+            }
+
+            if (!_isOpen)
+            {
+                __Error.StreamIsClosed();
+            }
+
+            int n = Math.Min(_length - _position, destination.Length);
+            if (n <= 0)
+            {
+                return 0;
+            }
+
+            // TODO https://github.com/dotnet/corefx/issues/22388:
+            // Read(byte[], int, int) has an n <= 8 optimization, presumably based
+            // on benchmarking.  Determine if/where such a cut-off is here and add
+            // an equivalent optimization if necessary.
+            new Span<byte>(_buffer, _position, n).CopyTo(destination);
+
+            _position += n;
+            return n;
+        }
+
         public override Task<int> ReadAsync(Byte[] buffer, int offset, int count, CancellationToken cancellationToken)
         {
             if (buffer == null)
@@ -634,6 +665,52 @@ namespace System.IO
             _position = i;
         }
 
+        public override void Write(ReadOnlySpan<byte> source)
+        {
+            if (GetType() != typeof(MemoryStream))
+            {
+                // MemoryStream is not sealed, and a derived type may have overridden Write(byte[], int, int) prior
+                // to this Write(Span<byte>) overload being introduced.  In that case, this Write(Span<byte>) overload
+                // should use the behavior of Write(byte[],int,int) overload.
+                base.Write(source);
+                return;
+            }
+
+            if (!_isOpen)
+            {
+                __Error.StreamIsClosed();
+            }
+            EnsureWriteable();
+
+            // Check for overflow
+            int i = _position + source.Length;
+            if (i < 0)
+            {
+                throw new IOException(SR.IO_StreamTooLong);
+            }
+
+            if (i > _length)
+            {
+                bool mustZero = _position > _length;
+                if (i > _capacity)
+                {
+                    bool allocatedNewArray = EnsureCapacity(i);
+                    if (allocatedNewArray)
+                    {
+                        mustZero = false;
+                    }
+                }
+                if (mustZero)
+                {
+                    Array.Clear(_buffer, _length, i - _length);
+                }
+                _length = i;
+            }
+
+            source.CopyTo(new Span<byte>(_buffer, _position, source.Length));
+            _position = i;
+        }
+
         public override Task WriteAsync(Byte[] buffer, int offset, int count, CancellationToken cancellationToken)
         {
             if (buffer == null)
index 786dfed..de8226b 100644 (file)
@@ -734,6 +734,27 @@ namespace System.IO
 
         public abstract int Read([In, Out] byte[] buffer, int offset, int count);
 
+        public virtual int Read(Span<byte> destination)
+        {
+            if (destination.Length == 0)
+            {
+                return 0;
+            }
+
+            byte[] buffer = ArrayPool<byte>.Shared.Rent(destination.Length);
+            try
+            {
+                int numRead = Read(buffer, 0, destination.Length);
+                if ((uint)numRead > destination.Length)
+                {
+                    throw new IOException(SR.IO_StreamTooLong);
+                }
+                new Span<byte>(buffer, 0, numRead).CopyTo(destination);
+                return numRead;
+            }
+            finally { ArrayPool<byte>.Shared.Return(buffer); }
+        }
+
         // Reads one byte from the stream by calling Read(byte[], int, int). 
         // Will return an unsigned byte cast to an int or -1 on end of stream.
         // This implementation does not perform well because it allocates a new
@@ -754,6 +775,22 @@ namespace System.IO
 
         public abstract void Write(byte[] buffer, int offset, int count);
 
+        public virtual void Write(ReadOnlySpan<byte> source)
+        {
+            if (source.Length == 0)
+            {
+                return;
+            }
+
+            byte[] buffer = ArrayPool<byte>.Shared.Rent(source.Length);
+            try
+            {
+                source.CopyTo(buffer);
+                Write(buffer, 0, source.Length);
+            }
+            finally { ArrayPool<byte>.Shared.Return(buffer); }
+        }
+
         // Writes one byte from the stream by calling Write(byte[], int, int).
         // This implementation does not perform well because it allocates a new
         // byte[] each time you call it, and should be overridden by any 
@@ -957,6 +994,11 @@ namespace System.IO
                 return 0;
             }
 
+            public override int Read(Span<byte> destination)
+            {
+                return 0;
+            }
+
             public override Task<int> ReadAsync(Byte[] buffer, int offset, int count, CancellationToken cancellationToken)
             {
                 var nullReadTask = s_nullReadTask;
@@ -975,6 +1017,10 @@ namespace System.IO
             {
             }
 
+            public override void Write(ReadOnlySpan<byte> source)
+            {
+            }
+
             public override Task WriteAsync(Byte[] buffer, int offset, int count, CancellationToken cancellationToken)
             {
                 return cancellationToken.IsCancellationRequested ?
@@ -1229,6 +1275,12 @@ namespace System.IO
                     return _stream.Read(bytes, offset, count);
             }
 
+            public override int Read(Span<byte> destination)
+            {
+                lock (_stream)
+                    return _stream.Read(destination);
+            }
+
             public override int ReadByte()
             {
                 lock (_stream)
@@ -1282,6 +1334,12 @@ namespace System.IO
                     _stream.Write(bytes, offset, count);
             }
 
+            public override void Write(ReadOnlySpan<byte> source)
+            {
+                lock (_stream)
+                    _stream.Write(source);
+            }
+
             public override void WriteByte(byte b)
             {
                 lock (_stream)