Add object null checks in Memory<T> APIs to support default (#14816)
authorAhson Khan <ahkha@microsoft.com>
Tue, 14 Nov 2017 15:31:48 +0000 (07:31 -0800)
committerStephen Toub <stoub@microsoft.com>
Tue, 14 Nov 2017 15:31:48 +0000 (10:31 -0500)
* Add object null checks in Memory<T> APIs to support default

* Changing Empty property to return default value.

* Update use of object null checks

* Addressing PR feedback.

* Removing typeof char check.

* Fix typo

src/mscorlib/shared/System/Memory.cs
src/mscorlib/shared/System/ReadOnlyMemory.cs

index d83bd58..e66a61e 100644 (file)
@@ -20,7 +20,7 @@ namespace System
 
         // The highest order bit of _index is used to discern whether _object is an array/string or an owned memory
         // if (_index >> 31) == 1, object _object is an OwnedMemory<T>
-        // else, object _object is a T[] or a string.  It can only be a string if the Memory<T> was created by
+        // else, object _object is a T[] or a string. It can only be a string if the Memory<T> was created by
         // using unsafe / marshaling code to reinterpret a ReadOnlyMemory<char> wrapped around a string as
         // a Memory<T>.
         private readonly object _object;
@@ -122,7 +122,7 @@ namespace System
         /// <summary>
         /// Returns an empty <see cref="Memory{T}"/>
         /// </summary>
-        public static Memory<T> Empty { get; } = Array.Empty<T>();
+        public static Memory<T> Empty => default;
 
         /// <summary>
         /// The number of items in the memory.
@@ -187,15 +187,19 @@ namespace System
                 {
                     // This is dangerous, returning a writable span for a string that should be immutable.
                     // However, we need to handle the case where a ReadOnlyMemory<char> was created from a string
-                    // and then cast to a Memory<T>.  Such a cast can only be done with unsafe or marshaling code,
+                    // and then cast to a Memory<T>. Such a cast can only be done with unsafe or marshaling code,
                     // in which case that's the dangerous operation performed by the dev, and we're just following
                     // suit here to make it work as best as possible.
                     return new Span<T>(ref Unsafe.As<char, T>(ref s.GetRawStringData()), s.Length).Slice(_index, _length);
                 }
-                else
+                else if (_object != null)
                 {
                     return new Span<T>((T[])_object, _index, _length);
                 }
+                else
+                {
+                    return default;
+                }
             }
         }
 
@@ -224,7 +228,7 @@ namespace System
 
         public unsafe MemoryHandle Retain(bool pin = false)
         {
-            MemoryHandle memoryHandle;
+            MemoryHandle memoryHandle = default;
             if (pin)
             {
                 if (_index < 0)
@@ -243,9 +247,8 @@ namespace System
                     void* pointer = Unsafe.Add<T>(Unsafe.AsPointer(ref s.GetRawStringData()), _index);
                     memoryHandle = new MemoryHandle(null, pointer, handle);
                 }
-                else
+                else if (_object is T[] array)
                 {
-                    var array = (T[])_object;
                     var handle = GCHandle.Alloc(array, GCHandleType.Pinned);
                     void* pointer = Unsafe.Add<T>(Unsafe.AsPointer(ref array.GetRawSzArrayData()), _index);
                     memoryHandle = new MemoryHandle(null, pointer, handle);
@@ -258,10 +261,6 @@ namespace System
                     ((OwnedMemory<T>)_object).Retain();
                     memoryHandle = new MemoryHandle((OwnedMemory<T>)_object);
                 }
-                else
-                {
-                    memoryHandle = new MemoryHandle(null);
-                }
             }
             return memoryHandle;
         }
@@ -280,14 +279,10 @@ namespace System
                     return true;
                 }
             }
-            else
+            else if (_object is T[] arr)
             {
-                T[] arr = _object as T[];
-                if (typeof(T) != typeof(char) || arr != null)
-                {
-                    arraySegment = new ArraySegment<T>(arr, _index, _length);
-                    return true;
-                }
+                arraySegment = new ArraySegment<T>(arr, _index, _length);
+                return true;
             }
 
             arraySegment = default(ArraySegment<T>);
@@ -333,7 +328,7 @@ namespace System
         [EditorBrowsable(EditorBrowsableState.Never)]
         public override int GetHashCode()
         {
-            return CombineHashCodes(_object.GetHashCode(), _index.GetHashCode(), _length.GetHashCode());
+            return _object != null ? CombineHashCodes(_object.GetHashCode(), _index.GetHashCode(), _length.GetHashCode()) : 0;
         }
 
         private static int CombineHashCodes(int left, int right)
index 1b163dd..d68f622 100644 (file)
@@ -75,7 +75,7 @@ namespace System
             _length = length;
         }
 
-        /// <summary>Creates a new memory over the existing object, start, and length.  No validation is performed.</summary>
+        /// <summary>Creates a new memory over the existing object, start, and length. No validation is performed.</summary>
         /// <param name="obj">The target object.</param>
         /// <param name="start">The index at which to begin the memory.</param>
         /// <param name="length">The number of items in the memory.</param>
@@ -104,7 +104,7 @@ namespace System
         /// <summary>
         /// Returns an empty <see cref="ReadOnlyMemory{T}"/>
         /// </summary>
-        public static ReadOnlyMemory<T> Empty { get; } = Array.Empty<T>();
+        public static ReadOnlyMemory<T> Empty => default;
 
         /// <summary>
         /// The number of items in the memory.
@@ -169,10 +169,14 @@ namespace System
                 {
                     return new ReadOnlySpan<T>(ref Unsafe.As<char, T>(ref s.GetRawStringData()), s.Length).Slice(_index, _length);
                 }
-                else
+                else if (_object != null)
                 {
                     return new ReadOnlySpan<T>((T[])_object, _index, _length);
                 }
+                else
+                {
+                    return default;
+                }
             }
         }
 
@@ -206,7 +210,7 @@ namespace System
         /// </param>
         public unsafe MemoryHandle Retain(bool pin = false)
         {
-            MemoryHandle memoryHandle;
+            MemoryHandle memoryHandle = default;
             if (pin)
             {
                 if (_index < 0)
@@ -220,9 +224,8 @@ namespace System
                     void* pointer = Unsafe.Add<T>(Unsafe.AsPointer(ref s.GetRawStringData()), _index);
                     memoryHandle = new MemoryHandle(null, pointer, handle);
                 }
-                else
+                else if (_object is T[] array)
                 {
-                    var array = (T[])_object;
                     var handle = GCHandle.Alloc(array, GCHandleType.Pinned);
                     void* pointer = Unsafe.Add<T>(Unsafe.AsPointer(ref array.GetRawSzArrayData()), _index);
                     memoryHandle = new MemoryHandle(null, pointer, handle);
@@ -235,10 +238,6 @@ namespace System
                     ((OwnedMemory<T>)_object).Retain();
                     memoryHandle = new MemoryHandle((OwnedMemory<T>)_object);
                 }
-                else
-                {
-                    memoryHandle = new MemoryHandle(null);
-                }
             }
             return memoryHandle;
         }
@@ -258,14 +257,10 @@ namespace System
                     return true;
                 }
             }
-            else
+            else if (_object is T[] arr)
             {
-                T[] arr = _object as T[];
-                if (typeof(T) != typeof(char) || arr != null)
-                {
-                    arraySegment = new ArraySegment<T>(arr, _index, _length);
-                    return true;
-                }
+                arraySegment = new ArraySegment<T>(arr, _index, _length);
+                return true;
             }
 
             arraySegment = default;
@@ -313,7 +308,7 @@ namespace System
         [EditorBrowsable(EditorBrowsableState.Never)]
         public override int GetHashCode()
         {
-            return CombineHashCodes(_object.GetHashCode(), _index.GetHashCode(), _length.GetHashCode());
+            return _object != null ? CombineHashCodes(_object.GetHashCode(), _index.GetHashCode(), _length.GetHashCode()) : 0;
         }
         
         private static int CombineHashCodes(int left, int right)