Reimplement System.Net.Http's ObjectCollection<T> to reduce allocation / interface...
authorStephen Toub <stoub@microsoft.com>
Mon, 13 Apr 2020 21:19:41 +0000 (17:19 -0400)
committerGitHub <noreply@github.com>
Mon, 13 Apr 2020 21:19:41 +0000 (17:19 -0400)
The current implementation of `ObjectCollection<T>` wraps a `List<T>` and derives from `Collection<T>`.  This means that every `ObjectCollection<T>` allocated also involves an extra `List<T>` object (and its array if items are added to it) as well as multiple levels of indirection on each operation.

We can instead just implement `ObjectCollection<T>` directly.  Since most uses end up with just a single object contained (e.g. for a `MedaTypeHeaderValue`'s `CharSet`), and since it only ever stores `T`s that are non-null classes, we can use the items field to be either a `T` or a `T[]`, optimizing for the case where a single element is stored.  In implementing it directly, we then also avoid the extra `List<T>` object allocation, as well as the interface dispatch that results from going through the `IList<T>` interface.

src/libraries/System.Net.Http/src/System/Net/Http/Headers/CacheControlHeaderValue.cs
src/libraries/System.Net.Http/src/System/Net/Http/Headers/ContentDispositionHeaderValue.cs
src/libraries/System.Net.Http/src/System/Net/Http/Headers/MediaTypeHeaderValue.cs
src/libraries/System.Net.Http/src/System/Net/Http/Headers/NameValueWithParametersHeaderValue.cs
src/libraries/System.Net.Http/src/System/Net/Http/Headers/ObjectCollection.cs
src/libraries/System.Net.Http/src/System/Net/Http/Headers/RangeHeaderValue.cs
src/libraries/System.Net.Http/src/System/Net/Http/Headers/TransferCodingHeaderValue.cs
src/libraries/System.Net.Http/tests/UnitTests/Headers/ObjectCollectionTest.cs

index e73adeb..f5bf793 100644 (file)
@@ -52,17 +52,7 @@ namespace System.Net.Http.Headers
             set { _noCache = value; }
         }
 
-        public ICollection<string> NoCacheHeaders
-        {
-            get
-            {
-                if (_noCacheHeaders == null)
-                {
-                    _noCacheHeaders = new ObjectCollection<string>(s_checkIsValidToken);
-                }
-                return _noCacheHeaders;
-            }
-        }
+        public ICollection<string> NoCacheHeaders => _noCacheHeaders ??= new ObjectCollection<string>(s_checkIsValidToken);
 
         public bool NoStore
         {
@@ -124,17 +114,7 @@ namespace System.Net.Http.Headers
             set { _privateField = value; }
         }
 
-        public ICollection<string> PrivateHeaders
-        {
-            get
-            {
-                if (_privateHeaders == null)
-                {
-                    _privateHeaders = new ObjectCollection<string>(s_checkIsValidToken);
-                }
-                return _privateHeaders;
-            }
-        }
+        public ICollection<string> PrivateHeaders => _privateHeaders ??= new ObjectCollection<string>(s_checkIsValidToken);
 
         public bool MustRevalidate
         {
@@ -148,17 +128,7 @@ namespace System.Net.Http.Headers
             set { _proxyRevalidate = value; }
         }
 
-        public ICollection<NameValueHeaderValue> Extensions
-        {
-            get
-            {
-                if (_extensions == null)
-                {
-                    _extensions = new ObjectCollection<NameValueHeaderValue>();
-                }
-                return _extensions;
-            }
-        }
+        public ICollection<NameValueHeaderValue> Extensions => _extensions ??= new ObjectCollection<NameValueHeaderValue>();
 
         public CacheControlHeaderValue()
         {
@@ -604,11 +574,7 @@ namespace System.Net.Http.Headers
                     return false;
                 }
 
-                if (destination == null)
-                {
-                    destination = new ObjectCollection<string>(s_checkIsValidToken);
-                }
-
+                destination ??= new ObjectCollection<string>(s_checkIsValidToken);
                 destination.Add(valueString.Substring(current, tokenLength));
 
                 current = current + tokenLength;
index 5eb614c..57a4bd8 100644 (file)
@@ -40,17 +40,7 @@ namespace System.Net.Http.Headers
             }
         }
 
-        public ICollection<NameValueHeaderValue> Parameters
-        {
-            get
-            {
-                if (_parameters == null)
-                {
-                    _parameters = new ObjectCollection<NameValueHeaderValue>();
-                }
-                return _parameters;
-            }
-        }
+        public ICollection<NameValueHeaderValue> Parameters => _parameters ??= new ObjectCollection<NameValueHeaderValue>();
 
         public string? Name
         {
index 260e302..36651d6 100644 (file)
@@ -55,17 +55,7 @@ namespace System.Net.Http.Headers
             }
         }
 
-        public ICollection<NameValueHeaderValue> Parameters
-        {
-            get
-            {
-                if (_parameters == null)
-                {
-                    _parameters = new ObjectCollection<NameValueHeaderValue>();
-                }
-                return _parameters;
-            }
-        }
+        public ICollection<NameValueHeaderValue> Parameters => _parameters ??= new ObjectCollection<NameValueHeaderValue>();
 
         [DisallowNull]
         public string? MediaType
index e7f58ea..3c624ab 100644 (file)
@@ -18,17 +18,7 @@ namespace System.Net.Http.Headers
 
         private ObjectCollection<NameValueHeaderValue>? _parameters;
 
-        public ICollection<NameValueHeaderValue> Parameters
-        {
-            get
-            {
-                if (_parameters == null)
-                {
-                    _parameters = new ObjectCollection<NameValueHeaderValue>();
-                }
-                return _parameters;
-            }
-        }
+        public ICollection<NameValueHeaderValue> Parameters => _parameters ??= new ObjectCollection<NameValueHeaderValue>();
 
         public NameValueWithParametersHeaderValue(string name)
             : base(name)
index bb6ef3b..5809b7c 100644 (file)
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Collections;
 using System.Collections.Generic;
-using System.Collections.ObjectModel;
 using System.Diagnostics;
 
 namespace System.Net.Http.Headers
 {
-    // We need to prevent 'null' values in the collection. Since List<T> allows them, we will create
-    // a custom collection class. It is less efficient than List<T> but only used for small collections.
-    internal sealed class ObjectCollection<T> : Collection<T> where T : class
+    /// <summary>An <see cref="ICollection{T}"/> list that prohibits null elements and that is optimized for a small number of elements.</summary>
+    [DebuggerDisplay("Count = {Count}")]
+    [DebuggerTypeProxy(nameof(DebugView))]
+    internal sealed class ObjectCollection<T> : ICollection<T> where T : class
     {
-        private static readonly Action<T> s_defaultValidator = CheckNotNull;
+        private const int DefaultSize = 4;
 
-        private readonly Action<T> _validator;
+        /// <summary>Optional delegate used to validate added items.</summary>
+        private readonly Action<T>? _validator;
+        /// <summary>null, a T, or a T[].</summary>
+        internal object? _items;
+        /// <summary>Number of elements stored in the collection.</summary>
+        internal int _size;
 
-        public ObjectCollection()
-            : this(s_defaultValidator)
+        public ObjectCollection() { }
+
+        public ObjectCollection(Action<T> validator) => _validator = validator;
+
+        public int Count => _size;
+
+        public bool IsReadOnly => false;
+
+        public void Add(T item)
         {
+            // Validate the item, either just by checking it for null, or using a custom validator,
+            // which should also check for null.
+            if (_validator is null)
+            {
+                if (item is null)
+                {
+                    throw new ArgumentNullException(nameof(item));
+                }
+            }
+            else
+            {
+                _validator.Invoke(item);
+                Debug.Assert(item != null);
+            }
+
+            if (_items is null)
+            {
+                // The collection is empty. Just store the new item directly.
+                _items = item;
+                _size = 1;
+            }
+            else if (_items is T existingItem)
+            {
+                // The collection has a single item stored directly.  Upgrade to
+                // an array, and store both the existing and new items.
+                Debug.Assert(_size == 1);
+                T[] items = new T[DefaultSize];
+                items[0] = existingItem;
+                items[1] = item;
+                _items = items;
+                _size = 2;
+            }
+            else
+            {
+                T[] array = (T[])_items;
+                int size = _size;
+                if ((uint)size < (uint)array.Length)
+                {
+                    // There's room in the existing array.  Add the item.
+                    array[size] = item;
+                }
+                else
+                {
+                    // We need to grow the array.  Do so, and store the new item.
+                    Debug.Assert(_size > 0);
+                    Debug.Assert(_size == array.Length);
+
+                    var newItems = new T[array.Length * 2];
+                    Array.Copy(array, newItems, size);
+                    _items = newItems;
+                    newItems[size] = item;
+                }
+                _size = size + 1;
+            }
         }
 
-        public ObjectCollection(Action<T> validator)
-            : base(new List<T>())
+        public void Clear()
         {
-            Debug.Assert(validator != null, $"{nameof(validator)} must not be null.");
-            _validator = validator;
+            _items = null;
+            _size = 0;
         }
 
-        // This is only used internally to enumerate the collection
-        // without the enumerator allocation.
-        public new List<T>.Enumerator GetEnumerator()
+        public bool Contains(T item) =>
+            ReferenceEquals(item, _items) ||
+            (_size != 0 && _items is T[] items && Array.IndexOf(items, item, 0, _size) != -1);
+
+        public void CopyTo(T[] array, int arrayIndex)
         {
-            return ((List<T>)Items).GetEnumerator();
+            if (_items is T[] items)
+            {
+                Array.Copy(items, 0, array, arrayIndex, _size);
+            }
+            else
+            {
+                Debug.Assert(_size == 0 || _size == 1);
+                if (array is null || _size > array.Length - arrayIndex)
+                {
+                    // Use Array.CopyTo to throw the right exceptions.
+                    new T[] { (T)_items! }.CopyTo(array!, arrayIndex);
+                }
+                else if (_size == 1)
+                {
+                    array[arrayIndex] = (T)_items!;
+                }
+            }
         }
 
-        protected override void InsertItem(int index, T item)
+        public bool Remove(T item)
         {
-            _validator(item);
-            base.InsertItem(index, item);
+            if (ReferenceEquals(_items, item))
+            {
+                _items = null;
+                _size = 0;
+                return true;
+            }
+
+            if (_items is T[] items)
+            {
+                int index = Array.IndexOf(items, item, 0, _size);
+                if (index != -1)
+                {
+                    _size--;
+                    if (index < _size)
+                    {
+                        Array.Copy(items, index + 1, items, index, _size - index);
+                    }
+                    items[_size] = null!;
+
+                    return true;
+                }
+            }
+
+            return false;
         }
 
-        protected override void SetItem(int index, T item)
+        public Enumerator GetEnumerator() => new Enumerator(this);
+        IEnumerator<T> IEnumerable<T>.GetEnumerator() => GetEnumerator();
+        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+
+        public struct Enumerator : IEnumerator<T>
         {
-            _validator(item);
-            base.SetItem(index, item);
+            private readonly ObjectCollection<T> _list;
+            private int _index;
+            private T _current;
+
+            internal Enumerator(ObjectCollection<T> list)
+            {
+                _list = list;
+                _index = 0;
+                _current = default!;
+            }
+
+            public void Dispose() { }
+
+            public bool MoveNext()
+            {
+                ObjectCollection<T> list = _list;
+
+                if ((uint)_index < (uint)list._size)
+                {
+                    _current = list._items is T[] items ? items[_index] : (T)list._items!;
+                    _index++;
+                    return true;
+                }
+
+                _index = _list._size + 1;
+                _current = default!;
+                return false;
+            }
+
+            public T Current => _current!;
+
+            object? IEnumerator.Current => _current;
+
+            void IEnumerator.Reset()
+            {
+                _index = 0;
+                _current = default!;
+            }
         }
 
-        private static void CheckNotNull(T item)
+        internal sealed class DebugView
         {
-            // Null values cannot be added to the collection.
-            if (item == null)
+            private readonly ObjectCollection<T> _collection;
+
+            public DebugView(ObjectCollection<T> collection) => _collection = collection ?? throw new ArgumentNullException(nameof(collection));
+
+            [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)]
+            public T[] Items
             {
-                throw new ArgumentNullException(nameof(item));
+                get
+                {
+                    T[] items = new T[_collection.Count];
+                    _collection.CopyTo(items, 0);
+                    return items;
+                }
             }
         }
     }
index b952875..a6b80ac 100644 (file)
@@ -25,17 +25,7 @@ namespace System.Net.Http.Headers
             }
         }
 
-        public ICollection<RangeItemHeaderValue> Ranges
-        {
-            get
-            {
-                if (_ranges == null)
-                {
-                    _ranges = new ObjectCollection<RangeItemHeaderValue>();
-                }
-                return _ranges;
-            }
-        }
+        public ICollection<RangeItemHeaderValue> Ranges => _ranges ??= new ObjectCollection<RangeItemHeaderValue>();
 
         public RangeHeaderValue()
         {
index 22daaea..c833b89 100644 (file)
@@ -21,17 +21,7 @@ namespace System.Net.Http.Headers
             get { return _value; }
         }
 
-        public ICollection<NameValueHeaderValue> Parameters
-        {
-            get
-            {
-                if (_parameters == null)
-                {
-                    _parameters = new ObjectCollection<NameValueHeaderValue>();
-                }
-                return _parameters;
-            }
-        }
+        public ICollection<NameValueHeaderValue> Parameters => _parameters ??= new ObjectCollection<NameValueHeaderValue>();
 
         internal TransferCodingHeaderValue()
         {
index 2b9a1ac..a48ed76 100644 (file)
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-using System;
 using System.Net.Http.Headers;
 
 using Xunit;
@@ -18,14 +17,13 @@ namespace System.Net.Http.Tests
             ObjectCollection<string> c = new ObjectCollection<string>();
 
             c.Add("value1");
-            c.Insert(0, "value2");
+            c.Add("value2");
 
             Assert.Throws<ArgumentNullException>(() => { c.Add(null); });
-            Assert.Throws<ArgumentNullException>(() => { c[0] = null; });
 
             Assert.Equal(2, c.Count);
-            Assert.Equal("value2", c[0]);
-            Assert.Equal("value1", c[1]);
+            Assert.True(c.Contains("value2"));
+            Assert.True(c.Contains("value1"));
 
             // Use custom validator
             c = new ObjectCollection<string>(item =>
@@ -37,13 +35,11 @@ namespace System.Net.Http.Tests
             });
 
             c.Add("value1");
-            c[0] = "value2";
 
             Assert.Throws<InvalidOperationException>(() => { c.Add(null); });
-            Assert.Throws<InvalidOperationException>(() => { c[0] = null; });
 
             Assert.Equal(1, c.Count);
-            Assert.Equal("value2", c[0]);
+            Assert.True(c.Contains("value1"));
         }
     }
 }