[mlir] Make some functions public to use custom TypeIDs
authorMathieu Fehr <mathieu.fehr@gmail.com>
Tue, 20 Apr 2021 17:49:30 +0000 (10:49 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 20 Apr 2021 17:56:00 +0000 (10:56 -0700)
Currently, it is only possible to register an operation or a type
when the TypeID is defined at compile time. Same with InterfaceMaps
which can only be defined with compile-time defined interfaces.

With those changes, it is now possible to register types/operations
with custom TypeIDs. This is necessary to define new operations/types
at runtime.

Differential Revision: https://reviews.llvm.org/D99084

mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/Support/InterfaceSupport.h

index d3af955..26ae5f9 100644 (file)
@@ -193,6 +193,11 @@ protected:
     (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
   }
 
+  /// Register a type instance with this dialect.
+  /// The use of this method is in general discouraged in favor of
+  /// 'addTypes<CustomType>()'.
+  void addType(TypeID typeID, AbstractType &&typeInfo);
+
   /// Register a set of attribute classes with this dialect.
   template <typename... Args> void addAttributes() {
     (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
@@ -231,7 +236,6 @@ private:
     addType(T::getTypeID(), AbstractType::get<T>(*this));
     detail::TypeUniquer::registerType<T>(context);
   }
-  void addType(TypeID typeID, AbstractType &&typeInfo);
 
   /// The namespace of this dialect.
   StringRef name;
index cb82ec9..7cfb979 100644 (file)
@@ -172,7 +172,9 @@ public:
            T::getHasTraitFn());
   }
 
-private:
+  /// Register a new operation in a Dialect object.
+  /// The use of this method is in general discouraged in favor of
+  /// 'insert<CustomOp>(dialect)'.
   static void insert(StringRef name, Dialect &dialect, TypeID typeID,
                      ParseAssemblyFn parseAssembly,
                      PrintAssemblyFn printAssembly,
@@ -180,6 +182,7 @@ private:
                      GetCanonicalizationPatternsFn getCanonicalizationPatterns,
                      detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
 
+private:
   AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID,
                     ParseAssemblyFn parseAssembly,
                     PrintAssemblyFn printAssembly,
index c1de589..898c26d 100644 (file)
@@ -39,6 +39,15 @@ public:
     return AbstractType(dialect, T::getInterfaceMap(), T::getTypeID());
   }
 
+  /// This method is used by Dialect objects to register types with
+  /// custom TypeIDs.
+  /// The use of this method is in general discouraged in favor of
+  /// 'get<CustomType>(dialect)';
+  static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
+                          TypeID typeID) {
+    return AbstractType(dialect, std::move(interfaceMap), typeID);
+  }
+
   /// Return the dialect this type was registered to.
   Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
 
index 6fc6117..6af36aa 100644 (file)
@@ -188,6 +188,16 @@ public:
   /// Returns true if the interface map contains an interface for the given id.
   bool contains(TypeID interfaceID) const { return lookup(interfaceID); }
 
+  /// Create an InterfaceMap given with the implementation of the interfaces.
+  /// The use of this constructor is in general discouraged in favor of
+  /// 'InterfaceMap::get<InterfaceA, ...>()'.
+  InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements)
+      : interfaces(elements.begin(), elements.end()) {
+    llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) {
+      return compare(lhs.first, rhs.first);
+    });
+  }
+
 private:
   /// Compare two TypeID instances by comparing the underlying pointer.
   static bool compare(TypeID lhs, TypeID rhs) {
@@ -195,12 +205,6 @@ private:
   }
 
   InterfaceMap() = default;
-  InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements)
-      : interfaces(elements.begin(), elements.end()) {
-    llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) {
-      return compare(lhs.first, rhs.first);
-    });
-  }
 
   template <typename... Ts>
   static InterfaceMap getImpl(std::tuple<Ts...> *) {