Fix CookieContainer memory leak (dotnet/corefx#34006)
authorCaesar Chen <caesar1995@users.noreply.github.com>
Thu, 10 Jan 2019 19:44:21 +0000 (11:44 -0800)
committerGitHub <noreply@github.com>
Thu, 10 Jan 2019 19:44:21 +0000 (11:44 -0800)
* fix CookieContainer memory leak

* disable test on netfx

* address feedback

Commit migrated from https://github.com/dotnet/corefx/commit/41e4a8c74f8cdb4843c6d11a404ad094b54ae992

src/libraries/System.Net.Primitives/src/System/Net/CookieContainer.cs
src/libraries/System.Net.Primitives/tests/FunctionalTests/CookieContainerTest.cs

index 8b50d80..46382b1 100644 (file)
@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Collections;
+using System.Collections.Generic;
 using System.Diagnostics;
 using System.IO;
 using System.Net.NetworkInformation;
@@ -332,12 +333,18 @@ namespace System.Net
                         return; // Cannot age: reject new cookie
                     }
 
-                    // About to change the collection
+                    // About to change the collection.
                     lock (cookies)
                     {
                         m_count += cookies.InternalAdd(cookie, true);
                     }
                 }
+
+                // We don't want to cleanup m_domaintable/m_list too often. Add check to avoid overhead.
+                if (m_domainTable.Count > m_count || pathList.Count > m_maxCookiesPerDomain)
+                {
+                    DomainTableCleanup();
+                }
             }
             catch (OutOfMemoryException)
             {
@@ -505,6 +512,52 @@ namespace System.Net
             return true;
         }
 
+        private void DomainTableCleanup()
+        {
+            var removePathList = new List<object>();
+            var removeDomainList = new List<string>();
+
+            string currentDomain;
+            PathList pathList;
+
+            lock (m_domainTable.SyncRoot)
+            {
+                // Manual use of IDictionaryEnumerator instead of foreach to avoid DictionaryEntry box allocations.
+                IDictionaryEnumerator enumerator = m_domainTable.GetEnumerator();
+                while (enumerator.MoveNext())
+                {
+                    currentDomain = (string)enumerator.Key;
+                    pathList = (PathList)enumerator.Value;
+
+                    lock (pathList.SyncRoot)
+                    {
+                        IDictionaryEnumerator e = pathList.GetEnumerator();
+                        while (e.MoveNext())
+                        {
+                            CookieCollection cc = (CookieCollection)e.Value;
+                            if (cc.Count == 0)
+                            {
+                                removePathList.Add(e.Key);
+                            }
+                        }
+
+                        foreach (var key in removePathList)
+                        {
+                            pathList.Remove(key);
+                        }
+
+                        removePathList.Clear();
+                        if (pathList.Count == 0) removeDomainList.Add(currentDomain);
+                    }
+                }
+
+                foreach (var key in removeDomainList)
+                {
+                    m_domainTable.Remove(key);
+                }
+            }
+        }
+
         // Return number of cookies removed from the collection.
         private int ExpireCollection(CookieCollection cc)
         {
@@ -841,8 +894,7 @@ namespace System.Net
                     }
                 }
 
-                // Remove unused domain
-                // (This is the only place that does domain removal)
+                // Remove unused domain.
                 if (pathList.Count == 0)
                 {
                     lock (m_domainTable.SyncRoot)
@@ -1032,6 +1084,14 @@ namespace System.Net
             }
         }
 
+        internal void Remove(object key)
+        {
+            lock (SyncRoot)
+            {
+                m_list.Remove(key);
+            }
+        }
+
         internal object SyncRoot => m_list.SyncRoot;
 
         [Serializable]
index f0e2a35..1f2e00c 100644 (file)
@@ -2,7 +2,9 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
+using System.Collections;
 using System.Collections.Generic;
+using System.Reflection;
 using System.Threading.Tasks;
 
 using Xunit;
@@ -164,5 +166,42 @@ namespace System.Net.Primitives.Functional.Tests
             Assert.Throws<ArgumentNullException>(() => cc.SetCookies(null, "")); // Null uri
             Assert.Throws<ArgumentNullException>(() => cc.SetCookies(u5, null)); // Null header
         }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework)] // .NET Framework will not perform domainTable clean up.
+        public static void AddCookies_CapacityReached_OldCookiesRemoved(bool isFromSameDomain)
+        {
+            const int Capacity = 10;
+            const int TotalCookieCount = 100;
+            var cookieContainer = new CookieContainer(Capacity);
+            Cookie cookie;
+
+            for (int i = 0; i < TotalCookieCount; i++)
+            {
+                if (isFromSameDomain)
+                {
+                    cookie = new Cookie("name1", "value1", $"/{i}", "test.com");
+                }
+                else
+                {
+                    cookie = new Cookie("name1", "value1", "/", $"test{i}.com");
+                }
+
+                cookieContainer.Add(cookie);
+            }
+
+            Assert.Equal(Capacity, cookieContainer.Count);
+
+            if (!isFromSameDomain)
+            {
+                FieldInfo domainTableField = typeof(CookieContainer).GetField("m_domainTable", BindingFlags.Instance | BindingFlags.NonPublic);
+                Assert.NotNull(domainTableField);
+                Hashtable domainTable = domainTableField.GetValue(cookieContainer) as Hashtable;
+                Assert.NotNull(domainTable);
+                Assert.Equal(Capacity, domainTable.Count);
+            }
+        }
     }
 }