This patch moves the registration to a method in the MLIRContext: getOrCreateDialect<ConcreteDialect>()
This method requires dialect to provide a static getDialectNamespace()
and store a TypeID on the Dialect itself, which allows to lazyily
create a dialect when not yet loaded in the context.
As a side effect, it means that duplicated registration of the same
dialect is not an issue anymore.
To limit the boilerplate, TableGen dialect generation is modified to
emit the constructor entirely and invoke separately a "init()" method
that the user implements.
Differential Revision: https://reviews.llvm.org/D85495
// Standalone dialect.
//===----------------------------------------------------------------------===//
-StandaloneDialect::StandaloneDialect(mlir::MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void StandaloneDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "Standalone/StandaloneOps.cpp.inc"
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
private:
friend LLVMType;
- std::unique_ptr<detail::LLVMDialectImpl> impl;
+ // This can't be a unique_ptr because the ctor is generated inline
+ // in the class definition at the moment.
+ detail::LLVMDialectImpl *impl;
}];
}
class SDBMDialect : public Dialect {
public:
- SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {}
+ SDBMDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {}
/// Since there are no other virtual methods in this derived class, override
/// the destructor so that key methods get defined in the corresponding
#define MLIR_IR_DIALECT_H
#include "mlir/IR/OperationSupport.h"
+#include "mlir/Support/TypeID.h"
namespace mlir {
class DialectAsmParser;
StringRef getNamespace() const { return name; }
+ /// Returns the unique identifier that corresponds to this dialect.
+ TypeID getTypeID() const { return dialectID; }
+
/// Returns true if this dialect allows for unregistered operations, i.e.
/// operations prefixed with the dialect namespace but not registered with
/// addOperation.
/// with the namespace followed by '.'.
/// Example:
/// - "tf" for the TensorFlow ops like "tf.add".
- Dialect(StringRef name, MLIRContext *context);
+ Dialect(StringRef name, MLIRContext *context, TypeID id);
/// This method is used by derived classes to add their operations to the set.
///
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;
- /// Register this dialect object with the specified context. The context
- /// takes ownership of the heap allocated dialect.
- void registerDialect(MLIRContext *context);
-
/// The namespace of this dialect.
StringRef name;
+ /// The unique identifier of the derived Op class, this is used in the context
+ /// to allow registering multiple times the same dialect.
+ TypeID dialectID;
+
/// This is the context that owns this Dialect object.
MLIRContext *context;
const DialectAllocatorFunction &function);
template <typename ConcreteDialect>
friend void registerDialect();
+ friend class MLIRContext;
};
+
/// Registers all dialects and hooks from the global registries with the
/// specified MLIRContext.
/// Note: This method is not thread-safe.
/// global registry by calling registerDialect<MyDialect>();
/// Note: This method is not thread-safe.
template <typename ConcreteDialect> void registerDialect() {
- Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
- [](MLIRContext *ctx) {
- // Just allocate the dialect, the context
- // takes ownership of it.
- new ConcreteDialect(ctx);
- });
+ Dialect::registerDialectAllocator(
+ TypeID::get<ConcreteDialect>(),
+ [](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
}
/// DialectRegistration provides a global initializer that registers a Dialect
template <typename T>
struct isa_impl<T, ::mlir::Dialect> {
static inline bool doit(const ::mlir::Dialect &dialect) {
- return T::getDialectNamespace() == dialect.getNamespace();
+ return mlir::TypeID::get<T>() == dialect.getTypeID();
}
};
} // namespace llvm
#define MLIR_IR_MLIRCONTEXT_H
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
#include <functional>
#include <memory>
#include <vector>
return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
}
+ /// Get (or create) a dialect for the given derived dialect type. The derived
+ /// type must provide a static 'getDialectNamespace' method.
+ template <typename T>
+ T *getOrCreateDialect() {
+ return static_cast<T *>(getOrCreateDialect(
+ T::getDialectNamespace(), TypeID::get<T>(), [this]() {
+ std::unique_ptr<T> dialect(new T(this));
+ dialect->dialectID = TypeID::get<T>();
+ return dialect;
+ }));
+ }
+
/// Return true if we allow to create operation for unregistered dialects.
bool allowsUnregisteredDialects();
private:
const std::unique_ptr<MLIRContextImpl> impl;
+ /// Get a dialect for the provided namespace and TypeID: abort the program if
+ /// a dialect exist for this namespace with different TypeID. Returns a
+ /// pointer to the dialect owned by the context.
+ Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
+ function_ref<std::unique_ptr<Dialect>()> ctor);
+
MLIRContext(const MLIRContext &) = delete;
void operator=(const MLIRContext &) = delete;
};
using namespace mlir;
-avx512::AVX512Dialect::AVX512Dialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void avx512::AVX512Dialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AVX512/AVX512.cpp.inc"
// AffineDialect
//===----------------------------------------------------------------------===//
-AffineDialect::AffineDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void AffineDialect::initialize() {
addOperations<AffineDmaStartOp, AffineDmaWaitOp,
#define GET_OP_LIST
#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
return static_cast<bool>(isKernelAttr);
}
-GPUDialect::GPUDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void GPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
using namespace mlir;
-LLVM::LLVMAVX512Dialect::LLVMAVX512Dialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void LLVM::LLVMAVX512Dialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc"
} // end namespace LLVM
} // end namespace mlir
-LLVMDialect::LLVMDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context),
- impl(new detail::LLVMDialectImpl()) {
+void LLVMDialect::initialize() {
+ impl = new detail::LLVMDialectImpl();
// clang-format off
addTypes<LLVMVoidType,
LLVMHalfType,
allowUnknownOperations();
}
-LLVMDialect::~LLVMDialect() {}
+LLVMDialect::~LLVMDialect() { delete impl; }
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
//===----------------------------------------------------------------------===//
// TODO: This should be the llvm.nvvm dialect once this is supported.
-NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
+void NVVMDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
//===----------------------------------------------------------------------===//
// TODO: This should be the llvm.rocdl dialect once this is supported.
-ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) {
+void ROCDLDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
using namespace mlir;
using namespace mlir::linalg;
-mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void mlir::linalg::LinalgDialect::initialize() {
addTypes<RangeType>();
addOperations<
#define GET_OP_LIST
using namespace mlir;
using namespace mlir::omp;
-OpenMPDialect::OpenMPDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void OpenMPDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
using namespace mlir::quant;
using namespace mlir::quant::detail;
-QuantizationDialect::QuantizationDialect(MLIRContext *context)
- : Dialect(/*name=*/"quant", context) {
+void QuantizationDialect::initialize() {
addTypes<AnyQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addOperations<
// SCFDialect
//===----------------------------------------------------------------------===//
-SCFDialect::SCFDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void SCFDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/SCF/SCFOps.cpp.inc"
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
-SPIRVDialect::SPIRVDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void SPIRVDialect::initialize() {
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
PointerType, RuntimeArrayType, StructType>();
return success();
}
-ShapeDialect::ShapeDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void ShapeDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
return success();
}
-StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void StandardOpsDialect::initialize() {
addOperations<DmaStartOp, DmaWaitOp,
#define GET_OP_LIST
#include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
// VectorDialect
//===----------------------------------------------------------------------===//
-VectorDialect::VectorDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void VectorDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
// Dialect
//===----------------------------------------------------------------------===//
-Dialect::Dialect(StringRef name, MLIRContext *context)
- : name(name), context(context) {
+Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
+ : name(name), dialectID(id), context(context) {
assert(isValidNamespace(name) && "invalid dialect namespace");
- registerDialect(context);
}
Dialect::~Dialect() {}
/// A builtin dialect to define types/etc that are necessary for the validity of
/// the IR.
struct BuiltinDialect : public Dialect {
- BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
+ BuiltinDialect(MLIRContext *context)
+ : Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
// have been fully decoupled from the core.
addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
}
+ static StringRef getDialectNamespace() { return ""; }
};
} // end anonymous namespace.
}
// Register dialects with this context.
- new BuiltinDialect(this);
+ getOrCreateDialect<BuiltinDialect>();
registerAllDialects(this);
// Initialize several common attributes and types to avoid the need to lock
: nullptr;
}
-/// Register this dialect object with the specified context. The context
-/// takes ownership of the heap allocated dialect.
-void Dialect::registerDialect(MLIRContext *context) {
- auto &impl = context->getImpl();
- std::unique_ptr<Dialect> dialect(this);
-
+/// Get a dialect for the provided namespace and TypeID: abort the program if a
+/// dialect exist for this namespace with different TypeID. Returns a pointer to
+/// the dialect owned by the context.
+Dialect *
+MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
+ function_ref<std::unique_ptr<Dialect>()> ctor) {
+ auto &impl = getImpl();
// Get the correct insertion position sorted by namespace.
- auto insertPt = llvm::lower_bound(
- impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
- return lhs->getNamespace() < rhs->getNamespace();
- });
+ auto insertPt =
+ llvm::lower_bound(impl.dialects, nullptr,
+ [&](const std::unique_ptr<Dialect> &lhs,
+ const std::unique_ptr<Dialect> &rhs) {
+ if (!lhs)
+ return dialectNamespace < rhs->getNamespace();
+ return lhs->getNamespace() < dialectNamespace;
+ });
// Abort if dialect with namespace has already been registered.
if (insertPt != impl.dialects.end() &&
- (*insertPt)->getNamespace() == getNamespace()) {
- llvm::report_fatal_error("a dialect with namespace '" + getNamespace() +
+ (*insertPt)->getNamespace() == dialectNamespace) {
+ if ((*insertPt)->getTypeID() == dialectID)
+ return insertPt->get();
+ llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
"' has already been registered");
}
- impl.dialects.insert(insertPt, std::move(dialect));
+ auto it = impl.dialects.insert(insertPt, ctor());
+ return &**it;
}
bool MLIRContext::allowsUnregisteredDialects() {
// TestDialect
//===----------------------------------------------------------------------===//
-TestDialect::TestDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context) {
+void TestDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
/// {1}: The dialect namespace.
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::Dialect {
+ explicit {0}(::mlir::MLIRContext *context)
+ : ::mlir::Dialect(getDialectNamespace(), context,
+ ::mlir::TypeID::get<{0}>()) {{
+ initialize();
+ }
+ void initialize();
+ friend class ::mlir::MLIRContext;
public:
- explicit {0}(::mlir::MLIRContext *context);
static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
)";
namespace {
struct TestDialect : public Dialect {
- TestDialect(MLIRContext *context) : Dialect(/*name=*/"test", context) {}
+ static StringRef getDialectNamespace() { return "test"; };
+ TestDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
+};
+struct AnotherTestDialect : public Dialect {
+ static StringRef getDialectNamespace() { return "test"; };
+ AnotherTestDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context,
+ TypeID::get<AnotherTestDialect>()) {}
};
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
// Registering a dialect with the same namespace twice should result in a
// failure.
- new TestDialect(&context);
- ASSERT_DEATH(new TestDialect(&context), "");
+ context.getOrCreateDialect<TestDialect>();
+ ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), "");
}
} // end namespace