From 98dceed64bd061ef42272fb84eea8fd2b84083ac Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Tue, 20 Apr 2021 10:49:30 -0700 Subject: [PATCH] [mlir] Make some functions public to use custom TypeIDs Currently, it is only possible to register an operation or a type when the TypeID is defined at compile time. Same with InterfaceMaps which can only be defined with compile-time defined interfaces. With those changes, it is now possible to register types/operations with custom TypeIDs. This is necessary to define new operations/types at runtime. Differential Revision: https://reviews.llvm.org/D99084 --- mlir/include/mlir/IR/Dialect.h | 6 +++++- mlir/include/mlir/IR/OperationSupport.h | 5 ++++- mlir/include/mlir/IR/TypeSupport.h | 9 +++++++++ mlir/include/mlir/Support/InterfaceSupport.h | 16 ++++++++++------ 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index d3af955..26ae5f9 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -193,6 +193,11 @@ protected: (void)std::initializer_list{0, (addType(), 0)...}; } + /// Register a type instance with this dialect. + /// The use of this method is in general discouraged in favor of + /// 'addTypes()'. + void addType(TypeID typeID, AbstractType &&typeInfo); + /// Register a set of attribute classes with this dialect. template void addAttributes() { (void)std::initializer_list{0, (addAttribute(), 0)...}; @@ -231,7 +236,6 @@ private: addType(T::getTypeID(), AbstractType::get(*this)); detail::TypeUniquer::registerType(context); } - void addType(TypeID typeID, AbstractType &&typeInfo); /// The namespace of this dialect. StringRef name; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index cb82ec9..7cfb979e 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -172,7 +172,9 @@ public: T::getHasTraitFn()); } -private: + /// Register a new operation in a Dialect object. + /// The use of this method is in general discouraged in favor of + /// 'insert(dialect)'. static void insert(StringRef name, Dialect &dialect, TypeID typeID, ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly, @@ -180,6 +182,7 @@ private: GetCanonicalizationPatternsFn getCanonicalizationPatterns, detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait); +private: AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID, ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly, diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h index c1de589..898c26d 100644 --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -39,6 +39,15 @@ public: return AbstractType(dialect, T::getInterfaceMap(), T::getTypeID()); } + /// This method is used by Dialect objects to register types with + /// custom TypeIDs. + /// The use of this method is in general discouraged in favor of + /// 'get(dialect)'; + static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, + TypeID typeID) { + return AbstractType(dialect, std::move(interfaceMap), typeID); + } + /// Return the dialect this type was registered to. Dialect &getDialect() const { return const_cast(dialect); } diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h index 6fc6117..6af36aa 100644 --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -188,6 +188,16 @@ public: /// Returns true if the interface map contains an interface for the given id. bool contains(TypeID interfaceID) const { return lookup(interfaceID); } + /// Create an InterfaceMap given with the implementation of the interfaces. + /// The use of this constructor is in general discouraged in favor of + /// 'InterfaceMap::get()'. + InterfaceMap(MutableArrayRef> elements) + : interfaces(elements.begin(), elements.end()) { + llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) { + return compare(lhs.first, rhs.first); + }); + } + private: /// Compare two TypeID instances by comparing the underlying pointer. static bool compare(TypeID lhs, TypeID rhs) { @@ -195,12 +205,6 @@ private: } InterfaceMap() = default; - InterfaceMap(MutableArrayRef> elements) - : interfaces(elements.begin(), elements.end()) { - llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) { - return compare(lhs.first, rhs.first); - }); - } template static InterfaceMap getImpl(std::tuple *) { -- 2.7.4