From 3da51522fb4f72b7d4619f2dfd454bb3073ab460 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 10 Feb 2021 10:11:40 +0100 Subject: [PATCH] [mlir] enable delayed registration of dialect interfaces This introduces a mechanism to register interfaces for a dialect without making the dialect itself depend on the interface. The registration request happens on DialectRegistry and, if the dialect has not been loaded yet, the actual registration is delayed until the dialect is loaded. It requires DialectRegistry to become aware of the context that contains it and the context to expose methods for querying if a dialect is loaded. This mechanism will enable a simple extension mechanism for dialects that can have interfaces defined outside of the dialect code. It is particularly helpful for, e.g., translation to LLVM IR where we don't want the dialect itself to depend on LLVM IR libraries. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D96137 --- mlir/include/mlir/IR/Dialect.h | 53 +++++++++++++++++++++++++++++++++-- mlir/lib/IR/Dialect.cpp | 32 +++++++++++++++++++++ mlir/lib/IR/MLIRContext.cpp | 11 +++++--- mlir/lib/Support/MlirOptMain.cpp | 10 +++---- mlir/unittests/IR/DialectTest.cpp | 58 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 151 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index cd64d38..978531f 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -26,6 +26,8 @@ class OpBuilder; class Type; using DialectAllocatorFunction = std::function; +using InterfaceAllocatorFunction = + std::function(Dialect *)>; /// Dialects are groups of MLIR operations, types and attributes, as well as /// behavior associated with the entire group. For example, hooks into other @@ -222,6 +224,7 @@ private: /// A collection of registered dialect interfaces. DenseMap> registeredInterfaces; + friend class DialectRegistry; friend void registerDialect(); friend class MLIRContext; }; @@ -234,8 +237,13 @@ private: class DialectRegistry { using MapTy = std::map>; + using InterfaceMapTy = + DenseMap>; public: + explicit DialectRegistry(MLIRContext *context = nullptr) + : owningContext(context) {} + template void insert() { insert(TypeID::get(), @@ -254,7 +262,9 @@ public: insert(); } - /// Add a new dialect constructor to the registry. + /// Add a new dialect constructor to the registry. The constructor must be + /// calling MLIRContext::getOrLoadDialect in order for the context to take + /// ownership of the dialect and for delayed interface registration to happen. void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor); /// Load a dialect for this namespace in the provided context. @@ -267,6 +277,7 @@ public: destination.insert(nameAndRegistrationIt.second.first, nameAndRegistrationIt.first, nameAndRegistrationIt.second.second); + destination.interfaces.insert(interfaces.begin(), interfaces.end()); } // Load all dialects available in the registry in the provided context. void loadAll(MLIRContext *context) { @@ -274,11 +285,47 @@ public: nameAndRegistrationIt.second.second(context); } - MapTy::const_iterator begin() const { return registry.begin(); } - MapTy::const_iterator end() const { return registry.end(); } + /// Return the names of dialects known to this registry. + auto getDialectNames() { + return llvm::map_range( + registry, [](const MapTy::value_type &item) { return item.first; }); + } + + /// Add an interface constructed with the given allocation function to the + /// dialect provided as template parameter. The dialect must be present in + /// the registry, but may or may not be loaded. If it is not loaded, the + /// interface registration is delayed until the loading. + template + void addDialectInterface(InterfaceAllocatorFunction allocator) { + addDialectInterface(DialectTy::getDialectNamespace(), allocator); + } + + /// Add an interface to the dialect, both provided as template parameter. The + /// dialect must be present in the registry, but may or may not be loaded. If + /// it is not loaded, the interface registration is delayed until the loading. + template + void addDialectInterface() { + addDialectInterface([](Dialect *dialect) { + return std::make_unique(dialect); + }); + } + + /// Register any interfaces required for the given dialect (based on its + /// TypeID). Users are not expected to call this directly. + void registerDelayedInterfaces(Dialect *dialect); private: + /// Add an interface constructed with the given allocation function to the + /// dialect identified by its namespace. + void addDialectInterface(StringRef dialectName, + InterfaceAllocatorFunction allocator); + MapTy registry; + InterfaceMapTy interfaces; + + /// If this registry belongs to a context, this points back to the context. + /// Useful for checking if a dialect is loaded in the context. + MLIRContext *owningContext; }; } // namespace mlir diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index beabd48..01f8ec1 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -22,6 +22,29 @@ using namespace detail; DialectAsmParser::~DialectAsmParser() {} +//===----------------------------------------------------------------------===// +// DialectRegistry +//===----------------------------------------------------------------------===// + +void DialectRegistry::addDialectInterface( + StringRef dialectName, InterfaceAllocatorFunction allocator) { + assert(allocator && "unexpected null interface allocation function"); + + // If the dialect is already loaded, directly add the interface. + if (Dialect *dialect = owningContext + ? owningContext->getLoadedDialect(dialectName) + : nullptr) { + dialect->addInterface(allocator(dialect)); + return; + } + + // Otherwise, store it in the interface map for delayed registration. + auto it = registry.find(dialectName.str()); + assert(it != registry.end() && + "adding an interface for an unregistered dialect"); + interfaces[it->second.first].push_back(allocator); +} + Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) { auto it = registry.find(name.str()); if (it == registry.end()) @@ -40,6 +63,15 @@ void DialectRegistry::insert(TypeID typeID, StringRef name, } } +void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) { + auto it = interfaces.find(dialect->getTypeID()); + if (it == interfaces.end()) + return; + + for (const InterfaceAllocatorFunction &createInterface : it->second) + dialect->addInterface(createInterface(dialect)); +} + //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index f56eb75..832eea7 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -326,7 +326,8 @@ public: DictionaryAttr emptyDictionaryAttr; public: - MLIRContextImpl() : identifiers(identifierAllocator) {} + MLIRContextImpl(MLIRContext *ctx) + : dialectsRegistry(ctx), identifiers(identifierAllocator) {} ~MLIRContextImpl() { for (auto typeMapping : registeredTypes) typeMapping.second->~AbstractType(); @@ -336,7 +337,7 @@ public: }; } // end namespace mlir -MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) { +MLIRContext::MLIRContext() : impl(new MLIRContextImpl(this)) { // Initialize values based on the command line flags if they were provided. if (clOptions.isConstructed()) { disableMultithreading(clOptions->disableThreading); @@ -441,8 +442,8 @@ std::vector MLIRContext::getLoadedDialects() { } std::vector MLIRContext::getAvailableDialects() { std::vector result; - for (auto &dialect : impl->dialectsRegistry) - result.push_back(dialect.first); + for (auto dialect : impl->dialectsRegistry.getDialectNames()) + result.push_back(dialect); return result; } @@ -493,6 +494,8 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, identifierEntry.first().startswith(dialectNamespace)) identifierEntry.second = dialect.get(); + // Actually register the interfaces with delayed registration. + impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); return dialect.get(); } diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp index 85891fd..2796851 100644 --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -201,10 +201,8 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, { llvm::raw_string_ostream os(helpHeader); MLIRContext context; - interleaveComma(registry, os, [&](auto ®istryEntry) { - StringRef name = registryEntry.first; - os << name; - }); + interleaveComma(registry.getDialectNames(), os, + [&](auto name) { os << name; }); } // Parse pass names in main to ensure static initialization completed. cl::ParseCommandLineOptions(argc, argv, helpHeader); @@ -212,8 +210,8 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, if (showDialects) { llvm::outs() << "Available Dialects:\n"; interleave( - registry, llvm::outs(), - [](auto ®istryEntry) { llvm::outs() << registryEntry.first; }, "\n"); + registry.getDialectNames(), llvm::outs(), + [](auto name) { llvm::outs() << name; }, "\n"); return success(); } diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp index 2410be0..ed19558 100644 --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectInterface.h" #include "gtest/gtest.h" using namespace mlir; @@ -34,4 +35,61 @@ TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { ASSERT_DEATH(context.loadDialect(), ""); } +struct SecondTestDialect : public Dialect { + static StringRef getDialectNamespace() { return "test2"; } + SecondTestDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, + TypeID::get()) {} +}; + +struct TestDialectInterfaceBase + : public DialectInterface::Base { + TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} + virtual int function() const { return 42; } +}; + +struct TestDialectInterface : public TestDialectInterfaceBase { + using TestDialectInterfaceBase::TestDialectInterfaceBase; + int function() const final { return 56; } +}; + +struct SecondTestDialectInterface : public TestDialectInterfaceBase { + using TestDialectInterfaceBase::TestDialectInterfaceBase; + int function() const final { return 78; } +}; + +TEST(Dialect, DelayedInterfaceRegistration) { + DialectRegistry registry; + registry.insert(); + + // Delayed registration of an interface for TestDialect. + registry.addDialectInterface(); + + MLIRContext context; + registry.appendTo(context.getDialectRegistry()); + + // Load the TestDialect and check that the interface got registered for it. + auto *testDialect = context.getOrLoadDialect(); + ASSERT_TRUE(testDialect != nullptr); + auto *testDialectInterface = + testDialect->getRegisteredInterface(); + EXPECT_TRUE(testDialectInterface != nullptr); + + // Load the SecondTestDialect and check that the interface is not registered + // for it. + auto *secondTestDialect = context.getOrLoadDialect(); + ASSERT_TRUE(secondTestDialect != nullptr); + auto *secondTestDialectInterface = + secondTestDialect->getRegisteredInterface(); + EXPECT_TRUE(secondTestDialectInterface == nullptr); + + // Use the same mechanism as for delayed registration but for an already + // loaded dialect and check that the interface is now registered. + context.getDialectRegistry() + .addDialectInterface(); + secondTestDialectInterface = + secondTestDialect->getRegisteredInterface(); + EXPECT_TRUE(secondTestDialectInterface != nullptr); +} + } // end namespace -- 2.7.4