[mlir] Remove locking for dialect/operation registration.
authorRiver Riddle <riddleriver@gmail.com>
Tue, 30 Jun 2020 22:43:03 +0000 (15:43 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 30 Jun 2020 22:52:33 +0000 (15:52 -0700)
Moving forward dialects should only be registered in a thread safe context. This matches the existing usage we have today, but it allows for removing quite a bit of expensive locking from the context.

This led to ~.5 a second compile time improvement when running one conversion pass on a very large .mlir file(hundreds of thousands of operations).

Differential Revision: https://reviews.llvm.org/D82595

mlir/include/mlir/IR/Dialect.h
mlir/lib/IR/MLIRContext.cpp

index c017c7b..043fde9 100644 (file)
@@ -258,10 +258,12 @@ private:
 };
 /// Registers all dialects and hooks from the global registries with the
 /// specified MLIRContext.
+/// Note: This method is not thread-safe.
 void registerAllDialects(MLIRContext *context);
 
 /// Utility to register a dialect. Client can register their dialect with the
 /// global registry by calling registerDialect<MyDialect>();
+/// Note: This method is not thread-safe.
 template <typename ConcreteDialect> void registerDialect() {
   Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
                                     [](MLIRContext *ctx) {
index 1e53525..4c31ef3 100644 (file)
@@ -270,10 +270,6 @@ public:
   // Other
   //===--------------------------------------------------------------------===//
 
-  /// A general purpose mutex to lock access to parts of the context that do not
-  /// have a more specific mutex, e.g. registry operations.
-  llvm::sys::SmartRWMutex<true> contextMutex;
-
   /// This is a list of dialects that are created referring to this context.
   /// The MLIRContext owns the objects.
   std::vector<std::unique_ptr<Dialect>> dialects;
@@ -425,8 +421,6 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
 
 /// Return information about all registered IR dialects.
 std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
-  // Lock access to the context registry.
-  ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
   std::vector<Dialect *> result;
   result.reserve(impl->dialects.size());
   for (auto &dialect : impl->dialects)
@@ -437,9 +431,6 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
 /// Get a registered IR dialect with the given namespace. If none is found,
 /// then return nullptr.
 Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
-  // Lock access to the context registry.
-  ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
-
   // Dialects are sorted by name, so we can use binary search for lookup.
   auto it = llvm::lower_bound(
       impl->dialects, name,
@@ -455,9 +446,6 @@ void Dialect::registerDialect(MLIRContext *context) {
   auto &impl = context->getImpl();
   std::unique_ptr<Dialect> dialect(this);
 
-  // Lock access to the context registry.
-  ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
-
   // Get the correct insertion position sorted by namespace.
   auto insertPt = llvm::lower_bound(
       impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
@@ -524,35 +512,26 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
 /// efficient, typically you should ask the operations about their properties
 /// directly.
 std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
-  std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
-
-  { // Lock access to the context registry.
-    ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
-
-    // We just have the operations in a non-deterministic hash table order. Dump
-    // into a temporary array, then sort it by operation name to get a stable
-    // ordering.
-    llvm::StringMap<AbstractOperation> &registeredOps =
-        impl->registeredOperations;
-
-    opsToSort.reserve(registeredOps.size());
-    for (auto &elt : registeredOps)
-      opsToSort.push_back({elt.first(), &elt.second});
-  }
-
-  llvm::array_pod_sort(opsToSort.begin(), opsToSort.end());
+  // We just have the operations in a non-deterministic hash table order. Dump
+  // into a temporary array, then sort it by operation name to get a stable
+  // ordering.
+  llvm::StringMap<AbstractOperation> &registeredOps =
+      impl->registeredOperations;
 
   std::vector<AbstractOperation *> result;
-  result.reserve(opsToSort.size());
-  for (auto &elt : opsToSort)
-    result.push_back(elt.second);
+  result.reserve(registeredOps.size());
+  for (auto &elt : registeredOps)
+    result.push_back(&elt.second);
+  llvm::array_pod_sort(
+      result.begin(), result.end(),
+      [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
+        return (*lhs)->name.compare((*rhs)->name);
+      });
+
   return result;
 }
 
 bool MLIRContext::isOperationRegistered(StringRef name) {
-  // Lock access to the context registry.
-  ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
-
   return impl->registeredOperations.count(name);
 }
 
@@ -561,12 +540,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
          "op name doesn't start with dialect namespace");
   assert(&opInfo.dialect == this && "Dialect object mismatch");
   auto &impl = context->getImpl();
-
-  // Lock access to the context registry.
   StringRef opName = opInfo.name;
-  ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
   if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
-    llvm::errs() << "error: operation named '" << opName
+    llvm::errs() << "error: operation named '" << opInfo.name
                  << "' is already registered.\n";
     abort();
   }
@@ -574,9 +550,6 @@ void Dialect::addOperation(AbstractOperation opInfo) {
 
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
   auto &impl = context->getImpl();
-
-  // Lock access to the context registry.
-  ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
   auto *newInfo =
       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
           AbstractType(std::move(typeInfo));
@@ -586,9 +559,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
 
 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
   auto &impl = context->getImpl();
-
-  // Lock access to the context registry.
-  ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
   auto *newInfo =
       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
           AbstractAttribute(std::move(attrInfo));
@@ -612,9 +582,6 @@ const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
 const AbstractOperation *AbstractOperation::lookup(StringRef opName,
                                                    MLIRContext *context) {
   auto &impl = context->getImpl();
-
-  // Lock access to the context registry.
-  ScopedReaderLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
   auto it = impl.registeredOperations.find(opName);
   if (it != impl.registeredOperations.end())
     return &it->second;