Adds ConstantFoldHook registry in MLIRContext
authorFeng Liu <fengliuai@google.com>
Tue, 20 Nov 2018 22:47:10 +0000 (14:47 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:04:34 +0000 (14:04 -0700)
This reverts the previous method which needs to create a new dialect with the
constant fold hook from TensorFlow. This new method uses a function object in
dialect to store the constant fold hook. Once a hook is registered to the
dialect, this function object will be assigned when the dialect is added to the
MLIRContext.

For the operations which are not registered, a new method getRegisteredDialects
is added to the MLIRContext to query the dialects which matches their op name
prefixes.

PiperOrigin-RevId: 222310149

mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp

index 76783b82d3930a7a695ef00b551f7d9f89230e77..72ea426070bd7b3fa926fd2b8b4a0a4236d2742f 100644 (file)
@@ -26,6 +26,9 @@
 
 namespace mlir {
 
+using DialectConstantFoldHook = std::function<bool(
+    const Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
+
 /// Dialects are groups of MLIR operations and behavior associated with the
 /// entire group.  For example, hooks into other systems for constant folding,
 /// default named types for asm printing, etc.
@@ -39,19 +42,16 @@ public:
 
   StringRef getOperationPrefix() const { return opPrefix; }
 
-  /// Dialect implementations can implement this hook. It should attempt to
-  /// constant fold this operation with the specified constant operand values -
-  /// the elements in "operands" will correspond directly to the operands of the
-  /// operation, but may be null if non-constant.  If constant folding is
-  /// successful, this returns false and fills in the `results` vector.  If not,
-  /// this returns true and `results` is unspecified.
-  ///
-  /// If not overridden, this fallback implementation always fails to fold.
-  ///
-  virtual bool constantFold(const Operation *op, ArrayRef<Attribute> operands,
-                            SmallVectorImpl<Attribute> &results) const {
-    return true;
-  }
+  /// Registered fallback constant fold hook for the dialect. Like the constant
+  /// fold hook of each operation, it attempts to constant fold the operation
+  /// with the specified constant operand values - the elements in "operands"
+  /// will correspond directly to the operands of the operation, but may be null
+  /// if non-constant.  If constant folding is successful, this returns false
+  /// and fills in the `results` vector.  If not, this returns true and
+  /// `results` is unspecified.
+  DialectConstantFoldHook constantFoldHook =
+      [](const Operation *op, ArrayRef<Attribute> operands,
+         SmallVectorImpl<Attribute> &results) { return true; };
 
   // TODO: Hook to return the list of named types that are known.
 
@@ -108,11 +108,26 @@ private:
 };
 
 using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
+using ConstantFoldHookAllocator = std::function<void(MLIRContext *)>;
 
-/// Register a specific dialect creation function with the system, typically
+/// Registers a specific dialect creation function with the system, typically
 /// used through the DialectRegistration template.
 void registerDialectAllocator(const DialectAllocatorFunction &function);
 
+/// Registers a constant fold hook for one or multiple dialects. The
+/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted
+/// dialect(s) in the context.
+/// Exmaple:
+///      registerConstantFoldHook([&](MLIRContext *ctx) {
+///        auto dialects = ctx->getRegisteredDialects();
+///        // then iterate and select the target dialect from dialects, or
+///        // get one dialect directly by the prefix:
+///        auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX")
+///
+///        dialect->constantFoldHook = MyConstantFoldHook;
+///      });
+void registerConstantFoldHook(const ConstantFoldHookAllocator &function);
+
 /// Registers all dialects with the specified MLIRContext.
 void registerAllDialects(MLIRContext *context);
 
index 6fd703dc2873d25ca38c6e139d19b30955079305..5532c4db89b0671caeedf78bd72cf17f8f078005 100644 (file)
@@ -45,6 +45,10 @@ public:
   /// Return information about all registered IR dialects.
   std::vector<Dialect *> getRegisteredDialects() const;
 
+  /// Get registered IR dialect which has the longest matching with the given
+  /// prefix. If none is found, returns nullptr.
+  Dialect *getRegisteredDialect(StringRef prefix) const;
+
   /// Return information about all registered operations.  This isn't very
   /// efficient: typically you should ask the operations about their properties
   /// directly.
@@ -97,4 +101,4 @@ private:
 };
 } // end namespace mlir
 
-#endif  // MLIR_IR_MLIRCONTEXT_H
+#endif // MLIR_IR_MLIRCONTEXT_H
index 79b0a9ce0762e8032a9a4dffa8abab89e6ac6575..7640ce93c94491b009a031ac187c18ad5d46789f 100644 (file)
@@ -24,17 +24,33 @@ using namespace mlir;
 static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
     dialectRegistry;
 
-/// Register a specific dialect creation function with the system, typically
+// Registry for dialect's constant fold hooks.
+static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
+    constantFoldHookRegistry;
+
+/// Registers a specific dialect creation function with the system, typically
 /// used through the DialectRegistration template.
 void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
-  assert(function && "Attempting to register an empty op initialize function");
+  assert(function &&
+         "Attempting to register an empty dialect initialize function");
   dialectRegistry->push_back(function);
 }
 
-/// Registers all dialects with the specified MLIRContext.
+/// Registers a constant fold hook for a specific dialect with the system.
+void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
+  assert(
+      function &&
+      "Attempting to register an empty constant fold hook initialize function");
+  constantFoldHookRegistry->push_back(function);
+}
+
+/// Registers all dialects and their const folding hooks with the specified
+/// MLIRContext.
 void mlir::registerAllDialects(MLIRContext *context) {
   for (const auto &fn : *dialectRegistry)
     fn(context);
+  for (const auto &fn : *constantFoldHookRegistry)
+    fn(context);
 }
 
 Dialect::Dialect(StringRef opPrefix, MLIRContext *context)
index 9c7824e9088aadd911b9d050fb662cc7e50270b9..423d2cdda7102feccee3db2096bb38183d6f0ed8 100644 (file)
@@ -516,6 +516,19 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
   return result;
 }
 
+/// Get registered IR dialect which has the longest matching with the given
+/// prefix. If none is found, returns nullptr.
+Dialect *MLIRContext::getRegisteredDialect(StringRef prefix) const {
+  Dialect *result = nullptr;
+  for (auto &dialect : getImpl().dialects) {
+    if (prefix.startswith(dialect->getOperationPrefix()))
+      if (!result || result->getOperationPrefix().size() <
+                         dialect->getOperationPrefix().size())
+        result = dialect.get();
+  }
+  return result;
+}
+
 /// Register this dialect object with the specified context.  The context
 /// takes ownership of the heap allocated dialect.
 void Dialect::registerDialect(MLIRContext *context) {
index 728d2362c99d17b010a24205f1958f49d9dce2c0..aa878b1f0ae8baac38232133c913fb5bebb1a9c8 100644 (file)
@@ -304,9 +304,16 @@ bool Operation::constantFold(ArrayRef<Attribute> operands,
       return false;
 
     // Otherwise, fall back on the dialect hook to handle it.
-    Dialect &dialect = abstractOp->dialect;
-    return dialect.constantFold(this, operands, results);
+    return abstractOp->dialect.constantFoldHook(this, operands, results);
   }
+
+  // If this operation hasn't been registered or doesn't have abstract
+  // operation, fall back to a dialect which matches the prefix.
+  auto opName = getName().getStringRef();
+  if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
+    return dialect->constantFoldHook(this, operands, results);
+  }
+
   return true;
 }