[mlir][StorageUniquer] Use allocators per thread instead of per shard
authorRiver Riddle <riddleriver@gmail.com>
Tue, 14 Mar 2023 00:32:10 +0000 (17:32 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 16 Mar 2023 21:56:22 +0000 (14:56 -0700)
This greatly reduces the number of allocators we create, while still
retaining thread safety. Reducing the number of allocators is much
better for locality and memory usage; this revision drops memory
usage for some MLIR heavy workloads (with lots of attributes/types)
by >=5%. This is due to the observation that the number of threads
is effectively always smaller than the number of parametric attributes/types.

Differential Revision: https://reviews.llvm.org/D145991

mlir/lib/Support/StorageUniquer.cpp

index a33eb53..f670318 100644 (file)
@@ -82,9 +82,6 @@ private:
     /// The set containing the allocated storage instances.
     StorageTypeSet instances;
 
-    /// Allocator to use when constructing derived instances.
-    StorageAllocator allocator;
-
 #if LLVM_ENABLE_THREADS != 0
     /// A mutex to keep uniquing thread-safe.
     llvm::sys::SmartRWMutex<true> mutex;
@@ -93,13 +90,12 @@ private:
 
   /// Get or create an instance of a param derived type in an thread-unsafe
   /// fashion.
-  BaseStorage *
-  getOrCreateUnsafe(Shard &shard, LookupKey &key,
-                    function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+  BaseStorage *getOrCreateUnsafe(Shard &shard, LookupKey &key,
+                                 function_ref<BaseStorage *()> ctorFn) {
     auto existing = shard.instances.insert_as({key.hashValue}, key);
     BaseStorage *&storage = existing.first->storage;
     if (existing.second)
-      storage = ctorFn(shard.allocator);
+      storage = ctorFn();
     return storage;
   }
 
@@ -135,10 +131,9 @@ public:
     }
   }
   /// Get or create an instance of a parametric type.
-  BaseStorage *
-  getOrCreate(bool threadingIsEnabled, unsigned hashValue,
-              function_ref<bool(const BaseStorage *)> isEqual,
-              function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+  BaseStorage *getOrCreate(bool threadingIsEnabled, unsigned hashValue,
+                           function_ref<bool(const BaseStorage *)> isEqual,
+                           function_ref<BaseStorage *()> ctorFn) {
     Shard &shard = getShard(hashValue);
     ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
     if (!threadingIsEnabled)
@@ -163,17 +158,20 @@ public:
     llvm::sys::SmartScopedWriter<true> typeLock(shard.mutex);
     return localInst = getOrCreateUnsafe(shard, lookupKey, ctorFn);
   }
+
   /// Run a mutation function on the provided storage object in a thread-safe
   /// way.
-  LogicalResult
-  mutate(bool threadingIsEnabled, BaseStorage *storage,
-         function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
-    Shard &shard = getShardFor(storage);
+  LogicalResult mutate(bool threadingIsEnabled, BaseStorage *storage,
+                       function_ref<LogicalResult()> mutationFn) {
     if (!threadingIsEnabled)
-      return mutationFn(shard.allocator);
+      return mutationFn();
 
+    // Get a shard to use for mutating this storage instance. It doesn't need to
+    // be the same shard as the original allocation, but does need to be
+    // deterministic.
+    Shard &shard = getShard(llvm::hash_value(storage));
     llvm::sys::SmartScopedWriter<true> lock(shard.mutex);
-    return mutationFn(shard.allocator);
+    return mutationFn();
   }
 
 private:
@@ -197,18 +195,6 @@ private:
     return *shard;
   }
 
-  /// Return the shard that allocated the provided storage object.
-  Shard &getShardFor(BaseStorage *storage) {
-    for (size_t i = 0; i != numShards; ++i) {
-      if (Shard *shard = shards[i].load(std::memory_order_acquire)) {
-        llvm::sys::SmartScopedReader<true> lock(shard->mutex);
-        if (shard->allocator.allocated(storage))
-          return *shard;
-      }
-    }
-    llvm_unreachable("expected storage object to have a valid shard");
-  }
-
   /// A thread local cache for storage objects. This helps to reduce the lock
   /// contention when an object already existing in the cache.
   ThreadLocalCache<StorageTypeSet> localCache;
@@ -281,8 +267,9 @@ struct StorageUniquerImpl {
     assert(parametricUniquers.count(id) &&
            "creating unregistered storage instance");
     ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
-    return storageUniquer.getOrCreate(threadingIsEnabled, hashValue, isEqual,
-                                      ctorFn);
+    return storageUniquer.getOrCreate(
+        threadingIsEnabled, hashValue, isEqual,
+        [&] { return ctorFn(getThreadSafeAllocator()); });
   }
 
   /// Run a mutation function on the provided storage object in a thread-safe
@@ -293,7 +280,34 @@ struct StorageUniquerImpl {
     assert(parametricUniquers.count(id) &&
            "mutating unregistered storage instance");
     ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
-    return storageUniquer.mutate(threadingIsEnabled, storage, mutationFn);
+    return storageUniquer.mutate(threadingIsEnabled, storage, [&] {
+      return mutationFn(getThreadSafeAllocator());
+    });
+  }
+
+  /// Return an allocator that can be used to safely allocate instances on the
+  /// current thread.
+  StorageAllocator &getThreadSafeAllocator() {
+#if LLVM_ENABLE_THREADS != 0
+    if (!threadingIsEnabled)
+      return allocator;
+
+    // If the allocator has not been initialized, create a new one.
+    StorageAllocator *&threadAllocator = threadSafeAllocator.get();
+    if (!threadAllocator) {
+      threadAllocator = new StorageAllocator();
+
+      // Record this allocator, given that we don't want it to be destroyed when
+      // the thread dies.
+      llvm::sys::SmartScopedLock<true> lock(threadAllocatorMutex);
+      threadAllocators.push_back(
+          std::unique_ptr<StorageAllocator>(threadAllocator));
+    }
+
+    return *threadAllocator;
+#else
+    return allocator;
+#endif
   }
 
   //===--------------------------------------------------------------------===//
@@ -314,6 +328,22 @@ struct StorageUniquerImpl {
   // Instance Storage
   //===--------------------------------------------------------------------===//
 
+#if LLVM_ENABLE_THREADS != 0
+  /// A thread local set of allocators used for uniquing parametric instances,
+  /// or other data allocated in thread volatile situations.
+  ThreadLocalCache<StorageAllocator *> threadSafeAllocator;
+
+  /// All of the allocators that have been created for thread based allocation.
+  std::vector<std::unique_ptr<StorageAllocator>> threadAllocators;
+
+  /// A mutex used for safely adding a new thread allocator.
+  llvm::sys::SmartMutex<true> threadAllocatorMutex;
+#endif
+
+  /// Main allocator used for uniquing singleton instances, and other state when
+  /// thread safety is guaranteed.
+  StorageAllocator allocator;
+
   /// Map of type ids to the storage uniquer to use for registered objects.
   DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
       parametricUniquers;
@@ -322,9 +352,6 @@ struct StorageUniquerImpl {
   /// singleton.
   DenseMap<TypeID, BaseStorage *> singletonInstances;
 
-  /// Allocator used for uniquing singleton instances.
-  StorageAllocator singletonAllocator;
-
   /// Flag specifying if multi-threading is enabled within the uniquer.
   bool threadingIsEnabled = true;
 };
@@ -378,7 +405,7 @@ void StorageUniquer::registerSingletonImpl(
     TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
   assert(!impl->singletonInstances.count(id) &&
          "storage class already registered");
-  impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator));
+  impl->singletonInstances.try_emplace(id, ctorFn(impl->allocator));
 }
 
 /// Implementation for mutating an instance of a derived storage.