From 4433e52e69b1ce19b1d3c756e6d3262170ad4a30 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 27 Oct 2022 11:45:15 +0200 Subject: [PATCH] [mlir] Fix circular dialect initialization This change fixes a bug where a dialect is initialized multiple times. This triggers an assertion when the ops of the dialect are registered (`error: operation named ... is already registered`). This bug can be triggered as follows: 1. Dialect A depends on dialect B (as per ADialect.td). 2. Somewhere there is an extension of dialect B that depends on dialect A (e.g., it defines external models create ops from dialect A). E.g.: ``` registry.addExtension(+[](MLIRContext *ctx, BDialect *dialect) { BDialectOp::attachInterface ... ctx->loadDialect(); }); ``` 3. When dialect A is loaded, its `initialize` function is called twice: ``` ADialect::ADialect() | | | v | ADialect::initialize() v getOrLoadDialect() | v (load extension of BDialect) | v ctx->loadDialect() // user wrote this in the extension | v getOrLoadDialect() // the dialect is not "fully" loaded yet | v ADialect::ADialect() | v ADialect::initialize() ``` An example of a dialect extension that depends on other dialects is `Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp`. That particular dialect extension does not trigger this bug. (It would trigger this bug if the SCF dialect would depend on the Tensor dialect.) This change introduces a new dialect state: dialects that are currently being loaded. Same as dialects that were already fully loaded (and initialized), dialects that are in the process of being loaded are not loaded a second time. Differential Revision: https://reviews.llvm.org/D136685 --- mlir/include/mlir/IR/MLIRContext.h | 10 ++++++++-- mlir/lib/IR/MLIRContext.cpp | 22 +++++++++++++++++++--- mlir/tools/mlir-tblgen/DialectGen.cpp | 2 +- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index c162b00..b87dd27 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -98,16 +98,22 @@ public: })); } + /// Return true if the given dialect is currently loading. + bool isDialectLoading(StringRef dialectNamespace); + /// Load a dialect in the context. template void loadDialect() { - getOrLoadDialect(); + // Do not load the dialect if it is currently loading. This can happen if a + // dialect initializer triggers loading the same dialect recursively. + if (!isDialectLoading(Dialect::getDialectNamespace())) + getOrLoadDialect(); } /// Load a list dialects in the context. template void loadDialect() { - getOrLoadDialect(); + loadDialect(); loadDialect(); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 7ddcc2f..896938d 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -429,9 +429,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, ") while in a multi-threaded execution context (maybe " "the PassManager): this can indicate a " "missing `dependentDialects` in a pass for example."); -#endif - std::unique_ptr &dialect = - impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second; +#endif // NDEBUG + // nullptr indicates that the dialect is currently being loaded. + impl.loadedDialects[dialectNamespace] = nullptr; + std::unique_ptr &dialect = impl.loadedDialects[dialectNamespace] = + ctor(); assert(dialect && "dialect ctor failed"); // Refresh all the identifiers dialect field, this catches cases where a @@ -449,6 +451,14 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, return dialect.get(); } +#ifndef NDEBUG + if (dialectIt->second == nullptr) + llvm::report_fatal_error( + "Loading (and getting) a dialect (" + dialectNamespace + + ") while the same dialect is still loading: use loadDialect instead " + "of getOrLoadDialect."); +#endif // NDEBUG + // Abort if dialect with namespace has already been registered. std::unique_ptr &dialect = dialectIt->second; if (dialect->getTypeID() != dialectID) @@ -458,6 +468,12 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, return dialect.get(); } +bool MLIRContext::isDialectLoading(StringRef dialectNamespace) { + auto it = getImpl().loadedDialects.find(dialectNamespace); + // nullptr indicates that the dialect is currently being loaded. + return it != getImpl().loadedDialects.end() && it->second == nullptr; +} + DynamicDialect *MLIRContext::getOrLoadDynamicDialect( StringRef dialectNamespace, function_ref ctor) { auto &impl = getImpl(); diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp index c7e42ac..1085407 100644 --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -107,7 +107,7 @@ public: /// Registration for a single dependent dialect: to be inserted in the ctor /// above for each dependent dialect. const char *const dialectRegistrationTemplate = R"( - getContext()->getOrLoadDialect<{0}>(); + getContext()->loadDialect<{0}>(); )"; /// The code block for the attribute parser/printer hooks. -- 2.7.4