Revisit Dialect registration: require and store a TypeID on dialects
authorMehdi Amini <joker.eph@gmail.com>
Fri, 7 Aug 2020 02:41:44 +0000 (02:41 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 7 Aug 2020 15:57:08 +0000 (15:57 +0000)
This patch moves the registration to a method in the MLIRContext: getOrCreateDialect<ConcreteDialect>()

This method requires dialect to provide a static getDialectNamespace()
and store a TypeID on the Dialect itself, which allows to lazyily
create a dialect when not yet loaded in the context.
As a side effect, it means that duplicated registration of the same
dialect is not an issue anymore.

To limit the boilerplate, TableGen dialect generation is modified to
emit the constructor entirely and invoke separately a "init()" method
that the user implements.

Differential Revision: https://reviews.llvm.org/D85495

31 files changed:
mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp
mlir/examples/toy/Ch2/mlir/Dialect.cpp
mlir/examples/toy/Ch3/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Dialect/Quant/IR/QuantOps.cpp
mlir/lib/Dialect/SCF/SCF.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/tools/mlir-tblgen/DialectGen.cpp
mlir/unittests/IR/DialectTest.cpp

index 3a253f3..acdf88a 100644 (file)
@@ -16,8 +16,7 @@ using namespace mlir::standalone;
 // Standalone dialect.
 //===----------------------------------------------------------------------===//
 
-StandaloneDialect::StandaloneDialect(mlir::MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void StandaloneDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "Standalone/StandaloneOps.cpp.inc"
index 4be6bdc..86d6383 100644 (file)
@@ -26,7 +26,8 @@ using namespace mlir::toy;
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"
index 4be6bdc..86d6383 100644 (file)
@@ -26,7 +26,8 @@ using namespace mlir::toy;
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"
index 97c97b0..ca568a5 100644 (file)
@@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"
index 3f7dafa..d1a518e 100644 (file)
@@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"
index 3f7dafa..d1a518e 100644 (file)
@@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"
index fc7bf2a..e233a55 100644 (file)
@@ -76,7 +76,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"
index 9b708fe..d21f5bc 100644 (file)
@@ -27,7 +27,9 @@ def LLVM_Dialect : Dialect {
   private:
     friend LLVMType;
 
-    std::unique_ptr<detail::LLVMDialectImpl> impl;
+    // This can't be a unique_ptr because the ctor is generated inline
+    // in the class definition at the moment.
+    detail::LLVMDialectImpl *impl;
   }];
 }
 
index 0993b43..901ada9 100644 (file)
@@ -17,7 +17,8 @@ class MLIRContext;
 
 class SDBMDialect : public Dialect {
 public:
-  SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {}
+  SDBMDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {}
 
   /// Since there are no other virtual methods in this derived class, override
   /// the destructor so that key methods get defined in the corresponding
index 043fde9..bd9f3c1 100644 (file)
@@ -14,6 +14,7 @@
 #define MLIR_IR_DIALECT_H
 
 #include "mlir/IR/OperationSupport.h"
+#include "mlir/Support/TypeID.h"
 
 namespace mlir {
 class DialectAsmParser;
@@ -49,6 +50,9 @@ public:
 
   StringRef getNamespace() const { return name; }
 
+  /// Returns the unique identifier that corresponds to this dialect.
+  TypeID getTypeID() const { return dialectID; }
+
   /// Returns true if this dialect allows for unregistered operations, i.e.
   /// operations prefixed with the dialect namespace but not registered with
   /// addOperation.
@@ -177,7 +181,7 @@ protected:
   ///       with the namespace followed by '.'.
   /// Example:
   ///       - "tf" for the TensorFlow ops like "tf.add".
-  Dialect(StringRef name, MLIRContext *context);
+  Dialect(StringRef name, MLIRContext *context, TypeID id);
 
   /// This method is used by derived classes to add their operations to the set.
   ///
@@ -223,13 +227,13 @@ private:
   Dialect(const Dialect &) = delete;
   void operator=(Dialect &) = delete;
 
-  /// Register this dialect object with the specified context.  The context
-  /// takes ownership of the heap allocated dialect.
-  void registerDialect(MLIRContext *context);
-
   /// The namespace of this dialect.
   StringRef name;
 
+  /// The unique identifier of the derived Op class, this is used in the context
+  /// to allow registering multiple times the same dialect.
+  TypeID dialectID;
+
   /// This is the context that owns this Dialect object.
   MLIRContext *context;
 
@@ -255,7 +259,9 @@ private:
                            const DialectAllocatorFunction &function);
   template <typename ConcreteDialect>
   friend void registerDialect();
+  friend class MLIRContext;
 };
+
 /// Registers all dialects and hooks from the global registries with the
 /// specified MLIRContext.
 /// Note: This method is not thread-safe.
@@ -265,12 +271,9 @@ void registerAllDialects(MLIRContext *context);
 /// global registry by calling registerDialect<MyDialect>();
 /// Note: This method is not thread-safe.
 template <typename ConcreteDialect> void registerDialect() {
-  Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
-                                    [](MLIRContext *ctx) {
-                                      // Just allocate the dialect, the context
-                                      // takes ownership of it.
-                                      new ConcreteDialect(ctx);
-                                    });
+  Dialect::registerDialectAllocator(
+      TypeID::get<ConcreteDialect>(),
+      [](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
 }
 
 /// DialectRegistration provides a global initializer that registers a Dialect
@@ -291,7 +294,7 @@ namespace llvm {
 template <typename T>
 struct isa_impl<T, ::mlir::Dialect> {
   static inline bool doit(const ::mlir::Dialect &dialect) {
-    return T::getDialectNamespace() == dialect.getNamespace();
+    return mlir::TypeID::get<T>() == dialect.getTypeID();
   }
 };
 } // namespace llvm
index 8e75bb6..0192a8a 100644 (file)
@@ -10,6 +10,7 @@
 #define MLIR_IR_MLIRCONTEXT_H
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
 #include <functional>
 #include <memory>
 #include <vector>
@@ -49,6 +50,18 @@ public:
     return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
   }
 
+  /// Get (or create) a dialect for the given derived dialect type. The derived
+  /// type must provide a static 'getDialectNamespace' method.
+  template <typename T>
+  T *getOrCreateDialect() {
+    return static_cast<T *>(getOrCreateDialect(
+        T::getDialectNamespace(), TypeID::get<T>(), [this]() {
+          std::unique_ptr<T> dialect(new T(this));
+          dialect->dialectID = TypeID::get<T>();
+          return dialect;
+        }));
+  }
+
   /// Return true if we allow to create operation for unregistered dialects.
   bool allowsUnregisteredDialects();
 
@@ -109,6 +122,12 @@ public:
 private:
   const std::unique_ptr<MLIRContextImpl> impl;
 
+  /// Get a dialect for the provided namespace and TypeID: abort the program if
+  /// a dialect exist for this namespace with different TypeID. Returns a
+  /// pointer to the dialect owned by the context.
+  Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
+                              function_ref<std::unique_ptr<Dialect>()> ctor);
+
   MLIRContext(const MLIRContext &) = delete;
   void operator=(const MLIRContext &) = delete;
 };
index aade931..3595970 100644 (file)
@@ -18,8 +18,7 @@
 
 using namespace mlir;
 
-avx512::AVX512Dialect::AVX512Dialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void avx512::AVX512Dialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/AVX512/AVX512.cpp.inc"
index 5c9dc8f..fa98f63 100644 (file)
@@ -68,8 +68,7 @@ struct AffineInlinerInterface : public DialectInlinerInterface {
 // AffineDialect
 //===----------------------------------------------------------------------===//
 
-AffineDialect::AffineDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void AffineDialect::initialize() {
   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
 #define GET_OP_LIST
 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
index dd8200d..58f9480 100644 (file)
@@ -35,8 +35,7 @@ bool GPUDialect::isKernel(Operation *op) {
   return static_cast<bool>(isKernelAttr);
 }
 
-GPUDialect::GPUDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void GPUDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
index bde8114..9f7e66b 100644 (file)
@@ -20,8 +20,7 @@
 
 using namespace mlir;
 
-LLVM::LLVMAVX512Dialect::LLVMAVX512Dialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void LLVM::LLVMAVX512Dialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc"
index e03d025..515e120 100644 (file)
@@ -1683,9 +1683,8 @@ struct LLVMDialectImpl {
 } // end namespace LLVM
 } // end namespace mlir
 
-LLVMDialect::LLVMDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context),
-      impl(new detail::LLVMDialectImpl()) {
+void LLVMDialect::initialize() {
+  impl = new detail::LLVMDialectImpl();
   // clang-format off
   addTypes<LLVMVoidType,
            LLVMHalfType,
@@ -1716,7 +1715,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
   allowUnknownOperations();
 }
 
-LLVMDialect::~LLVMDialect() {}
+LLVMDialect::~LLVMDialect() { delete impl; }
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
index 9a09488..cc809b5 100644 (file)
@@ -136,7 +136,7 @@ static LogicalResult verify(MmaOp op) {
 //===----------------------------------------------------------------------===//
 
 // TODO: This should be the llvm.nvvm dialect once this is supported.
-NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
+void NVVMDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
index 47089b9..70c3558 100644 (file)
@@ -81,7 +81,7 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
 //===----------------------------------------------------------------------===//
 
 // TODO: This should be the llvm.rocdl dialect once this is supported.
-ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) {
+void ROCDLDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
index a55d467..50924f7 100644 (file)
@@ -24,8 +24,7 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void mlir::linalg::LinalgDialect::initialize() {
   addTypes<RangeType>();
   addOperations<
 #define GET_OP_LIST
index 4467d33..9159e87 100644 (file)
@@ -26,8 +26,7 @@
 using namespace mlir;
 using namespace mlir::omp;
 
-OpenMPDialect::OpenMPDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void OpenMPDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
index 07f881f..e7df59a 100644 (file)
@@ -23,8 +23,7 @@ using namespace mlir;
 using namespace mlir::quant;
 using namespace mlir::quant::detail;
 
-QuantizationDialect::QuantizationDialect(MLIRContext *context)
-    : Dialect(/*name=*/"quant", context) {
+void QuantizationDialect::initialize() {
   addTypes<AnyQuantizedType, UniformQuantizedType,
            UniformQuantizedPerAxisType>();
   addOperations<
index d0958e5..6f3f1e4 100644 (file)
@@ -53,8 +53,7 @@ struct SCFInlinerInterface : public DialectInlinerInterface {
 // SCFDialect
 //===----------------------------------------------------------------------===//
 
-SCFDialect::SCFDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void SCFDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
index a2659d6..01c3057 100644 (file)
@@ -112,8 +112,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
 // SPIR-V Dialect
 //===----------------------------------------------------------------------===//
 
-SPIRVDialect::SPIRVDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void SPIRVDialect::initialize() {
   addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
            PointerType, RuntimeArrayType, StructType>();
 
index be4c3c7..47c592e 100644 (file)
@@ -59,8 +59,7 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
   return success();
 }
 
-ShapeDialect::ShapeDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void ShapeDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
index 74e1e20..a19d579 100644 (file)
@@ -145,8 +145,7 @@ static LogicalResult verifyCastOp(T op) {
   return success();
 }
 
-StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void StandardOpsDialect::initialize() {
   addOperations<DmaStartOp, DmaWaitOp,
 #define GET_OP_LIST
 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
index e04091e..7c715bf 100644 (file)
@@ -34,8 +34,7 @@ using namespace mlir::vector;
 // VectorDialect
 //===----------------------------------------------------------------------===//
 
-VectorDialect::VectorDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void VectorDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
index 501cdda..02448b3 100644 (file)
@@ -66,10 +66,9 @@ void mlir::registerAllDialects(MLIRContext *context) {
 // Dialect
 //===----------------------------------------------------------------------===//
 
-Dialect::Dialect(StringRef name, MLIRContext *context)
-    : name(name), context(context) {
+Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
+    : name(name), dialectID(id), context(context) {
   assert(isValidNamespace(name) && "invalid dialect namespace");
-  registerDialect(context);
 }
 
 Dialect::~Dialect() {}
index a4e833c..32ee60f 100644 (file)
@@ -85,7 +85,8 @@ namespace {
 /// A builtin dialect to define types/etc that are necessary for the validity of
 /// the IR.
 struct BuiltinDialect : public Dialect {
-  BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
+  BuiltinDialect(MLIRContext *context)
+      : Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
     addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
                   DenseStringElementsAttr, DictionaryAttr, FloatAttr,
                   SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
@@ -102,6 +103,7 @@ struct BuiltinDialect : public Dialect {
     // have been fully decoupled from the core.
     addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
   }
+  static StringRef getDialectNamespace() { return ""; }
 };
 } // end anonymous namespace.
 
@@ -349,7 +351,7 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
   }
 
   // Register dialects with this context.
-  new BuiltinDialect(this);
+  getOrCreateDialect<BuiltinDialect>();
   registerAllDialects(this);
 
   // Initialize several common attributes and types to avoid the need to lock
@@ -446,25 +448,33 @@ Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
              : nullptr;
 }
 
-/// Register this dialect object with the specified context.  The context
-/// takes ownership of the heap allocated dialect.
-void Dialect::registerDialect(MLIRContext *context) {
-  auto &impl = context->getImpl();
-  std::unique_ptr<Dialect> dialect(this);
-
+/// Get a dialect for the provided namespace and TypeID: abort the program if a
+/// dialect exist for this namespace with different TypeID. Returns a pointer to
+/// the dialect owned by the context.
+Dialect *
+MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
+                                function_ref<std::unique_ptr<Dialect>()> ctor) {
+  auto &impl = getImpl();
   // Get the correct insertion position sorted by namespace.
-  auto insertPt = llvm::lower_bound(
-      impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
-        return lhs->getNamespace() < rhs->getNamespace();
-      });
+  auto insertPt =
+      llvm::lower_bound(impl.dialects, nullptr,
+                        [&](const std::unique_ptr<Dialect> &lhs,
+                            const std::unique_ptr<Dialect> &rhs) {
+                          if (!lhs)
+                            return dialectNamespace < rhs->getNamespace();
+                          return lhs->getNamespace() < dialectNamespace;
+                        });
 
   // Abort if dialect with namespace has already been registered.
   if (insertPt != impl.dialects.end() &&
-      (*insertPt)->getNamespace() == getNamespace()) {
-    llvm::report_fatal_error("a dialect with namespace '" + getNamespace() +
+      (*insertPt)->getNamespace() == dialectNamespace) {
+    if ((*insertPt)->getTypeID() == dialectID)
+      return insertPt->get();
+    llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
                              "' has already been registered");
   }
-  impl.dialects.insert(insertPt, std::move(dialect));
+  auto it = impl.dialects.insert(insertPt, ctor());
+  return &**it;
 }
 
 bool MLIRContext::allowsUnregisteredDialects() {
index cdbf974..c9cfdc5 100644 (file)
@@ -130,8 +130,7 @@ struct TestInlinerInterface : public DialectInlinerInterface {
 // TestDialect
 //===----------------------------------------------------------------------===//
 
-TestDialect::TestDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void TestDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "TestOps.cpp.inc"
index 4a9109d..13421c4 100644 (file)
@@ -63,8 +63,14 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
 /// {1}: The dialect namespace.
 static const char *const dialectDeclBeginStr = R"(
 class {0} : public ::mlir::Dialect {
+  explicit {0}(::mlir::MLIRContext *context)
+    : ::mlir::Dialect(getDialectNamespace(), context,
+      ::mlir::TypeID::get<{0}>()) {{
+    initialize();
+  }
+  void initialize();
+  friend class ::mlir::MLIRContext;
 public:
-  explicit {0}(::mlir::MLIRContext *context);
   static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
 )";
 
index 49d2e27..bc389ce 100644 (file)
@@ -14,7 +14,15 @@ using namespace mlir::detail;
 
 namespace {
 struct TestDialect : public Dialect {
-  TestDialect(MLIRContext *context) : Dialect(/*name=*/"test", context) {}
+  static StringRef getDialectNamespace() { return "test"; };
+  TestDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
+};
+struct AnotherTestDialect : public Dialect {
+  static StringRef getDialectNamespace() { return "test"; };
+  AnotherTestDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context,
+                TypeID::get<AnotherTestDialect>()) {}
 };
 
 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
@@ -22,8 +30,8 @@ TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
 
   // Registering a dialect with the same namespace twice should result in a
   // failure.
-  new TestDialect(&context);
-  ASSERT_DEATH(new TestDialect(&context), "");
+  context.getOrCreateDialect<TestDialect>();
+  ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), "");
 }
 
 } // end namespace