Revert "Fix tsan problem where the per-thread shared_ptr() can be locked right before...
authorMitch Phillips <31459023+hctim@users.noreply.github.com>
Wed, 1 Feb 2023 18:35:56 +0000 (10:35 -0800)
committerMitch Phillips <31459023+hctim@users.noreply.github.com>
Wed, 1 Feb 2023 18:35:56 +0000 (10:35 -0800)
This reverts commit bcc10817d5569172ee065015747e226280e9b698.

Reason: Broke the aarch64-asan bot. More information available in the
Phabricator review: https://reviews.llvm.org/D140931

mlir/include/mlir/Support/ThreadLocalCache.h

index 1be94ca..e98fae6 100644 (file)
@@ -25,40 +25,12 @@ namespace mlir {
 /// cache has very large lock contention.
 template <typename ValueT>
 class ThreadLocalCache {
-  // Keep a separate shared_ptr protected state that can be acquired atomically
-  // instead of using shared_ptr's for each value. This avoids a problem
-  // where the instance shared_ptr is locked() successfully, and then the
-  // ThreadLocalCache gets destroyed before remove() can be called successfully.
-  struct PerInstanceState {
-    /// Remove the given value entry. This is generally called when a thread
-    /// local cache is destructing.
-    void remove(ValueT *value) {
-      // Erase the found value directly, because it is guaranteed to be in the
-      // list.
-      llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
-      auto it =
-          llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
-            return instance.get() == value;
-          });
-      assert(it != instances.end() && "expected value to exist in cache");
-      instances.erase(it);
-    }
-
-    /// Owning pointers to all of the values that have been constructed for this
-    /// object in the static cache.
-    SmallVector<std::unique_ptr<ValueT>, 1> instances;
-
-    /// A mutex used when a new thread instance has been added to the cache for
-    /// this object.
-    llvm::sys::SmartMutex<true> instanceMutex;
-  };
-
   /// The type used for the static thread_local cache. This is a map between an
   /// instance of the non-static cache and a weak reference to an instance of
   /// ValueT. We use a weak reference here so that the object can be destroyed
   /// without needing to lock access to the cache itself.
-  struct CacheType
-      : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
+  struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *,
+                                                std::weak_ptr<ValueT>> {
     ~CacheType() {
       // Remove the values of this cache that haven't already expired.
       for (auto &it : *this)
@@ -88,16 +60,15 @@ public:
   ValueT &get() {
     // Check for an already existing instance for this thread.
     CacheType &staticCache = getStaticCache();
-    std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
+    std::weak_ptr<ValueT> &threadInstance = staticCache[this];
     if (std::shared_ptr<ValueT> value = threadInstance.lock())
       return *value;
 
     // Otherwise, create a new instance for this thread.
-    llvm::sys::SmartScopedLock<true> threadInstanceLock(
-        perInstanceState->instanceMutex);
-    perInstanceState->instances.push_back(std::make_unique<ValueT>());
-    ValueT *instance = perInstanceState->instances.back().get();
-    threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
+    instances.push_back(std::make_shared<ValueT>());
+    std::shared_ptr<ValueT> &instance = instances.back();
+    threadInstance = instance;
 
     // Before returning the new instance, take the chance to clear out any used
     // entries in the static map. The cache is only cleared within the same
@@ -119,8 +90,26 @@ private:
     return cache;
   }
 
-  std::shared_ptr<PerInstanceState> perInstanceState =
-      std::make_shared<PerInstanceState>();
+  /// Remove the given value entry. This is generally called when a thread local
+  /// cache is destructing.
+  void remove(ValueT *value) {
+    // Erase the found value directly, because it is guaranteed to be in the
+    // list.
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
+    auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) {
+      return instance.get() == value;
+    });
+    assert(it != instances.end() && "expected value to exist in cache");
+    instances.erase(it);
+  }
+
+  /// Owning pointers to all of the values that have been constructed for this
+  /// object in the static cache.
+  SmallVector<std::shared_ptr<ValueT>, 1> instances;
+
+  /// A mutex used when a new thread instance has been added to the cache for
+  /// this object.
+  llvm::sys::SmartMutex<true> instanceMutex;
 };
 } // namespace mlir