Fix ThreadLocal tracking behavior (#56956)
authorDavid Wrighton <davidwr@microsoft.com>
Tue, 10 Aug 2021 08:04:09 +0000 (01:04 -0700)
committerGitHub <noreply@github.com>
Tue, 10 Aug 2021 08:04:09 +0000 (01:04 -0700)
- Before this change the trackAllValues behavior for ThreadLocal<SomeParticularType> was defined by the first instance of thread local to have its value set on the thread
  - This could lead to unpredictable memory leaks (where the value was improperly tracked even though it wasn't supposed to be) This reproduces as a memory leak with no other observable behavior
  - Or data loss, where the Values collection was missing entries.
- Change the model so that ThreadLocal<T> trackAllValues behavior is properly defined by the exact ThreadLocal<T> instance in use
- Implement by keeping track of the track all changes behavior within the IdManager

Fixes #55796

src/libraries/System.Private.CoreLib/src/System/Threading/ThreadLocal.cs
src/libraries/System.Threading/tests/ThreadLocalTests.cs

index 539ac4f..7b73e8d 100644 (file)
@@ -123,7 +123,7 @@ namespace System.Threading
             _trackAllValues = trackAllValues;
 
             // Assign the ID and mark the instance as initialized.
-             _idComplement = ~s_idManager.GetId();
+             _idComplement = ~s_idManager.GetId(trackAllValues);
 
             // As the last step, mark the instance as fully initialized. (Otherwise, if _initialized=false, we know that an exception
             // occurred in the constructor.)
@@ -201,7 +201,7 @@ namespace System.Threading
                 }
             }
             _linkedSlot = null;
-            s_idManager.ReturnId(id);
+            s_idManager.ReturnId(id, _trackAllValues);
         }
 
         #endregion
@@ -346,7 +346,7 @@ namespace System.Threading
             if (slotArray == null)
             {
                 slotArray = new LinkedSlotVolatile[GetNewTableSize(id + 1)];
-                ts_finalizationHelper = new FinalizationHelper(slotArray, _trackAllValues);
+                ts_finalizationHelper = new FinalizationHelper(slotArray);
                 ts_slotArray = slotArray;
             }
 
@@ -675,42 +675,66 @@ namespace System.Threading
         {
             // The next ID to try
             private int _nextIdToTry;
+            // Keep track of the count of non-TrackAllValues ids in use. A count of 0 leads to more efficient thread cleanup
+            private volatile int _idsThatDoNotTrackAllValues;
 
-            // Stores whether each ID is free or not. Additionally, the object is also used as a lock for the IdManager.
-            private readonly List<bool> _freeIds = new List<bool>();
+            private const byte IdFree = 0;
+            private const byte TrackAllValuesAllocated = 1;
+            private const byte DoNotTrackAllValuesAllocated = 2;
 
-            internal int GetId()
+            // Stores whether each ID is free or not, and if it tracksAllValues or not. Additionally, the object is also used as a lock for the IdManager.
+            private readonly List<byte> _ids = new List<byte>();
+
+            internal int GetId(bool trackAllValues)
             {
-                lock (_freeIds)
+                lock (_ids)
                 {
                     int availableId = _nextIdToTry;
-                    while (availableId < _freeIds.Count)
+                    while (availableId < _ids.Count)
                     {
-                        if (_freeIds[availableId]) { break; }
+                        if (_ids[availableId] == IdFree) { break; }
                         availableId++;
                     }
 
-                    if (availableId == _freeIds.Count)
+                    byte allocatedFlag = trackAllValues ? TrackAllValuesAllocated : DoNotTrackAllValuesAllocated;
+                    if (availableId == _ids.Count)
                     {
-                        _freeIds.Add(false);
+                        _ids.Add(allocatedFlag);
                     }
                     else
                     {
-                        _freeIds[availableId] = false;
+                        _ids[availableId] = allocatedFlag;
                     }
 
+                    if (!trackAllValues)
+                        _idsThatDoNotTrackAllValues++;
+
                     _nextIdToTry = availableId + 1;
 
                     return availableId;
                 }
             }
 
+            // Identify if an allocated id tracks all values or not
+            internal bool IdTracksAllValues(int id)
+            {
+                lock (_ids)
+                {
+                    return _ids[id] == TrackAllValuesAllocated;
+                }
+            }
+
+            internal int IdsThatDoNotTrackValuesCount => _idsThatDoNotTrackAllValues;
+
             // Return an ID to the pool
-            internal void ReturnId(int id)
+            internal void ReturnId(int id, bool idTracksAllValues)
             {
-                lock (_freeIds)
+                lock (_ids)
                 {
-                    _freeIds[id] = true;
+                    if (!idTracksAllValues)
+                        _idsThatDoNotTrackAllValues--;
+
+                    _ids[id] = IdFree;
                     if (id < _nextIdToTry) _nextIdToTry = id;
                 }
             }
@@ -731,18 +755,17 @@ namespace System.Threading
         private sealed class FinalizationHelper
         {
             internal LinkedSlotVolatile[] SlotArray;
-            private readonly bool _trackAllValues;
 
-            internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues)
+            internal FinalizationHelper(LinkedSlotVolatile[] slotArray)
             {
                 SlotArray = slotArray;
-                _trackAllValues = trackAllValues;
             }
 
             ~FinalizationHelper()
             {
                 LinkedSlotVolatile[] slotArray = SlotArray;
                 Debug.Assert(slotArray != null);
+                int idsThatDoNotTrackAllValuesCountRemaining = s_idManager.IdsThatDoNotTrackValuesCount;
 
                 for (int i = 0; i < slotArray.Length; i++)
                 {
@@ -753,7 +776,10 @@ namespace System.Threading
                         continue;
                     }
 
-                    if (_trackAllValues)
+                    // If there are no ids that do not TrackAllValues, we don't need to call the IdTracksAllValues function.
+                    // This is an improvement as that function requires taking a lock.
+                    if (idsThatDoNotTrackAllValuesCountRemaining == 0 ||
+                        s_idManager.IdTracksAllValues(i))
                     {
                         // Set the SlotArray field to null to release the slot array.
                         linkedSlot._slotArray = null;
@@ -764,6 +790,13 @@ namespace System.Threading
                         // the table will be have been removed, and so the table can get GC'd.
                         lock (s_idManager)
                         {
+                            // If the slot wasn't disposed between reading it above and entering the lock
+                            // decrement idsThatDoNotTrackAllValuesCountRemaining
+                            if (slotArray[i].Value != null)
+                            {
+                                idsThatDoNotTrackAllValuesCountRemaining--;
+                            }
+
                             if (linkedSlot._next != null)
                             {
                                 linkedSlot._next._previous = linkedSlot._previous;
index 54cdf6d..7a1ff77 100644 (file)
@@ -435,6 +435,37 @@ namespace System.Threading.Tests
             Assert.False(failed);
         }
 
+        private enum UniqueEnumUsedOnlyWithNonInterferenceTest { True, False }
+
+        [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
+        public static void TestUnrelatedThreadLocalDoesNotInterfereWithTrackAllValues()
+        {
+            ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest> localThatDoesNotTrackValues = new ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest>(false);
+            ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest> localThatDoesTrackValues = new ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest>(true);
+
+            for (int i = 0; i < 10; i++)
+            {
+                Thread t = new Thread(Work);
+                t.Start();
+                t.Join();
+            }
+            GC.Collect();
+            GC.WaitForPendingFinalizers();
+            int count = 0;
+            foreach (var x in localThatDoesTrackValues.Values)
+            {
+                if (x == UniqueEnumUsedOnlyWithNonInterferenceTest.True)
+                    count++;
+            }
+
+            Assert.Equal(10, count);
+            void Work()
+            {
+                localThatDoesNotTrackValues.Value = UniqueEnumUsedOnlyWithNonInterferenceTest.True;
+                localThatDoesTrackValues.Value = UniqueEnumUsedOnlyWithNonInterferenceTest.True;
+            }
+        }
+
         private class SetMreOnFinalize
         {
             private ManualResetEventSlim _mres;