Use dialect hook registration for constant folding hook.
authorTatiana Shpeisman <shpeisman@google.com>
Mon, 25 Feb 2019 16:37:28 +0000 (08:37 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:40:35 +0000 (16:40 -0700)
Deletes specialized mechanism for registering constant folding hook and uses dialect hooks registration mechanism instead.

PiperOrigin-RevId: 235535410

mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/DialectHooks.h
mlir/lib/IR/Dialect.cpp

index 9f7732e3776675a59956994f0af0348db9e75790..55b6f7efd3650e12eda2cd9aaa2d57d86047a97c 100644 (file)
@@ -178,26 +178,11 @@ private:
 };
 
 using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
-using ConstantFoldHookAllocator = std::function<void(MLIRContext *)>;
 
 /// Registers a specific dialect creation function with the system, typically
 /// used through the DialectRegistration template.
 void registerDialectAllocator(const DialectAllocatorFunction &function);
 
-/// Registers a constant fold hook for one or multiple dialects. The
-/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted
-/// dialect(s) in the context.
-/// Exmaple:
-///      registerConstantFoldHook([&](MLIRContext *ctx) {
-///        auto dialects = ctx->getRegisteredDialects();
-///        // then iterate and select the target dialect from dialects, or
-///        // get one dialect directly by the prefix:
-///        auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX")
-///
-///        dialect->constantFoldHook = MyConstantFoldHook;
-///      });
-void registerConstantFoldHook(const ConstantFoldHookAllocator &function);
-
 /// Registers all dialects with the specified MLIRContext.
 void registerAllDialects(MLIRContext *context);
 
index dbfb1ab33c700a31dc5c641d2f06530e703a2a8c..f368988b5b40e30d1591d10317e924db5de54215 100644 (file)
@@ -38,6 +38,8 @@ using DialectHooksSetter = std::function<void(MLIRContext *)>;
 /// The subclass should override DialectHook methods for supported hooks.
 class DialectHooks {
 public:
+  // Returns hook to constant fold an operation.
+  DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
   // Returns hook to decode opaque constant tensor.
   DialectConstantDecodeHook getDecodeHook() { return nullptr; }
   // Returns hook to extract an element of an opaque constant tensor.
@@ -65,6 +67,8 @@ template <typename ConcreteHooks> struct DialectHooksRegistration {
       }
       // Set hooks.
       ConcreteHooks hooks;
+      if (auto h = hooks.getConstantFoldHook())
+        dialect->constantFoldHook = h;
       if (auto h = hooks.getDecodeHook())
         dialect->decodeHook = h;
       if (auto h = hooks.getExtractElementHook())
index 249c9d84c1f7cd054ae1e7a0002abad1a09e59e1..338c918c33965c7134a9da4325b9428405dc1f51 100644 (file)
@@ -25,10 +25,6 @@ using namespace mlir;
 static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
     dialectRegistry;
 
-// Registry for dialect's constant fold hooks.
-static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
-    constantFoldHookRegistry;
-
 // Registry for functions that set dialect hooks.
 static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>>
     dialectHooksRegistry;
@@ -41,14 +37,6 @@ void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
   dialectRegistry->push_back(function);
 }
 
-/// Registers a constant fold hook for a specific dialect with the system.
-void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
-  assert(
-      function &&
-      "Attempting to register an empty constant fold hook initialize function");
-  constantFoldHookRegistry->push_back(function);
-}
-
 /// Registers a function to set specific hooks for a specific dialect, typically
 /// used through the DialectHooksRegistreation template.
 void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
@@ -64,8 +52,6 @@ void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
 void mlir::registerAllDialects(MLIRContext *context) {
   for (const auto &fn : *dialectRegistry)
     fn(context);
-  for (const auto &fn : *constantFoldHookRegistry)
-    fn(context);
   for (const auto &fn : *dialectHooksRegistry) {
     fn(context);
   }