[mlir] Fix circular dialect initialization
authorMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 09:45:15 +0000 (11:45 +0200)
committerMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 09:50:37 +0000 (11:50 +0200)
This change fixes a bug where a dialect is initialized multiple times. This triggers an assertion when the ops of the dialect are registered (`error: operation named ... is already registered`).

This bug can be triggered as follows:

1. Dialect A depends on dialect B (as per ADialect.td).

2. Somewhere there is an extension of dialect B that depends on dialect A (e.g., it defines external models create ops from dialect A). E.g.:
```
registry.addExtension(+[](MLIRContext *ctx, BDialect *dialect) {
  BDialectOp::attachInterface ...
  ctx->loadDialect<ADialect>();
});
```

3. When dialect A is loaded, its `initialize` function is called twice:

```
     ADialect::ADialect()
        |     |
        |     v
        |   ADialect::initialize()
        v
     getOrLoadDialect<BDialect>()
        |
        v
     (load extension of BDialect)
        |
        v
     ctx->loadDialect<ADialect>()  // user wrote this in the extension
        |
        v
     getOrLoadDialect<ADialect>()  // the dialect is not "fully" loaded yet
        |
        v
     ADialect::ADialect()
        |
        v
     ADialect::initialize()
```

An example of a dialect extension that depends on other dialects is `Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp`. That particular dialect extension does not trigger this bug. (It would trigger this bug if the SCF dialect would depend on the Tensor dialect.)

This change introduces a new dialect state: dialects that are currently being loaded. Same as dialects that were already fully loaded (and initialized), dialects that are in the process of being loaded are not loaded a second time.

Differential Revision: https://reviews.llvm.org/D136685

mlir/include/mlir/IR/MLIRContext.h
mlir/lib/IR/MLIRContext.cpp
mlir/tools/mlir-tblgen/DialectGen.cpp

index c162b00..b87dd27 100644 (file)
@@ -98,16 +98,22 @@ public:
         }));
   }
 
+  /// Return true if the given dialect is currently loading.
+  bool isDialectLoading(StringRef dialectNamespace);
+
   /// Load a dialect in the context.
   template <typename Dialect>
   void loadDialect() {
-    getOrLoadDialect<Dialect>();
+    // Do not load the dialect if it is currently loading. This can happen if a
+    // dialect initializer triggers loading the same dialect recursively.
+    if (!isDialectLoading(Dialect::getDialectNamespace()))
+      getOrLoadDialect<Dialect>();
   }
 
   /// Load a list dialects in the context.
   template <typename Dialect, typename OtherDialect, typename... MoreDialects>
   void loadDialect() {
-    getOrLoadDialect<Dialect>();
+    loadDialect<Dialect>();
     loadDialect<OtherDialect, MoreDialects...>();
   }
 
index 7ddcc2f..896938d 100644 (file)
@@ -429,9 +429,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
           ") while in a multi-threaded execution context (maybe "
           "the PassManager): this can indicate a "
           "missing `dependentDialects` in a pass for example.");
-#endif
-    std::unique_ptr<Dialect> &dialect =
-        impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second;
+#endif // NDEBUG
+    // nullptr indicates that the dialect is currently being loaded.
+    impl.loadedDialects[dialectNamespace] = nullptr;
+    std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace] =
+        ctor();
     assert(dialect && "dialect ctor failed");
 
     // Refresh all the identifiers dialect field, this catches cases where a
@@ -449,6 +451,14 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
     return dialect.get();
   }
 
+#ifndef NDEBUG
+  if (dialectIt->second == nullptr)
+    llvm::report_fatal_error(
+        "Loading (and getting) a dialect (" + dialectNamespace +
+        ") while the same dialect is still loading: use loadDialect instead "
+        "of getOrLoadDialect.");
+#endif // NDEBUG
+
   // Abort if dialect with namespace has already been registered.
   std::unique_ptr<Dialect> &dialect = dialectIt->second;
   if (dialect->getTypeID() != dialectID)
@@ -458,6 +468,12 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
   return dialect.get();
 }
 
+bool MLIRContext::isDialectLoading(StringRef dialectNamespace) {
+  auto it = getImpl().loadedDialects.find(dialectNamespace);
+  // nullptr indicates that the dialect is currently being loaded.
+  return it != getImpl().loadedDialects.end() && it->second == nullptr;
+}
+
 DynamicDialect *MLIRContext::getOrLoadDynamicDialect(
     StringRef dialectNamespace, function_ref<void(DynamicDialect *)> ctor) {
   auto &impl = getImpl();
index c7e42ac..1085407 100644 (file)
@@ -107,7 +107,7 @@ public:
 /// Registration for a single dependent dialect: to be inserted in the ctor
 /// above for each dependent dialect.
 const char *const dialectRegistrationTemplate = R"(
-    getContext()->getOrLoadDialect<{0}>();
+    getContext()->loadDialect<{0}>();
 )";
 
 /// The code block for the attribute parser/printer hooks.