[Relay] Start porting pass to the pass manager (#3191)
author雾雨魔理沙 <lolisa@marisa.moe>
Fri, 24 May 2019 23:43:03 +0000 (16:43 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 24 May 2019 23:43:03 +0000 (16:43 -0700)
12 files changed:
include/tvm/relay/pass.h
include/tvm/relay/transform.h
src/relay/pass/dead_code.cc
src/relay/pass/device_annotation.cc
src/relay/pass/fold_constant.cc
src/relay/pass/forward_rewrite.cc
src/relay/pass/fuse_ops.cc
src/relay/pass/partial_eval.cc
src/relay/pass/pass_manager.cc
src/relay/pass/to_a_normal_form.cc
src/relay/pass/to_graph_normal_form.cc
tests/python/relay/test_pass_manager.py

index c84e3f9..67cc5df 100644 (file)
@@ -31,6 +31,7 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/type.h>
 #include <tvm/relay/adt.h>
+#include <tvm/relay/transform.h>
 #include <tvm/runtime/vm.h>
 #include <string>
 #include <vector>
@@ -84,7 +85,8 @@ TVM_DLL Function InferType(const Function& f, const Module& mod,
  */
 TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
 
-/*! \brief Compare two expressions for structural equivalence.
+/*!
+ * \brief Compare two expressions for structural equivalence.
  *
  * This comparison operator respects scoping and compares
  * expressions without regard to variable choice.
@@ -101,7 +103,8 @@ TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
  */
 TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
 
-/*! \brief Compare two types for structural equivalence.
+/*!
+ * \brief Compare two types for structural equivalence.
  *
  * This comparison operator respects scoping and compares
  * expressions without regard to variable choice.
@@ -119,7 +122,8 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
  */
 TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
 
-/*! \brief Add abstraction over a function
+/*!
+ * \brief Add abstraction over a function
  *
  * For example: `square` is transformed to
  * `fun x -> square x`.
@@ -135,7 +139,8 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
  */
 TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
 
-/*! \brief Check that each Var is only bound once.
+/*!
+ * \brief Check that each Var is only bound once.
  *
  * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
  *
@@ -148,7 +153,8 @@ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
  */
 TVM_DLL bool WellFormed(const Expr& expr);
 
-/*! \brief Get all bound variables from expression expr.
+/*!
+ * \brief Get all bound variables from expression expr.
  *
  * Bound variables are all variables that are declared in the expr.
  * They only have meaning inside that expr, and can only be used in it.
@@ -159,7 +165,8 @@ TVM_DLL bool WellFormed(const Expr& expr);
  */
 TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
 
-/*! \brief Get all bound variables from pattern pat.
+/*!
+ * \brief Get all bound variables from pattern pat.
  *
  * Bound variables are all variables that got bound by the pat.
  * They only have meaning inside that expr, and can only be used in it.
@@ -170,7 +177,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
  */
 TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
 
-/*! \brief Get free type parameters from expression expr.
+/*!
+ * \brief Get free type parameters from expression expr.
  *
  * Free variables are variables that are not bound by a
  * let or a function parameter in the context.
@@ -181,7 +189,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
  */
 TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
 
-/*! \brief Get all variables from expression expr.
+/*!
+ * \brief Get all variables from expression expr.
  *
  * \param expr the expression.
  *
@@ -189,7 +198,8 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
  */
 TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
 
-/*! \brief Get free TypeVars from expression expr.
+/*!
+ * \brief Get free TypeVars from expression expr.
  *
  * Free type parameters are type parameters that are not bound by a function
  * type in the context.
@@ -201,7 +211,8 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
  */
 TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
 
-/*! \brief Get free TypeVars from type t.
+/*!
+ * \brief Get free TypeVars from type t.
  *
  * Free type parameters are type parameters that are not bound by a function
  * type in the context.
@@ -213,7 +224,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
  */
 TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
 
-/*! \brief Get all bound type variables from expression expr.
+/*!
+ * \brief Get all bound type variables from expression expr.
  *
  * Bound variables are all type variables that are declared in the expr.
  * They only have meaning inside that expr, and can only be used in it.
@@ -225,7 +237,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
  */
 TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
 
-/*! \brief Get all bound type variables from type t.
+/*!
+ * \brief Get all bound type variables from type t.
  *
  * Bound variables are all type variables that are declared in the type.
  * They only have meaning inside that type, and can only be used in it.
@@ -237,7 +250,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
  */
 TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
 
-/*! \brief Get all type variables in expression expr.
+/*!
+ * \brief Get all type variables in expression expr.
  *
  * \param expr the expression.
  * \param mod the module.
@@ -246,7 +260,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
  */
 TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
 
-/*! \brief Get all type variables in type t.
+/*!
+ * \brief Get all type variables in type t.
  *
  * \param t the type.
  * \param mod the module.
@@ -273,22 +288,27 @@ TVM_DLL Expr DeadCodeElimination(const Expr& e);
 
 /*!
  * \brief Fold constant expressions.
+ *
  * \param expr the expression to be optimized.
+ *
  * \return The optimized expression.
  */
 TVM_DLL Expr FoldConstant(const Expr& expr);
 
 /*!
  * \brief Fuse operations into expr into seperate functions.
+ *
  * \param expr The expression.
  * \param fuse_opt_level Optimization level.
  * \param mod the module.
+ *
  * \return The optimized expression.
  */
 TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
 
 /*!
  * \brief Apply rewrite rules to rewrite the expr in post DFS order.
+ *
  * \param expr The expression.
  * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
  *                              rule function.
@@ -298,84 +318,68 @@ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
  * \return The rewritten expression.
  */
 TVM_DLL Expr ForwardRewrite(const Expr& expr,
-                    const std::string& rewrite_map_attr_name,
-                    std::function<NodeRef(const Call&)> fcontext = nullptr,
-                    std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
+                            const std::string& rewrite_map_attr_name,
+                            std::function<NodeRef(const Call&)> fcontext = nullptr,
+                            std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
 
 /*!
  * \brief Apply rewrite rules to rewrite the expr in post DFS order.
+ *
  * \param expr The expression.
  * \param rewrite_func The rewrite func that will apply to all operators.
  * \param fcontext Additional callback to provide context argument for each call node.
  * \param fmulti_ref_trigger Transformation function to be called when
  *                           an Expr consumed by multiple callers.
+ *
  * \return The rewritten expression.
  */
 TVM_DLL Expr ForwardRewrite(const Expr& expr,
-                    const FForwardRewrite& rewrite_func,
-                    std::function<NodeRef(const Call&)> fcontext = nullptr,
-                    std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
+                            const FForwardRewrite& rewrite_func,
+                            std::function<NodeRef(const Call&)> fcontext = nullptr,
+                            std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
 
 /*!
  * \brief Rewrite the annotated program.
+ *
  * \param expr The expression.
  * \param fallback_device The fallback device which is the default device for
  *                        operators without annotation.
+ *
  * \return The updated program.
  */
 TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
 
 /*!
  * \brief Collect the device mapping information of each expression.
+ *
  * \param expr The expression.
+ *
  * \return The device mapping.
  */
 TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
 
-/*! \brief A hashing structure in the style of std::hash. */
-struct StructuralHash {
-  /*! \brief Hash a Relay type.
-   *
-   * Implements structural hashing of a Relay type.
-   *
-   *  \param type the type to hash.
-   *
-   *  \return the hash value.
-   */
-  size_t operator()(const Type& type) const;
-
-  /*! \brief Hash a Relay expression.
-   *
-   * Implements structural hashing of a Relay expression.
-   *
-   * \param expr the expression to hash.
-   *
-   * \return the hash value.
-   */
-  size_t operator()(const Expr& expr) const;
-};
-
-/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
+/*!
+ * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
  *
  * It will turn an expression that is in a graph form (with sharing implicit),
  * to an expression with explicit sharing (A-Normal Form).
  *
  * The scope of the root expression is the global scope.
-
+ *
  * The scope of any non root expression is the least common ancestor of all it's scope.
  *
  * Values are ordered by post-DFS order in each scope.
  *
- * \param e the expression to observably share
- *
+ * \param e the expression to observably share.
  * \param mod The module used for referencing global functions, can be
  * None.
  *
- * \return expression in A-Normal Form
+ * \return expression in A-Normal Form.
  */
 TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
 
-/*! \brief Remove let binding and directly share via pointer instead.
+/*!
+ * \brief Remove let binding and directly share via pointer instead.
  *
  * It will remove all let binding,
  * and turn all of the variable bound by let into direct pointer reference.
@@ -386,18 +390,49 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
  */
 TVM_DLL Expr ToGraphNormalForm(const Expr& e);
 
-/*! \brief Aggressive constant propagation/constant folding/inlining.
+/*!
+ * \brief Aggressive constant propagation/constant folding/inlining.
+ *
  * It will do as much computation in compile time as possible.
  * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
  * As a side effect, code size will explode.
+ *
+ * \param e the expression,
+ *
+ * \return the optimized expression.
  */
-Expr PartialEval(const Expr& e);
+TVM_DLL Expr PartialEval(const Expr& e);
+
+/*! \brief A hashing structure in the style of std::hash. */
+struct StructuralHash {
+  /*! \brief Hash a Relay type.
+   *
+   * Implements structural hashing of a Relay type.
+   *
+   * \param type the type to hash.
+   *
+   * \return the hash value.
+   */
+  size_t operator()(const Type& type) const;
+
+  /*! \brief Hash a Relay expression.
+   *
+   * Implements structural hashing of a Relay expression.
+   *
+   * \param expr the expression to hash.
+   *
+   * \return the hash value.
+   */
+  size_t operator()(const Expr& expr) const;
+};
 
 namespace vm {
 
-/*! \brief Compile a module, and construct the virtual machine.
+/*!
+ * \brief Compile a module, and construct the virtual machine.
  *
  * \param mod The module to compile.
+ *
  * \return The constructed virtual machine.
  */
 runtime::vm::VirtualMachine CompileModule(const Module& mod);
index 5123f3a..4d6921a 100644 (file)
@@ -61,6 +61,7 @@
 #include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/module.h>
+#include <tvm/relay/op_attr_types.h>
 #include <string>
 #include <unordered_map>
 #include <vector>
@@ -198,7 +199,7 @@ class Pass;
  */
 class PassNode : public RelayNode {
  public:
-  /*
+  /*!
    * \brief Get the pass information/meta data. */
   virtual PassInfo Info() const = 0;
 
@@ -300,11 +301,118 @@ Pass CreateModulePass(
  *
  * \return The created function pass.
  */
-Pass CreateFunctionPass(
-    const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
-    int opt_level,
-    const std::string& name,
-    const tvm::Array<tvm::Expr>& required);
+TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
+                                  Function(Function, Module, PassContext)>& pass_func,
+                                int opt_level,
+                                const std::string& name,
+                                const tvm::Array<tvm::Expr>& required);
+
+/*! \brief Remove expressions which does not effect the program result.
+ *
+ * It will remove let bindings which are not referenced,
+ * and inline let bindings that are only used once.
+ *
+ * For example, this pass should turn `let a = 1 in 2` into `2`,
+ * as the value of the expression does not depend on a.
+ *
+ * As another example, `let a = 1 in a` will be optimized into 1.
+ *
+ * \return the pass.
+ */
+TVM_DLL Pass DeadCodeElimination();
+
+/*!
+ * \brief Fold constant expressions.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass FoldConstant();
+
+/*!
+ * \brief Fuse operations into expr into seperate functions.
+ *
+ * \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
+
+/*!
+ * \brief Apply rewrite rules to rewrite the expr in post DFS order.
+ *
+ * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
+ *                              rule function.
+ * \param fcontext Additional callback to provide context argument for each call node.
+ * \param fmulti_ref_trigger Transformation function to be called when
+ *                           an Expr consumed by multiple callers.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
+                            std::function<NodeRef(const Call&)> fcontext = nullptr,
+                            std::function<Expr(const Expr&)>
+                            fmulti_ref_trigger = nullptr);
+
+/*!
+ * \brief Apply rewrite rules to rewrite the expr in post DFS order.
+ *
+ * \param rewrite_func The rewrite func that will apply to all operators.
+ * \param fcontext Additional callback to provide context argument for each call node.
+ * \param fmulti_ref_trigger Transformation function to be called when
+ *                           an Expr consumed by multiple callers.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
+                            std::function<NodeRef(const Call&)> fcontext = nullptr,
+                            std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
+
+/*!
+ * \brief Rewrite the annotated program.
+ *
+ * \param fallback_device The fallback device which is the default device for
+ *                        operators without annotation.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
+
+/*!
+ * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
+ *
+ * It will turn an expression that is in a graph form (with sharing implicit),
+ * to an expression with explicit sharing (A-Normal Form).
+ *
+ * The scope of the root expression is the global scope.
+ *
+ * The scope of any non root expression is the least common ancestor of all it's scope.
+ *
+ * Values are ordered by post-DFS order in each scope.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ToANormalForm();
+
+/*!
+ * \brief Remove let binding and directly share via pointer instead.
+ *
+ * It will remove all let binding,
+ * and turn all of the variable bound by let into direct pointer reference.
+ *
+ * \return the expression in graph normal form.
+ */
+TVM_DLL Pass ToGraphNormalForm();
+
+/*!
+ * \brief Aggressive constant propagation/constant folding/inlining.
+ *
+ * It will do as much computation in compile time as possible.
+ * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
+ * As a side effect, code size will explode.
+ *
+ * \return the optimized expression.
+ */
+TVM_DLL Pass PartialEval();
 
 }  // namespace transform
 }  // namespace relay
index 533c214..dd1ed62 100644 (file)
@@ -151,5 +151,17 @@ Expr DeadCodeElimination(const Expr& e) {
 TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
 .set_body_typed(DeadCodeElimination);
 
+namespace transform {
+
+Pass DeadCodeElimination() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(DeadCodeElimination(f));
+  };
+  return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {});
+}
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 8807f6d..fa656db 100644 (file)
@@ -550,6 +550,18 @@ TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
 TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
 .set_body_typed(CollectDeviceAnnotationOps);
 
+namespace transform {
+
+Pass RewriteAnnotatedOps(int fallback_device) {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device));
+  };
+  return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", {});
+}
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
 
index c085d80..286392a 100644 (file)
@@ -215,5 +215,17 @@ Expr FoldConstant(const Expr& expr) {
 TVM_REGISTER_API("relay._ir_pass.FoldConstant")
 .set_body_typed(FoldConstant);
 
+namespace transform {
+
+Pass FoldConstant() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(FoldConstant(f));
+  };
+  return CreateFunctionPass(pass_func, 1, "fold_constant", {});
+}
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 88a2d66..2a3aa16 100644 (file)
@@ -206,6 +206,37 @@ Expr ForwardRewrite(const Expr& expr,
   return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
 }
 
+namespace transform {
+
+using std::function;
+
+Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
+                    function<NodeRef(const Call&)> fcontext,
+                    function<Expr(const Expr&)> fmulti_ref_trigger) {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(ForwardRewrite(f,
+                                             rewrite_map_attr_name,
+                                             fcontext,
+                                             fmulti_ref_trigger));
+  };
+  return CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
+}
+
+Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
+                    function<NodeRef(const Call&)> fcontext,
+                    function<Expr(const Expr&)> fmulti_ref_trigger) {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(ForwardRewrite(f,
+                                             rewrite_func,
+                                             fcontext,
+                                             fmulti_ref_trigger));
+  };
+  return CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
+}
+
+}  // namespace transform
 
 }  // namespace relay
 }  // namespace tvm
index d0d0cab..9277689 100644 (file)
@@ -964,5 +964,19 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
 
 TVM_REGISTER_API("relay._ir_pass.FuseOps")
 .set_body_typed(FuseOps);
+
+namespace transform {
+
+Pass FuseOps(int fuse_opt_level) {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
+    return Downcast<Function>(FuseOps(f, opt_level, m));
+  };
+  return CreateFunctionPass(pass_func, 1, "fuse_ops", {});
+}
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index ad86174..3f42c6f 100644 (file)
@@ -801,5 +801,17 @@ TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
     *ret = PartialEval(args[0]);
   });
 
+namespace transform {
+
+Pass PartialEval() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(PartialEval(f));
+  };
+  return CreateFunctionPass(pass_func, 1, "partial_eval", {});
+}
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 4bcc0bb..ea4c976 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -201,7 +201,7 @@ class FunctionPassNode : public PassNode {
    * `pass_func` and let it run on a given module. The same `pass_func` will
    * then be applied on each function in the module.
    */
-  runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func;
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
 
   FunctionPassNode() = default;
 
@@ -225,7 +225,7 @@ class FunctionPassNode : public PassNode {
   PassInfo Info() const { return pass_info; }
 
   TVM_DLL static FunctionPass make(
-      runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
+      runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
       PassInfo pass_info);
 
   static constexpr const char* _type_key = "relay.FunctionPass";
@@ -363,7 +363,7 @@ Module ModulePassNode::operator()(const Module& mod,
 }
 
 FunctionPass FunctionPassNode::make(
-    runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
+    runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
     PassInfo pass_info) {
   auto n = make_node<FunctionPassNode>();
   n->pass_func = std::move(pass_func);
@@ -383,8 +383,7 @@ Module FunctionPassNode::operator()(const Module& mod,
 
   // Execute the pass function and return a new module.
   for (const auto& it : mod->functions) {
-    auto updated_func =
-        SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx);
+    auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx);
     new_mod->Add(it.first, updated_func);
   }
 
@@ -501,7 +500,7 @@ Pass CreateModulePass(
 }
 
 Pass CreateFunctionPass(
-    const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
+    const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func,
     int opt_level,
     const std::string& name,
     const tvm::Array<tvm::Expr>& required) {
@@ -589,7 +588,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
                                  tvm::IRPrinter* p) {
   const PassInfoNode* seq_pn = node->Info().operator->();
   p->stream << "Run Sequential pass: " << seq_pn->name
-            << " at the optimization level. " << seq_pn->opt_level;
+            << " at the optimization level " << seq_pn->opt_level << ". ";
   p->stream << "The passes will be executed are: [";
   for (const auto& it : node->passes) {
     const PassNode* pn = it.operator->();
index 913f8de..f9d47f7 100644 (file)
@@ -333,5 +333,17 @@ Expr ToANormalForm(const Expr& e, const Module& m) {
 TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
 .set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
 
+namespace transform {
+
+Pass ToANormalForm() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(ToANormalForm(f, m));
+  };
+  return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {});
+}
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 490a80f..50ebb70 100644 (file)
@@ -79,5 +79,17 @@ Expr ToGraphNormalForm(const Expr& e) {
 TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
 .set_body_typed(ToGraphNormalForm);
 
+namespace transform {
+
+Pass ToGraphNormalForm() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(ToGraphNormalForm(f));
+  };
+  return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {});
+}
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index db346e7..2703e5c 100644 (file)
@@ -204,7 +204,7 @@ def test_function_pass():
     pass_ctx = None
 
     @_transform.function_pass(opt_level=opt_level, name=pass_name)
-    def transform(expr, ctx):
+    def transform(expr, mod, ctx):
         return opt_tester.transform(expr, ctx)
 
     def get_ref_log():
@@ -303,7 +303,7 @@ def test_sequential_pass():
 
     # Register a function pass.
     @_transform.function_pass(opt_level=1)
-    def func_transform(expr, ctx):
+    def func_transform(expr, mod, ctx):
         return opt_tester.transform(expr, ctx)
 
     function_pass = func_transform