[mlir] Use unique_function in AbstractOperation fields
authorMathieu Fehr <mathieu.fehr@gmail.com>
Tue, 25 May 2021 18:36:04 +0000 (11:36 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 25 May 2021 18:36:12 +0000 (11:36 -0700)
Currently, AbstractOperation fields are function pointers.
Modifying them to unique_function allow them to contain
runtime information.

For instance, this allows operations to be defined at runtime.

Differential Revision: https://reviews.llvm.org/D103031

mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Parser/Parser.cpp

index 735d597..b2e5032 100644 (file)
@@ -1671,7 +1671,10 @@ private:
                               detect_has_single_result_fold<ConcreteOpT>::value,
                           AbstractOperation::FoldHookFn>
   getFoldHookFnImpl() {
-    return &foldSingleResultHook<ConcreteOpT>;
+    return [](Operation *op, ArrayRef<Attribute> operands,
+              SmallVectorImpl<OpFoldResult> &results) {
+      return foldSingleResultHook<ConcreteOpT>(op, operands, results);
+    };
   }
   /// The internal implementation of `getFoldHookFn` above that is invoked if
   /// the operation is not single result and defines a `fold` method.
@@ -1681,7 +1684,10 @@ private:
                               detect_has_fold<ConcreteOpT>::value,
                           AbstractOperation::FoldHookFn>
   getFoldHookFnImpl() {
-    return &foldHook<ConcreteOpT>;
+    return [](Operation *op, ArrayRef<Attribute> operands,
+              SmallVectorImpl<OpFoldResult> &results) {
+      return foldHook<ConcreteOpT>(op, operands, results);
+    };
   }
   /// The internal implementation of `getFoldHookFn` above that is invoked if
   /// the operation does not define a `fold` method.
@@ -1690,8 +1696,12 @@ private:
                               !detect_has_fold<ConcreteOpT>::value,
                           AbstractOperation::FoldHookFn>
   getFoldHookFnImpl() {
-    // In this case, we only need to fold the traits of the operation.
-    return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
+    return [](Operation *op, ArrayRef<Attribute> operands,
+              SmallVectorImpl<OpFoldResult> &results) {
+      // In this case, we only need to fold the traits of the operation.
+      return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
+                                                                  results);
+    };
   }
   /// Return the result of folding a single result operation that defines a
   /// `fold` method.
@@ -1735,7 +1745,8 @@ private:
   }
   /// Implementation of `GetHasTraitFn`
   static AbstractOperation::HasTraitFn getHasTraitFn() {
-    return &op_definition_impl::hasTrait<Traits...>;
+    return
+        [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
   }
   /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
   static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
@@ -1751,7 +1762,9 @@ private:
   static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
                           AbstractOperation::PrintAssemblyFn>
   getPrintAssemblyFnImpl() {
-    return &OpState::print;
+    return [](Operation *op, OpAsmPrinter &parser) {
+      return OpState::print(op, parser);
+    };
   }
   /// The internal implementation of `getPrintAssemblyFn` that is invoked when
   /// the concrete operation defines a `print` method.
index e7bcacf..20d73cc 100644 (file)
@@ -67,14 +67,17 @@ using OwningRewritePatternList = RewritePatternSet;
 /// the concrete operation types.
 class AbstractOperation {
 public:
-  using GetCanonicalizationPatternsFn = void (*)(RewritePatternSet &,
-                                                 MLIRContext *);
-  using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
-                                       SmallVectorImpl<OpFoldResult> &);
-  using HasTraitFn = bool (*)(TypeID);
-  using ParseAssemblyFn = ParseResult (*)(OpAsmParser &, OperationState &);
-  using PrintAssemblyFn = void (*)(Operation *, OpAsmPrinter &);
-  using VerifyInvariantsFn = LogicalResult (*)(Operation *);
+  using GetCanonicalizationPatternsFn =
+      llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
+  using FoldHookFn = llvm::unique_function<LogicalResult(
+      Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>;
+  using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+  using ParseAssemblyFn =
+      llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
+  using PrintAssemblyFn =
+      llvm::unique_function<void(Operation *, OpAsmPrinter &) const>;
+  using VerifyInvariantsFn =
+      llvm::unique_function<LogicalResult(Operation *) const>;
 
   /// This is the name of the operation.
   const Identifier name;
@@ -89,7 +92,7 @@ public:
   ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
 
   /// Return the static hook for parsing this operation assembly.
-  ParseAssemblyFn getParseAssemblyFn() const { return parseAssemblyFn; }
+  const ParseAssemblyFn &getParseAssemblyFn() const { return parseAssemblyFn; }
 
   /// This hook implements the AsmPrinter for this operation.
   void printAssembly(Operation *op, OpAsmPrinter &p) const {
@@ -175,20 +178,21 @@ public:
   /// Register a new operation in a Dialect object.
   /// The use of this method is in general discouraged in favor of
   /// 'insert<CustomOp>(dialect)'.
-  static void insert(StringRef name, Dialect &dialect, TypeID typeID,
-                     ParseAssemblyFn parseAssembly,
-                     PrintAssemblyFn printAssembly,
-                     VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
-                     GetCanonicalizationPatternsFn getCanonicalizationPatterns,
-                     detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
+  static void
+  insert(StringRef name, Dialect &dialect, TypeID typeID,
+         ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+         VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+         GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+         detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
 
 private:
   AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID,
-                    ParseAssemblyFn parseAssembly,
-                    PrintAssemblyFn printAssembly,
-                    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
-                    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
-                    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
+                    ParseAssemblyFn &&parseAssembly,
+                    PrintAssemblyFn &&printAssembly,
+                    VerifyInvariantsFn &&verifyInvariants,
+                    FoldHookFn &&foldHook,
+                    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+                    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
 
   /// A map of interfaces that were registered to this operation.
   detail::InterfaceMap interfaceMap;
index f1825a4..f438c00 100644 (file)
@@ -696,13 +696,15 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
 
 void AbstractOperation::insert(
     StringRef name, Dialect &dialect, TypeID typeID,
-    ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
-    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
-    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
-    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) {
-  AbstractOperation opInfo(
-      name, dialect, typeID, parseAssembly, printAssembly, verifyInvariants,
-      foldHook, getCanonicalizationPatterns, std::move(interfaceMap), hasTrait);
+    ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+    VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) {
+  AbstractOperation opInfo(name, dialect, typeID, std::move(parseAssembly),
+                           std::move(printAssembly),
+                           std::move(verifyInvariants), std::move(foldHook),
+                           std::move(getCanonicalizationPatterns),
+                           std::move(interfaceMap), std::move(hasTrait));
 
   auto &impl = dialect.getContext()->getImpl();
   assert(impl.multiThreadedExecutionContext == 0 &&
@@ -717,16 +719,18 @@ void AbstractOperation::insert(
 
 AbstractOperation::AbstractOperation(
     StringRef name, Dialect &dialect, TypeID typeID,
-    ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
-    VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
-    GetCanonicalizationPatternsFn getCanonicalizationPatterns,
-    detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait)
+    ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+    VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait)
     : name(Identifier::get(name, dialect.getContext())), dialect(dialect),
       typeID(typeID), interfaceMap(std::move(interfaceMap)),
-      foldHookFn(foldHook),
-      getCanonicalizationPatternsFn(getCanonicalizationPatterns),
-      hasTraitFn(hasTrait), parseAssemblyFn(parseAssembly),
-      printAssemblyFn(printAssembly), verifyInvariantsFn(verifyInvariants) {}
+      foldHookFn(std::move(foldHook)),
+      getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)),
+      hasTraitFn(std::move(hasTrait)),
+      parseAssemblyFn(std::move(parseAssembly)),
+      printAssemblyFn(std::move(printAssembly)),
+      verifyInvariantsFn(std::move(verifyInvariants)) {}
 
 //===----------------------------------------------------------------------===//
 // AbstractType
index f88f94c..aa0d12f 100644 (file)
@@ -1830,7 +1830,7 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
   // This is the actual hook for the custom op parsing, usually implemented by
   // the op itself (`Op::parse()`). We retrieve it either from the
   // AbstractOperation or from the Dialect.
-  std::function<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
+  function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssemblyFn;
   bool isIsolatedFromAbove = false;
 
   if (opDefinition) {