[RELAY][TRANSFORM] Migrate buildmodule to transform (#3251)
authorZhi <5145158+zhiics@users.noreply.github.com>
Mon, 3 Jun 2019 17:40:38 +0000 (10:40 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 3 Jun 2019 17:40:38 +0000 (10:40 -0700)
24 files changed:
include/tvm/relay/module.h
include/tvm/relay/pass.h
include/tvm/relay/transform.h
python/tvm/relay/build_module.py
python/tvm/relay/transform.py
src/relay/backend/build_module.cc
src/relay/pass/alter_op_layout.cc
src/relay/pass/canonicalize_ops.cc
src/relay/pass/combine_parallel_conv2d.cc
src/relay/pass/dead_code.cc
src/relay/pass/device_annotation.cc
src/relay/pass/eliminate_common_subexpr.cc
src/relay/pass/fold_constant.cc
src/relay/pass/fold_scale_axis.cc
src/relay/pass/forward_rewrite.cc
src/relay/pass/fuse_ops.cc
src/relay/pass/partial_eval.cc
src/relay/pass/pass_manager.cc
src/relay/pass/simplify_inference.cc
src/relay/pass/to_a_normal_form.cc
src/relay/pass/to_graph_normal_form.cc
src/relay/pass/type_infer.cc
tests/cpp/relay_transform_sequential.cc [new file with mode: 0644]
tests/python/relay/test_pass_manager.py

index 6441fb3..3966a62 100644 (file)
@@ -87,14 +87,14 @@ class ModuleNode : public RelayNode {
    * \param update Controls whether you can replace a definition in the
    * environment.
    */
-  void Add(const GlobalVar& var, const Function& func, bool update = false);
+  TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
 
   /*!
    * \brief Add a type-level definition to the global environment.
    * \param var The var of the global type definition.
    * \param type The type definition.
    */
-  void AddDef(const GlobalTypeVar& var, const TypeData& type);
+  TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type);
 
   /*!
    * \brief Add a function to the global environment.
@@ -103,69 +103,69 @@ class ModuleNode : public RelayNode {
    *
    * It does not do type inference as Add does.
    */
-  void AddUnchecked(const GlobalVar& var, const Function& func);
+  TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
 
   /*!
    * \brief Update a function in the global environment.
    * \param var The name of the global function to update.
    * \param func The new function.
    */
-  void Update(const GlobalVar& var, const Function& func);
+  TVM_DLL void Update(const GlobalVar& var, const Function& func);
 
   /*!
    * \brief Remove a function from the global environment.
    * \param var The name of the global function to update.
    */
-  void Remove(const GlobalVar& var);
+  TVM_DLL void Remove(const GlobalVar& var);
 
   /*!
    * \brief Lookup a global function by its variable.
    * \param str The unique string specifying the global variable.
    * \returns The global variable.
    */
-  GlobalVar GetGlobalVar(const std::string& str);
+  TVM_DLL GlobalVar GetGlobalVar(const std::string& str);
 
   /*!
    * \brief Look up a global function by its name.
    * \param str The unique string specifying the global variable.
    * \returns The global variable.
    */
-  GlobalTypeVar GetGlobalTypeVar(const std::string& str);
+  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str);
 
   /*!
    * \brief Lookup a global function by its variable.
    * \param var The global var to lookup.
    * \returns The function named by the variable argument.
    */
-  Function Lookup(const GlobalVar& var);
+  TVM_DLL Function Lookup(const GlobalVar& var);
 
   /*!
    * \brief Lookup a global function by its string name
    * \param name The name of the function.
    * \returns The function named by the argument.
    */
-  Function Lookup(const std::string& name);
+  TVM_DLL Function Lookup(const std::string& name);
 
   /*!
    * \brief Lookup a global type definition by its variable.
    * \param var The var of the global type definition.
    * \return The type definition.
    */
-  TypeData LookupDef(const GlobalTypeVar& var);
+  TVM_DLL TypeData LookupDef(const GlobalTypeVar& var);
 
   /*!
    * \brief Lookup a global type definition by its name.
    * \param var The name of the global type definition.
    * \return The type definition.
    */
-  TypeData LookupDef(const std::string& var);
+  TVM_DLL TypeData LookupDef(const std::string& var);
 
   /*!
    * \brief Update the functions inside this environment by
    *        functions in another environment.
    * \param other The other environment.
    */
-  void Update(const Module& other);
+  TVM_DLL void Update(const Module& other);
 
   /*! \brief Construct a module from a standalone expression.
    *
@@ -177,7 +177,7 @@ class ModuleNode : public RelayNode {
    *
    * \returns A module with expr set as the entry point.
    */
-  static Module FromExpr(
+  TVM_DLL static Module FromExpr(
     const Expr& expr,
     const tvm::Map<GlobalVar, Function>& global_funcs = {});
 
index 67cc5df..8158733 100644 (file)
@@ -359,6 +359,15 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
 TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
 
 /*!
+ * \brief Collect the device anntation operators.
+ *
+ * \param expr The expression.
+ *
+ * \return The annotated expression to device type mapping for annotation ops.
+ */
+TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
+
+/*!
  * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
  *
  * It will turn an expression that is in a graph form (with sharing implicit),
@@ -403,6 +412,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
  */
 TVM_DLL Expr PartialEval(const Expr& e);
 
+/*!
+ * \brief Bind the free variables to a Relay expression.
+ *
+ * \param expr The expression.
+ * \param bind_map The variable to expression map that will be used to help the
+ *        binding.
+ *
+ * \return The updated expression.
+ */
+TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);
+
 /*! \brief A hashing structure in the style of std::hash. */
 struct StructuralHash {
   /*! \brief Hash a Relay type.
index 1c1b608..793bc98 100644 (file)
 
 #include <tvm/base.h>
 #include <tvm/packed_func_ext.h>
+#include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/module.h>
+#include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <string>
 #include <unordered_map>
@@ -292,9 +294,9 @@ class Sequential : public Pass {
    * \param passes The passes to apply.
    * \param pass_info The pass metadata.
    */
-  TVM_DLL Sequential(tvm::Array<Pass> passes,
-                     PassInfo pass_info);
-/*!
+  TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
+
+  /*!
    * \brief The constructor of `Sequential`.
    *
    * \param passes The passes to apply.
@@ -311,7 +313,6 @@ class Sequential : public Pass {
   using ContainerType = Sequential;
 };
 
-
 /*
  * \brief Create a module pass.
  *
@@ -339,7 +340,7 @@ Pass CreateModulePass(
  * \return The created function pass.
  */
 TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
-                                  Function(Function, Module, PassContext)>& pass_func,
+                                Function(Function, Module, PassContext)>& pass_func,
                                 int opt_level,
                                 const std::string& name,
                                 const tvm::Array<tvm::Expr>& required);
@@ -451,6 +452,85 @@ TVM_DLL Pass ToGraphNormalForm();
  */
 TVM_DLL Pass PartialEval();
 
+/*!
+ * \brief Simplify certain operators during inference. For example, batch norm
+ * will be unpacked into a number of simplified operators.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass SimplifyInference();
+
+/*!
+ * \brief Infer the type of an expression.
+ *
+ * The result of type checking is a new expression with unambigous
+ * type information filled in, as well as it's checked type field
+ * populated with the result type.
+ *
+ * \return The pass. 
+ */
+TVM_DLL Pass InferType();
+
+/*!
+ * \brief Search and eliminate common subexpression. For example, if there are
+ * two expressions evaluated to an identical value, a single variable is created
+ * and these two expressions are replaced by this variable.
+ *
+ * \param fskip The callback argument that allows to skip certain expressions.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
+
+/*!
+ * \brief Combine parallel 2d convolutions into a single convolution if the
+ * number of branches of this conv2d operator is not less than
+ * `min_num_branch`.
+ *
+ * \param min_num_branches The minimun number of branches.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
+
+/*!
+ * \brief Backward fold axis scaling into weights of conv/dense operators.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass BackwardFoldScaleAxis();
+
+/*!
+ * \brief Forward fold axis scaling into weights of conv/dense operators.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ForwardFoldScaleAxis();
+
+/*!
+ * \brief A sequential pass that executes ForwardFoldScaleAxis and
+ * BackwardFoldScaleAxis passes.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass FoldScaleAxis();
+
+/*!
+ * \brief Canonicalize some operators to the simplified operators. For example,
+ * bias_add can be canonicalized to expand_dims and broadcast_add.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CanonicalizeOps();
+
+/*!
+ * \brief Alternate the layouts of operators or replace primitive operators
+ * with other expressions.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass AlterOpLayout();
+
 }  // namespace transform
 }  // namespace relay
 }  // namespace tvm
index 6cee393..8f9b048 100644 (file)
@@ -20,7 +20,6 @@ from a Relay expression.
 """
 import numpy as np
 
-from tvm._ffi.runtime_ctypes import TVMContext
 from tvm import expr as tvm_expr
 from .. import nd as _nd, target as _target, autotvm
 from ..contrib import graph_runtime as _graph_rt
@@ -28,7 +27,6 @@ from . import _build_module
 from . import ir_pass
 from . import ty as _ty
 from . import expr as _expr
-from . import transform as _transform
 from .backend import interpreter as _interpreter
 from .backend.vm import VMExecutor
 
@@ -61,10 +59,6 @@ class BuildModule(object):
         self._get_graph_json = self.mod["get_graph_json"]
         self._get_module = self.mod["get_module"]
         self._build = self.mod["build"]
-        self._add_pass = self.mod["add_pass"]
-        self._disable_pass = self.mod["disable_pass"]
-        self._set_opt_level = self.mod["set_opt_level"]
-        self._set_fallback_device = self.mod["set_fallback_device"]
         self._set_params_func = self.mod["set_params"]
         self._get_params_func = self.mod["get_params"]
 
@@ -106,8 +100,9 @@ class BuildModule(object):
         """
         target = _update_target(target)
 
-        # Setup the build configurations passed in through `with build_config`.
-        self._setup_build_config(params)
+        # Setup the params.
+        if params:
+            self._set_params(params)
         # Build the function
         self._build(func, target, target_host)
         # Get artifacts
@@ -117,41 +112,6 @@ class BuildModule(object):
 
         return graph_json, mod, params
 
-    def _setup_build_config(self, params):
-        cfg = _transform.PassContext.current()
-
-        # Set opt_level.
-        self.set_opt_level(cfg.opt_level)
-
-        # Set fallback device if it is available.
-        if cfg.fallback_device:
-            self.set_fallback_device(cfg.fallback_device)
-
-        # Add required passes.
-        if cfg.required_pass:
-            passes = set()
-            if isinstance(cfg.required_pass, (list, tuple, set)):
-                passes = set(cfg.required_pass)
-            else:
-                raise TypeError("add_pass must be list, tuple, or set, but " +
-                                "got {}".format(type(cfg.required_pass)))
-            for pass_name in passes:
-                self.add_pass(pass_name)
-
-        # Add disabled passes.
-        if cfg.disabled_pass:
-            passes = set()
-            if isinstance(cfg.disabled_pass, (list, tuple, set)):
-                passes = set(cfg.disabled_pass)
-            else:
-                raise TypeError("disable_pass must be list, tuple, or set, " +
-                                "but got {}".format(type(cfg.disabled_pass)))
-            for pass_name in passes:
-                self.disable_pass(pass_name)
-
-        if params:
-            self._set_params(params)
-
     def _set_params(self, params):
         inputs = {}
         for name, param in params.items():
@@ -160,28 +120,6 @@ class BuildModule(object):
             inputs[name] = _expr.const(param)
         self._set_params_func(inputs)
 
-    def add_pass(self, pass_name):
-        """Add a pass to the pass list.
-
-        Parameters
-        ----------
-        pass_name : str
-            The name of the pass that will be added to the list of passes used
-            for optimizations.
-        """
-        self._add_pass(pass_name)
-
-    def disable_pass(self, pass_name):
-        """Add a pass to the disabled pass list.
-
-        Parameters
-        ----------
-        pass_name : str
-            The name of a pass. This pass will be added to the list of passes
-            that are disabled during optimization.
-        """
-        self._disable_pass(pass_name)
-
     def get_json(self):
         """Return the json file of the built program."""
         return self._get_graph_json()
@@ -198,32 +136,6 @@ class BuildModule(object):
             ret[key] = value.data
         return ret
 
-    def set_opt_level(self, level):
-        """Set the optimization level.
-
-        Parameters
-        ----------
-        level : int
-            The optimization level for build.
-        """
-        self._set_opt_level(level)
-
-    def set_fallback_device(self, fallback_device):
-        """Set the fallback device for heterogeneous execution.
-
-        Parameters
-        ----------
-        fallback_device : str or tvm.TVMContext
-            The fallback device used for heterogeneous execution.
-        """
-        if isinstance(fallback_device, (int, str)):
-            fallback_device = _nd.context(fallback_device)
-        if not isinstance(fallback_device, TVMContext):
-            raise TypeError("fallback_device is expected to be str, int, or " +
-                            "TVMContext but received: {}".format(type(fallback_device)))
-
-        self._set_fallback_device(fallback_device.device_type)
-
 
 def build(func, target=None, target_host=None, params=None):
     """Helper function that builds a Relay function to run on TVM graph
index a7887c6..38079b0 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=no-else-return
 # pylint: disable=unidiomatic-typecheck
+# pylint: disable=invalid-name
 """
 This file contains the pass manager for Relay which exposes different
 granularity of interfaces for users to implement and use passes more
@@ -394,3 +395,201 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
     if pass_func:
         return create_function_pass(pass_func)
     return create_function_pass
+
+
+def InferType():
+    """Infer the type of an expr.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered type inference pass.
+    """
+    return _transform.InferType()
+
+
+def FoldScaleAxis():
+    """Fold the scaling of axis into weights of conv2d/dense. This pass will
+    invoke both forward and backward scale folding.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass to fold expressions.
+
+    Note
+    ----
+    Internally, we will call backward_fold_scale_axis before using
+    forward_fold_scale_axis. As backward folding targets common conv-bn
+    pattern.
+    """
+    return _transform.FoldScaleAxis()
+
+
+def SimplifyInference():
+    """Simplify the data-flow graph for inference phase. An simplified expression
+    which is semantically equal to the input expression will be returned.
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered to perform operator simplification.
+    """
+    return _transform.SimplifyInference()
+
+
+def CanonicalizeOps():
+    """ Canonicalize special operators to basic operators.
+    This can simplify followed analysis. (e.g. expanding bias_add to
+    expand_dims and broadcast_add.)
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass performing the canonicalization.
+    """
+    return _transform.CanonicalizeOps()
+
+
+def DeadCodeElimination():
+    """ Remove expressions which does not effect the program result (dead code).
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass that eliminates the dead code in a Relay program.
+    """
+    return _transform.DeadCodeElimination()
+
+
+def FoldConstant():
+    """Fold the constant expression in expr.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass for constant folding.
+    """
+    return _transform.FoldConstant()
+
+
+def FuseOps(fuse_opt_level=-1):
+    """Fuse operators in an expr to a larger operator according to some rules.
+
+    Parameters
+    ----------
+    fuse_opt_level : int
+        The level of fuse optimization. -1 indicates that the level will be
+        inferred from pass context.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass for operator fusion.
+    """
+    return _transform.FuseOps(fuse_opt_level)
+
+
+def CombineParallelConv2D(min_num_branches=3):
+    """Combine multiple conv2d operators into one.
+
+    Parameters
+    ----------
+    min_num_branches : int
+        The minimum number of required parallel branches for performing this
+        optimization.
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass that combines parallel conv2d operators.
+    """
+    return _transform.CombineParallelConv2D(min_num_branches)
+
+
+def AlterOpLayout():
+    """Alternate the layouts of operators or replace primitive operators with
+    other expressions.
+    This pass can be used for computing convolution in custom layouts or
+    other general weight pre-transformation.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass that alters the layout of operators.
+    """
+    return _transform.AlterOpLayout()
+
+
+def RewriteAnnotatedOps(fallback_device):
+    """Rewrite the annotated program where annotation operators, e.g.
+    `on_deivce`, mark which device an expression should be scheduled to.
+    This pass helps heterogeneous execution where different operators may need
+    to be allocated on various devices.
+
+    Parameters
+    ----------
+    fallback_device : int
+        The fallback device type. It is also used as the default device for
+        operators with no annotated device.
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass that rewrites an expression with annotated
+        `on_device` operators.
+    """
+    return _transform.RewriteDeviceAnnotation(fallback_device)
+
+
+def ToANormalForm():
+    """Turn Graph Normal Form expression into A Normal Form Expression.
+    The scope of the root expression is the global scope.
+    The scope of any non root expression is the least common ancestor of all it's scope.
+    Values are ordered by post-DFS order in each scope.
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass that transforms an expression into A Normal Form.
+    """
+    return _transform.ToANormalForm()
+
+
+def ToGraphNormalForm():
+    """Turn A Normal Form expression into Graph Normal Form expression
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass that transforms an expression into Graph Normal Form.
+    """
+    return _transform.ToGraphNormalForm()
+
+
+def EliminateCommonSubexpr(fskip=None):
+    """Eliminate common subexpressions.
+
+    Parameters
+    ----------
+    fskip: Callable
+        The callback function that decides whether an expression should be
+        skipped.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass that eliminates common subexpressions.
+    """
+    return _transform.EliminateCommonSubexpr(fskip)
+
+
+def PartialEvaluate():
+    """Evaluate the static fragment of the code.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass that performs partial evaluation on an expression.
+    """
+    return _transform.PartialEvaluate()
index 57dc256..e0014e9 100644 (file)
  */
 #include <tvm/build_module.h>
 #include <tvm/runtime/device_api.h>
-#include <tvm/relay/op.h>
 #include <tvm/relay/expr.h>
-#include <tvm/relay/attrs/nn.h>
-#include <tvm/relay/attrs/transform.h>
-#include <vector>
-#include <string>
+#include <tvm/relay/transform.h>
 #include <memory>
 
 #include "utils.h"
@@ -38,39 +34,7 @@ namespace relay {
 namespace backend {
 
 using TargetsMap = Map<tvm::Integer, tvm::Target>;
-
-/*!
- * \brief A data structure to map the names of specific optimizations to
- *        numeric optimization levels
- *
- */
-struct OptPassLevel {
-  static const std::unordered_map<std::string, int> _data;
-  /*!
-   * \brief Get level for an optimization pass
-   *
-   * \param key pass name
-   * \return int level
-   */
-  int operator[](const std::string& key) const {
-    auto it = _data.find(key);
-    if (it == _data.end()) {
-      return -1;
-    }
-    return it->second;
-  }
-};
-
-const std::unordered_map<std::string, int> OptPassLevel::_data = {
-  {"SimplifyInference", 0},
-  {"OpFusion", 1},
-  {"FoldConstant", 2},
-  {"CombineParallelConv2D", 4},
-  {"FoldScaleAxis", 3},
-  {"AlterOpLayout", 3},
-  {"CanonicalizeOps", 3},
-  {"EliminateCommonSubexpr", 3}
-};
+using namespace tvm::relay::transform;
 
 /*!
  * \brief Output of building module
@@ -83,27 +47,6 @@ struct BuildOutput {
 };
 
 /*!
- * \brief Relay building config
- *
- */
-struct RelayBuildConfig {
-  int opt_level{2};
-  int fallback_device{static_cast<int>(kDLCPU)};
-  std::unordered_set<std::string> enabled_pass;
-  std::unordered_set<std::string> disabled_pass;
-  OptPassLevel OPT_PASS_LEVEL;
-  inline bool pass_enabled(const std::string& pass_name) const {
-    if (disabled_pass.count(pass_name)) {
-      return false;
-    }
-    if (enabled_pass.count(pass_name)) {
-      return true;
-    }
-    return opt_level >= OPT_PASS_LEVEL[pass_name];
-  }
-};
-
-/*!
  * \brief GraphCodegen module wrapper
  *
  */
@@ -156,18 +99,6 @@ struct GraphCodegen {
   }
 };
 
-template<typename R, typename ...Args>
-R CallPackedFunc(const std::string &name, Args... args) {
-  auto pf = GetPackedFunc(name);
-  return (*pf)(std::forward<Args>(args)...);
-}
-
-template<typename ...Args>
-Function CallPackedFunc(const std::string &name, Args... args) {
-  auto pf = GetPackedFunc(name);
-  return (*pf)(std::forward<Args>(args)...);
-}
-
 /*!
  * \brief Relay build module
  *
@@ -203,28 +134,6 @@ class RelayBuildModule : public runtime::ModuleNode {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         *rv = this->GetParams();
       });
-    } else if (name == "set_opt_level") {
-      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        CHECK_EQ(args.num_args, 1);
-        int level = args[0];
-        this->SetOptLevel(level);
-      });
-    } else if (name == "set_fallback_device") {
-      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        CHECK_EQ(args.num_args, 1);
-        int dev = args[0];
-        this->SetFallBackDev(dev);
-      });
-    } else if (name == "add_pass") {
-      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        std::string pass_name = args[0];
-        this->AddPass(pass_name);
-      });
-    } else if (name == "disable_pass") {
-      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        std::string pass_name = args[0];
-        this->DisablePass(pass_name);
-      });
     } else if (name == "set_params") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         Map<std::string, Constant> params = args[0];
@@ -246,30 +155,7 @@ class RelayBuildModule : public runtime::ModuleNode {
   const std::string& GetGraphJSON() {
     return ret_.graph_json;
   }
-  /*!
-   * \brief Add extra pass into build cfg
-   *
-   * \param pass_name name of pass
-   */
-  void AddPass(const std::string& pass_name) {
-    cfg_.enabled_pass.insert(pass_name);
-  }
-  /*!
-   * \brief Disable a specific pass in cfg
-   *
-   * \param pass_name name of pass
-   */
-  void DisablePass(const std::string& pass_name) {
-    cfg_.disabled_pass.insert(pass_name);
-  }
-  /*!
-   * \brief Set the Fallback device
-   *
-   * \param device name
-   */
-  void SetFallBackDev(int dev) {
-    cfg_.fallback_device = dev;
-  }
+
   /*!
    * \brief Get the Module object
    *
@@ -316,15 +202,6 @@ class RelayBuildModule : public runtime::ModuleNode {
   }
 
   /*!
-   * \brief Set the optimization level
-   *
-   * \param level
-   */
-  void SetOptLevel(char level) {
-    cfg_.opt_level = level;
-  }
-
-  /*!
    * \brief type key
    *
    * \return const char*
@@ -345,7 +222,7 @@ class RelayBuildModule : public runtime::ModuleNode {
              const tvm::Target& target_host) {
     targets_ = targets;
     target_host_ = target_host;
-    BuildRelay(func, cfg_, params_);
+    BuildRelay(func, params_);
   }
 
  protected:
@@ -378,85 +255,81 @@ class RelayBuildModule : public runtime::ModuleNode {
       if (repeat_var.count(arg)) {
         LOG(FATAL) << "Multiple args in the function have name " << kv.first;
       }
-      auto e = CallPackedFunc<Expr>("relay._make.Constant", kv.second);
-      bind_dict[arg] = e;
+      bind_dict[arg] = ConstantNode::make(kv.second);
     }
-    return CallPackedFunc("relay._expr.Bind", func, tvm::Map<relay::Var, Expr>(bind_dict));
+    Expr bound_expr = relay::Bind(func, bind_dict);
+    Function ret = Downcast<Function>(bound_expr);
+    CHECK(ret.defined())
+        << "The returning type is expected to be a Relay Function."
+        << "\n";
+    return ret;
   }
 
   /*!
-   * \brief Optimize Relay function
+   * \brief Optimize a Relay module.
    *
-   * \param func Input function
-   * \param target target device
-   * \param cfg Relay build config
-   * \param params params dict
-   * \return relay::Function
+   * \param relay_module The input Relay module where optmization will be
+   *        applied on.
+   * \param targets The device type to `Target` mapping.
+   * \param params The param name to value mapping.
+   *
+   * \return relay::Module The updated Relay module after optimization.
    */
-  relay::Function Optimize(relay::Function func,
-                           const TargetsMap& targets,
-                           const RelayBuildConfig& cfg,
-                           const std::unordered_map<std::string, runtime::NDArray>& params) {
-    if (params.size()) {
-      func = BindParamsByName(func, params);
-    }
-    if (cfg.pass_enabled("SimplifyInference")) {
-      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-      func = CallPackedFunc("relay._ir_pass.simplify_inference", func);
-    }
-    if (cfg.pass_enabled("EliminateCommonSubexpr")) {
-      auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
-        Expr expr = args[0];
-        if (expr.as<CallNode>()) {
-          auto call_node = expr.as<CallNode>();
-          auto op_node = call_node->op.as<OpNode>();
-          if (op_node->name == "cast") {
-            auto attrs = call_node->attrs.as<CastAttrs>();
-            if (attrs->dtype == HalideIR::Int(32)) {
-              *rv = true;
-            }
+  relay::Module Optimize(
+      relay::Module relay_module,
+      const TargetsMap& targets,
+      const std::unordered_map<std::string, runtime::NDArray>& params) {
+    Array<Pass> pass_seqs;
+    pass_seqs.push_back(transform::SimplifyInference());
+    PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
+      Expr expr = args[0];
+      if (expr.as<CallNode>()) {
+        auto call_node = expr.as<CallNode>();
+        auto op_node = call_node->op.as<OpNode>();
+        if (op_node->name == "cast") {
+          auto attrs = call_node->attrs.as<CastAttrs>();
+          if (attrs->dtype == HalideIR::Int(32)) {
+            *rv = true;
           }
         }
-        *rv =  false;
-      });
-      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-      func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip);
-    }
-    if (cfg.pass_enabled("CombineParallelConv2D")) {
-      const int min_num_branches = 3;
-      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-      func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches);
-    }
-    if (cfg.pass_enabled("FoldConstant")) {
-      func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
-    }
-    if (cfg.pass_enabled("FoldScaleAxis")) {
-      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-      func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func);
-      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-      func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func);
-      func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
-    }
-    if (cfg.pass_enabled("CanonicalizeOps")) {
-      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-      func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func);
+      }
+      *rv = false;
+    });
+    pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
+    pass_seqs.push_back(transform::CombineParallelConv2D(3));
+    pass_seqs.push_back(transform::FoldConstant());
+    pass_seqs.push_back(transform::FoldScaleAxis());
+    pass_seqs.push_back(transform::CanonicalizeOps());
+
+    // Alter layout transformation is only applied to homogeneous execution yet.
+    if (targets.size() == 1) {
+      pass_seqs.push_back(transform::AlterOpLayout());
     }
-    if (cfg.pass_enabled("AlterOpLayout")) {
-      if (targets.size() == 1) {
-        func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-        for (const auto& kv : targets) {
-          With<Target> tctx(kv.second);
-          func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
-        }
-      } else {
-        LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
-                  << " execution yet.";
+    pass_seqs.push_back(transform::FoldConstant());
+
+    // Create a sequential pass and perform optimizations.
+    transform::Pass seq = transform::Sequential(pass_seqs);
+    if (targets.size() == 1) {
+      for (const auto& kv : targets) {
+        With<Target> tctx(kv.second);
+        relay_module = seq(relay_module);
       }
+    } else {
+      relay_module = seq(relay_module);
     }
-    if (cfg.pass_enabled("FoldConstant")) {
-      func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
+
+    // Handle heterogeneous compilation.
+    transform::PassContext pass_ctx = PassContext::Current();
+    if (targets_.size() > 1) {
+      relay_module =
+          RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
     }
-    return func;
+
+    // Fuse the operations if it is needed.
+    relay_module = transform::FuseOps()(relay_module);
+    relay_module = transform::InferType()(relay_module);
+
+    return relay_module;
   }
 
   /*!
@@ -470,54 +343,58 @@ class RelayBuildModule : public runtime::ModuleNode {
     if (name == "gpu") return Target::Create("cuda");
     return Target::Create(name);
   }
+
   /*!
    * \brief Update the target and fallback device required for heterogeneous
    * compilation. CPU is used as the fallback device if it wasn't provided.
    * Meanwhile, a CPU device type and "llvm" pair will be added to the target
    * dictionary in this case.
    *
-   * \param targets dictionary
-   * \param cfg
-   * \return Map<tvm::Integer, tvm::Target>
+   * \param fallback_device The fallback device for heterogeneous execution.
    */
-  TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets,
-                                       const RelayBuildConfig& cfg) {
-    TargetsMap device_target = targets;
+  void UpdateHeterogeneousInputs(int fallback_device) {
     std::unordered_map<int64_t, tvm::Target> tmp_map;
-    for (const auto& kv : targets) {
+    for (const auto& kv : targets_) {
       tmp_map[kv.first->value] = kv.second;
     }
-    if (tmp_map.count(cfg.fallback_device) == 0) {
-      device_target.Set(
-          cfg.fallback_device,
-          CreateDefaultTarget(cfg.fallback_device));
+    if (tmp_map.count(fallback_device) == 0) {
+      targets_.Set(fallback_device, CreateDefaultTarget(fallback_device));
     }
-    return device_target;
   }
+
   /*!
    * \brief Execute the device annotation passes to update the input program and
    *        target information.
    *
-   * \param func
-   * \param cfg
-   * \param targets_map_ptr
-   * \return Function
+   * \param relay_module The input Relay module.
+   * \param fallback_device The fallback device for heterogeneous execution.
+   *
+   * \return updated_module The updated module after device annotation.
    */
-  Function RunDeviceAnnotationPass(Function func,
-                                   const RelayBuildConfig& cfg,
-                                   TargetsMap* targets_map_ptr) {
-    func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-    func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
-                          cfg.fallback_device);
-    auto device_map = CallPackedFunc<Map<Expr, Integer> >(
-        "relay._ir_pass.CollectDeviceInfo", func, nullptr);
-    if (device_map.size() == 0) {
-      auto annotation_map = CallPackedFunc<Map<Expr, Integer> >(
-          "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
-      if (annotation_map.size() == 0) {
-        targets_map_ptr->Set(
-            0, CreateDefaultTarget(cfg.fallback_device));
+  relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module,
+                                        int fallback_device) {
+    UpdateHeterogeneousInputs(fallback_device);
+    auto rewrite = transform::RewriteAnnotatedOps(fallback_device);
+    auto updated_module = rewrite(relay_module);
+    CHECK(updated_module.defined());
+
+    tvm::Map<Expr, Integer> device_map;
+    for (const auto& it : updated_module->functions) {
+      device_map = relay::CollectDeviceInfo(it.second);
+      if (!device_map.empty()) break;
+    }
+
+    if (device_map.empty()) {
+      tvm::Map<Expr, Integer> annotation_map;
+      for (const auto& it : relay_module->functions) {
+        annotation_map = relay::CollectDeviceAnnotationOps(it.second);
+        if (!annotation_map.empty()) break;
+      }
+      // None op is annotated but they are fallen back to the default device.
+      if (annotation_map.empty()) {
+        targets_.Set(0, CreateDefaultTarget(fallback_device));
       } else {
+        // All ops are annotated to the same device type.
         int64_t dev_type = -1;
         for (auto kv : annotation_map) {
           dev_type = kv.second->value;
@@ -531,47 +408,42 @@ class RelayBuildModule : public runtime::ModuleNode {
             << "found. Please check the "
             << "RewriteAnnotation pass.";
         }
-        targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
+        targets_.Set(0, CreateDefaultTarget(dev_type));
       }
     }
-    return func;
+    return updated_module;
   }
 
   /*!
    * \brief Build relay function to runtime module
    *
    * \param func Relay Function
-   * \param cfg Relay build config
    * \param params parameters
    */
-  void BuildRelay(Function func,
-                  const RelayBuildConfig& cfg,
-                  const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
-    // convert
-    tvm_cfg_ = BuildConfig::Create();
-    TargetsMap device_target;
-    if (targets_.size() > 1) {
-      device_target = UpdateHeterogeneousInputs(targets_, cfg);
-    } else {
-      device_target = targets_;
-    }
-    func = Optimize(func, targets_, cfg, params);
-    if (device_target.size() > 1) {
-      func = RunDeviceAnnotationPass(func, cfg, &device_target);
+  void BuildRelay(
+      Function func,
+      const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
+    if (params.size()) {
+      func = BindParamsByName(func, params);
     }
-    // TODO(@jroesch): use the passes directly.
-    func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-    func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level, nullptr);
-    func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
 
+    // Perform Module->Module optimizations.
+    relay::Module relay_module = relay::ModuleNode::FromExpr(func);
+    relay_module = Optimize(relay_module, targets_, params);
+    CHECK(relay_module.defined());
+    // Get the updated function.
+    func = relay_module->Lookup(relay_module->entry_func->name_hint);
+
+    // Generate code for the updated function.
     graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
-    graph_codegen_->Init(nullptr, device_target);
+    graph_codegen_->Init(nullptr, targets_);
     graph_codegen_->Codegen(func);
 
     ret_.graph_json = graph_codegen_->GetJSON();
     ret_.params = graph_codegen_->GetParams();
 
-    ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_);
+    ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_,
+                          BuildConfig::Current());
   }
 
  protected:
@@ -580,14 +452,10 @@ class RelayBuildModule : public runtime::ModuleNode {
   TargetsMap targets_;
   /*! \brief target host device */
   tvm::Target target_host_;
-  /*! \brief frontend optimization configure */
-  RelayBuildConfig cfg_;
   /*! \brief parameters */
   std::unordered_map<std::string, runtime::NDArray> params_;
   /*! \brief building output */
   BuildOutput ret_;
-  /*! \brief tvm building cfg */
-  BuildConfig tvm_cfg_;
 };
 
 runtime::Module RelayBuildCreate() {
index f51c201..d623393 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/relay/pass.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/transform.h>
 #include <tvm/tvm.h>
 #include <tuple>
 #include <vector>
@@ -338,17 +339,35 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
 // Limiations:
 // 1. the altered op should have the same number of arguments as the previous one
 // 2. do not support nested tuple arguments
-TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
+Expr AlterOpLayout(const Expr& expr) {
   TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>());
   auto fcontext = [&](const Call& call) -> NodeRef{
     return transformMemorizer;
   };
 
-  *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext);
-});
+  return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext);
+}
+
+TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
+.set_body_typed(AlterOpLayout);
 
 }  // namespace alter_op_layout
 
+namespace transform {
+
+Pass AlterOpLayout() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
+  };
+  return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
+                            {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.AlterOpLayout")
+.set_body_typed(AlterOpLayout);
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 9a46027..ff9e230 100644 (file)
@@ -26,6 +26,7 @@
 #include <tvm/relay/pass.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
 #include "pattern_util.h"
 
 namespace tvm {
@@ -63,5 +64,21 @@ Expr CanonicalizeOps(const Expr& e) {
 TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
 .set_body_typed(CanonicalizeOps);
 
+namespace transform {
+
+Pass CanonicalizeOps() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(CanonicalizeOps(f));
+  };
+  return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
+                            {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.CanonicalizeOps")
+.set_body_typed(CanonicalizeOps);
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 7e76322..c95c1dd 100644 (file)
@@ -38,6 +38,7 @@
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
 #include <unordered_map>
 #include <unordered_set>
 #include "./expr_subst.h"
@@ -357,5 +358,21 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
 TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
 .set_body_typed(CombineParallelConv2D);
 
+namespace transform {
+
+Pass CombineParallelConv2D(uint64_t min_num_branches) {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
+  };
+  return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
+                            {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.CombineParallelConv2D")
+.set_body_typed(CombineParallelConv2D);
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index dd1ed62..be67745 100644 (file)
@@ -158,9 +158,12 @@ Pass DeadCodeElimination() {
     [=](Function f, Module m, PassContext pc) {
     return Downcast<Function>(DeadCodeElimination(f));
   };
-  return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {});
+  return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
 }
 
+TVM_REGISTER_API("relay._transform.DeadCodeElimination")
+.set_body_typed(DeadCodeElimination);
+
 }  // namespace transform
 
 }  // namespace relay
index e2d0761..02d6d9e 100644 (file)
@@ -35,6 +35,7 @@
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pass.h>
+#include <tvm/relay/transform.h>
 
 #include <memory>
 #include <unordered_map>
@@ -564,11 +565,14 @@ Pass RewriteAnnotatedOps(int fallback_device) {
     [=](Function f, Module m, PassContext pc) {
     return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device));
   };
-  return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", {});
+  return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
+                            {ir::StringImm::make("InferType")});
 }
 
+TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation")
+.set_body_typed(RewriteAnnotatedOps);
+
 }  // namespace transform
 
 }  // namespace relay
 }  // namespace tvm
-
index f8432f6..883681a 100644 (file)
@@ -29,6 +29,7 @@
  */
 #include <tvm/relay/pass.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
 #include <unordered_map>
 #include "./pattern_util.h"
 
@@ -87,5 +88,21 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
 TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr")
 .set_body_typed<Expr(Expr, PackedFunc)>(EliminateCommonSubexpr);
 
+namespace transform {
+
+Pass EliminateCommonSubexpr(PackedFunc fskip) {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
+  };
+  return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
+                            {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr")
+.set_body_typed(EliminateCommonSubexpr);
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 286392a..8154070 100644 (file)
@@ -26,6 +26,7 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/transform.h>
 
 namespace tvm {
 namespace relay {
@@ -220,11 +221,14 @@ namespace transform {
 Pass FoldConstant() {
   runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
     [=](Function f, Module m, PassContext pc) {
-    return Downcast<Function>(FoldConstant(f));
+      return Downcast<Function>(FoldConstant(f));
   };
-  return CreateFunctionPass(pass_func, 1, "fold_constant", {});
+  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
 }
 
+TVM_REGISTER_API("relay._transform.FoldConstant")
+.set_body_typed(FoldConstant);
+
 }  // namespace transform
 
 }  // namespace relay
index c738e3e..5308980 100644 (file)
@@ -29,6 +29,7 @@
 #include <tvm/relay/pass.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
 #include "pattern_util.h"
 #include "pass_util.h"
 
@@ -530,7 +531,7 @@ RELAY_REGISTER_OP("nn.conv2d")
 .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
 
 
-Expr ForwardFoldScaleAxis(Expr data) {
+Expr ForwardFoldScaleAxis(const Expr& data) {
   auto message = ForwardPrep().Prepare(data);
   auto fcontext = [&](const Call& call) -> NodeRef{
     auto it = message.find(call.get());
@@ -942,7 +943,7 @@ RELAY_REGISTER_OP("nn.conv2d")
 RELAY_REGISTER_OP("nn.conv2d")
 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
 
-Expr BackwardFoldScaleAxis(Expr data) {
+Expr BackwardFoldScaleAxis(const Expr& data) {
   return make_node<BackwardTransformerNode>()->Fold(data);
 }
 
@@ -950,5 +951,42 @@ TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis")
 .set_body_typed<Expr(Expr)>(BackwardFoldScaleAxis);
 
 }  // namespace fold_scale_axis
+
+namespace transform {
+
+Pass ForwardFoldScaleAxis() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(
+          relay::fold_scale_axis::ForwardFoldScaleAxis(f));
+  };
+  return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
+                            {ir::StringImm::make("InferType")});
+}
+
+Pass BackwardFoldScaleAxis() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(
+          relay::fold_scale_axis::BackwardFoldScaleAxis(f));
+    };
+  return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
+                            {ir::StringImm::make("InferType")});
+}
+
+Pass FoldScaleAxis() {
+  // FoldScaleAxis pass contains the following three passes. Therefore, we can
+  // register it as a sequential pass.
+  Pass pass = Sequential(
+      {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
+      "FoldScaleAxis");
+  return pass;
+}
+
+TVM_REGISTER_API("relay._transform.FoldScaleAxis")
+.set_body_typed(FoldScaleAxis);
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 2a3aa16..8ad6127 100644 (file)
@@ -220,7 +220,7 @@ Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
                                              fcontext,
                                              fmulti_ref_trigger));
   };
-  return CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
+  return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {});
 }
 
 Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
@@ -233,7 +233,7 @@ Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
                                              fcontext,
                                              fmulti_ref_trigger));
   };
-  return CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
+  return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {});
 }
 
 }  // namespace transform
index 9277689..9f940e5 100644 (file)
@@ -29,6 +29,7 @@
 #include <tvm/relay/pass.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
 #include "./pattern_util.h"
 #include "../../common/arena.h"
 
@@ -973,9 +974,13 @@ Pass FuseOps(int fuse_opt_level) {
     int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
     return Downcast<Function>(FuseOps(f, opt_level, m));
   };
-  return CreateFunctionPass(pass_func, 1, "fuse_ops", {});
+  return CreateFunctionPass(pass_func, 1, "FuseOps",
+                            {ir::StringImm::make("InferType")});
 }
 
+TVM_REGISTER_API("relay._transform.FuseOps")
+.set_body_typed(FuseOps);
+
 }  // namespace transform
 
 }  // namespace relay
index 3f42c6f..71ba7cd 100644 (file)
@@ -797,9 +797,7 @@ Expr PartialEval(const Expr& e) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = PartialEval(args[0]);
-  });
+.set_body_typed(PartialEval);
 
 namespace transform {
 
@@ -808,9 +806,12 @@ Pass PartialEval() {
     [=](Function f, Module m, PassContext pc) {
     return Downcast<Function>(PartialEval(f));
   };
-  return CreateFunctionPass(pass_func, 1, "partial_eval", {});
+  return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
 }
 
+TVM_REGISTER_API("relay._transform.PartialEvaluate")
+.set_body_typed(PartialEval);
+
 }  // namespace transform
 
 }  // namespace relay
index a9c671a..13e908d 100644 (file)
@@ -37,42 +37,46 @@ namespace transform {
 
 using tvm::IRPrinter;
 
-/*!
- * \brief A data structure to map the names of specific optimizations to
- *        numeric optimization levels
- */
-class OptPassLevel {
- public:
-  /*!
-   * \brief Get level for an optimization pass
-   *
-   * \param key pass name
-   * \return int level
-   */
-  int operator[](const std::string& key) const {
-    const auto data = CreateMap();
-    auto it = data.find(key);
-    if (it == data.end()) {
-      return -1;
-    }
-    return it->second;
+namespace {
+
+// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
+// handled because we need to register the pass for Python invocation anyway.
+Pass GetPass(const std::string& pass_name) {
+  if (pass_name == "InferType") {
+    return InferType();
+  } else if (pass_name == "AlterOpLayout") {
+    return AlterOpLayout();
+  } else if (pass_name == "CanonicalizeOps") {
+    return CanonicalizeOps();
+  } else if (pass_name == "CombineParallelConv2d") {
+    return CombineParallelConv2D();
+  } else if (pass_name == "DeadCodeElimination") {
+    return DeadCodeElimination();
+  } else if (pass_name == "EliminateCommonSubexpr") {
+    return DeadCodeElimination();
+  } else if (pass_name == "FoldConstant") {
+    return FoldConstant();
+  } else if (pass_name == "BackwardFoldScaleAxis") {
+    return FoldScaleAxis();
+  } else if (pass_name == "ForwardFoldScaleAxis") {
+    return FoldScaleAxis();
+  } else if (pass_name == "FoldScaleAxis") {
+    return FoldScaleAxis();
+  } else if (pass_name == "PartialEvaluate") {
+    return SimplifyInference();
+  } else if (pass_name == "SimplifyInference") {
+    return SimplifyInference();
+  } else if (pass_name == "ToANormalForm") {
+    return ToANormalForm();
+  } else if (pass_name == "ToGraphNormalForm") {
+    return ToGraphNormalForm();
+  } else {
+    LOG(FATAL) << pass_name << " has not been registered yet." << "\n";
+    return Pass(nullptr);
   }
+}
 
- private:
-  static const std::unordered_map<std::string, int> CreateMap() {
-    const std::unordered_map<std::string, int> m = {
-      {"SimplifyInference", 0},
-      {"OpFusion", 1},
-      {"FoldConstant", 2},
-      {"CombineParallelConv2D", 3},
-      {"FoldScaleAxis", 3},
-      {"AlterOpLayout", 3},
-      {"CanonicalizeOps", 3},
-      {"EliminateCommonSubexpr", 3}
-    };
-    return m;
-  }
-};
+}  // namespace
 
 struct RelayPassContextThreadLocalEntry {
   /*! \brief The default pass context. */
@@ -246,12 +250,6 @@ class SequentialNode : public PassNode {
   /* \brief The pass meta data.*/
   PassInfo pass_info;
 
-  /*!
-   * \brief A helper struct to get the optimization pass name to opt level
-   * mapping.
-   */
-  OptPassLevel opt_pass_level;
-
   /*! \brief A list of passes that used to compose a sequential pass. */
   tvm::Array<Pass> passes;
   void VisitAttrs(tvm::AttrVisitor* v) final {
@@ -300,7 +298,7 @@ class SequentialNode : public PassNode {
       const Array<tvm::Expr>& disabled) const;
 
   std::unordered_set<std::string> RequiredPasses(
-      const Array<tvm::Expr>& disabled) const;
+      const Array<tvm::Expr>& required) const;
 
   /*!
    * \brief Perform optimizations on a series of passes. The aforementioned
@@ -338,14 +336,25 @@ ModulePass ModulePassNode::make(
 }
 
 // Module -> Module optimizations.
-// TODO(zhiics) Check and handle the required passes.
 Module ModulePassNode::operator()(const Module& mod,
                                   const PassContext& pass_ctx) const {
   PassInfo pass_info = Info();
   DLOG(INFO) << "Executing module pass : " << pass_info->name
              << " with opt level: " << pass_info->opt_level << "\n";
+
   CHECK(mod.defined());
-  auto updated_mod = pass_func(mod, pass_ctx);
+  Module updated_mod = mod;
+  // Execute the required passes in a DFS way.
+  // TODO(zhiics) We may need to pass validation to detect the cyclic
+  // dependency.
+  for (const auto& it : pass_info->required) {
+    const auto* name = it.as<tvm::ir::StringImm>();
+    CHECK(name);
+    auto pass = GetPass(name->value);
+    updated_mod = pass(updated_mod, pass_ctx);
+  }
+
+  updated_mod = pass_func(updated_mod, pass_ctx);
   CHECK(updated_mod.defined());
   return updated_mod;
 }
@@ -365,12 +374,26 @@ Module FunctionPassNode::operator()(const Module& mod,
                                     const PassContext& pass_ctx) const {
   PassInfo pass_info = Info();
   CHECK(mod.defined());
-  Module new_mod = ModuleNode::make({}, mod->type_definitions);
   DLOG(INFO) << "Executing module pass : " << pass_info->name
              << " with opt level: " << pass_info->opt_level << "\n";
+
+  Module updated_mod = mod;
+  // Execute the required passes in a DFS way.
+  // TODO(zhiics) We may need to pass validation to detect the cyclic
+  // dependency.
+  for (const auto& it : pass_info->required) {
+    const auto* name = it.as<tvm::ir::StringImm>();
+    CHECK(name);
+    auto pass = GetPass(name->value);
+    updated_mod = pass(updated_mod, pass_ctx);
+  }
+
+  Module new_mod = ModuleNode::make({}, mod->type_definitions);
   // Execute the pass function and return a new module.
   for (const auto& it : mod->functions) {
-    auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx);
+    auto updated_func = SkipFunction(it.second)
+                            ? it.second
+                            : pass_func(it.second, updated_mod, pass_ctx);
     new_mod->Add(it.first, updated_func);
   }
 
@@ -418,7 +441,7 @@ std::unordered_set<std::string> SequentialNode::DisabledPasses(
   std::unordered_set<std::string> ret;
   for (const auto& it : disabled) {
     const auto* str = it.as<tvm::ir::StringImm>();
-    CHECK(str) << "disabled passes must be string.";
+    CHECK(str) << "Disabled pass name must be string.";
     ret.emplace(str->value);
   }
   return ret;
@@ -429,7 +452,7 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
   std::unordered_set<std::string> ret;
   for (const auto& it : required) {
     const auto* str = it.as<tvm::ir::StringImm>();
-    CHECK(str) << "disabled passes must be string.";
+    CHECK(str) << "Required pass name must be string.";
     ret.emplace(str->value);
   }
   return ret;
@@ -439,7 +462,7 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const {
   PassContext ctx = PassContext::Current();
 
   auto required = RequiredPasses(ctx->required_pass);
-  auto disabled = DisabledPasses(ctx->required_pass);
+  auto disabled = DisabledPasses(ctx->disabled_pass);
 
   if (disabled.count(pass_name)) {
     return false;
@@ -448,29 +471,27 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const {
   if (required.count(pass_name)) {
     return true;
   }
-  return ctx->opt_level >= opt_pass_level[pass_name];
+
+  const Pass pass = GetPass(pass_name);
+  PassInfo info = pass->Info();
+  return ctx->opt_level >= info->opt_level;
 }
 
 // TODO(zhiics): we currenlty only sequentially execute each pass in
 // a Sequential without the consideration of their orders. The phase
-// ordering problem needed to be handled in the future.
+// ordering problem needs to be handled in the future.
 Module SequentialNode::operator()(const Module& module,
                                   const PassContext& pass_ctx) const {
-  int opt_level = pass_ctx->opt_level;
-  auto disabled = DisabledPasses(pass_ctx->disabled_pass);
   Module mod = module;
   for (const Pass& pass : passes) {
     CHECK(pass.defined()) << "Found undefined pass for optimization.";
+
     PassInfo info = pass->Info();
     const auto& pass_name = info->name;
-    const auto& pass_opt_level = info->opt_level;
-    // Skip the pass if its optimization level is higher that the  one of in the
-    // pass context or if this pass is disabled.
-    if (pass_opt_level > opt_level || disabled.count(pass_name)) {
-      continue;
+    // Execute the pass if it is enabled.
+    if (PassEnabled(pass_name)) {
+      mod = pass(mod, pass_ctx);
     }
-    const auto* pn = pass.operator->();
-    mod = (*pn)(mod, pass_ctx);
   }
   return mod;
 }
@@ -525,15 +546,17 @@ TVM_REGISTER_API("relay._transform.CreateModulePass")
 
 TVM_REGISTER_API("relay._transform.RunPass")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = args[0].operator Pass()(args[1]);
+  Pass pass = args[0];
+  Module mod = args[1];
+  *ret = pass(mod);
 });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ModulePassNode>([](const ModulePassNode* node,
                                  tvm::IRPrinter* p) {
-  const PassInfoNode* pn = node->Info().operator->();
-  p->stream << "Run Module pass: " << pn->name
-            << " at the optimization level " << pn->opt_level;
+  const PassInfo info = node->Info();
+  p->stream << "Run Module pass: " << info->name
+            << " at the optimization level " << info->opt_level;
 });
 
 TVM_REGISTER_NODE_TYPE(FunctionPassNode);
@@ -544,9 +567,9 @@ TVM_REGISTER_API("relay._transform.CreateFunctionPass")
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
                                    tvm::IRPrinter* p) {
-  const PassInfoNode* pn = node->Info().operator->();
-  p->stream << "Run Function pass: " << pn->name
-            << " at the optimization level " << pn->opt_level;
+  const PassInfo info = node->Info();
+  p->stream << "Run Function pass: " << info->name
+            << " at the optimization level " << info->opt_level;
 });
 
 TVM_REGISTER_NODE_TYPE(SequentialNode);
@@ -564,14 +587,13 @@ TVM_REGISTER_API("relay._transform.Sequential")
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<SequentialNode>([](const SequentialNode* node,
                                  tvm::IRPrinter* p) {
-  const PassInfoNode* seq_pn = node->Info().operator->();
-  p->stream << "Run Sequential pass: " << seq_pn->name
-            << " at the optimization level " << seq_pn->opt_level << ". ";
+  const PassInfo info = node->Info();
+  p->stream << "Run Sequential pass: " << info->name
+            << " at the optimization level " << info->opt_level << ". ";
   p->stream << "The passes will be executed are: [";
   for (const auto& it : node->passes) {
-    const PassNode* pn = it.operator->();
-    const PassInfoNode* pass_info_node = pn->Info().operator->();
-    p->stream << pass_info_node->name << " ";
+    const PassInfo pass_info = it->Info();
+    p->stream << pass_info->name << " ";
   }
   p->stream << "]";
 });
index 8dab0c3..6d6b24a 100644 (file)
@@ -24,6 +24,7 @@
 #include <tvm/relay/pass.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
 #include "./pattern_util.h"
 
 namespace tvm {
@@ -105,5 +106,21 @@ Expr SimplifyInference(const Expr& e) {
 TVM_REGISTER_API("relay._ir_pass.simplify_inference")
 .set_body_typed(SimplifyInference);
 
+namespace transform {
+
+Pass SimplifyInference() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(SimplifyInference(f));
+  };
+  return CreateFunctionPass(pass_func, 0, "SimplifyInference",
+                            {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.SimplifyInference")
+.set_body_typed(SimplifyInference);
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index f9d47f7..324eddd 100644 (file)
@@ -340,9 +340,12 @@ Pass ToANormalForm() {
     [=](Function f, Module m, PassContext pc) {
     return Downcast<Function>(ToANormalForm(f, m));
   };
-  return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {});
+  return CreateFunctionPass(pass_func, 1, "ToANormalForm", {});
 }
 
+TVM_REGISTER_API("relay._transform.ToANormalForm")
+.set_body_typed(ToANormalForm);
+
 }  // namespace transform
 
 }  // namespace relay
index 50ebb70..9c166f9 100644 (file)
@@ -86,9 +86,12 @@ Pass ToGraphNormalForm() {
     [=](Function f, Module m, PassContext pc) {
     return Downcast<Function>(ToGraphNormalForm(f));
   };
-  return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {});
+  return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {});
 }
 
+TVM_REGISTER_API("relay._transform.ToGraphNormalForm")
+.set_body_typed(ToGraphNormalForm);
+
 }  // namespace transform
 
 }  // namespace relay
index 482cef3..3fde3c7 100644 (file)
@@ -43,6 +43,7 @@
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/pass.h>
+#include <tvm/relay/transform.h>
 #include "./pass_util.h"
 #include "type_solver.h"
 #include "../ir/type_functor.h"
@@ -807,5 +808,23 @@ TVM_REGISTER_API("relay._ir_pass.infer_type")
 .set_body_typed<Expr(const Expr&, const Module&)>([](const Expr& expr, const Module& mod_ref) {
     return InferType(expr, mod_ref);
   });
+
+namespace transform {
+
+Pass InferType() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+      return Downcast<Function>(InferType(f, m));
+  };
+  return CreateFunctionPass(pass_func, 0, "InferType", {});
+}
+
+TVM_REGISTER_API("relay._transform.InferType")
+.set_body_typed<Pass()>([]() {
+  return InferType();
+});
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc
new file mode 100644 (file)
index 0000000..b61a5cc
--- /dev/null
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <topi/generic/injective.h>
+#include <tvm/build_module.h>
+#include <tvm/packed_func_ext.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/module.h>
+#include <tvm/relay/pass.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tvm.h>
+
+TVM_REGISTER_GLOBAL("schedule")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      *rv = topi::generic::schedule_injective(args[0], args[1]);
+    });
+
+TEST(Relay, Sequential) {
+  using namespace tvm;
+  auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, ::tvm::Float(32));
+  auto c_data =
+      tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+
+  // Create a function for optimization.
+  auto c = relay::ConstantNode::make(c_data);
+  auto a = relay::VarNode::make("a", tensor_type);
+  auto x = relay::VarNode::make("x", tensor_type);
+  auto add_op = relay::Op::Get("add");
+  auto y = relay::CallNode::make(add_op, {c, c});
+  y = relay::CallNode::make(add_op, {x, y});
+  auto z = relay::CallNode::make(add_op, {y, c});
+  auto z1 = relay::CallNode::make(add_op, {y, c});
+  auto z2 = relay::CallNode::make(add_op, {z, z1});
+  // Let expression and varaible a should be dead-code eliminated.
+  auto z3 = relay::LetNode::make(a, c, z2);
+  relay::Function func =
+      relay::FunctionNode::make(relay::FreeVars(z3), z3, relay::Type(), {});
+
+  // Get schedule
+  auto reg = tvm::runtime::Registry::Get("relay.op._Register");
+  auto sch = tvm::runtime::Registry::Get("schedule");
+  if (!reg || !sch) {
+    LOG(FATAL) << "Register/schedule is not defined.";
+  }
+
+  (*reg)("add", "FTVMSchedule", *sch, 10);
+
+  // Run sequential passes.
+  tvm::Array<relay::transform::Pass> pass_seqs{
+      relay::transform::InferType(),
+      relay::transform::DeadCodeElimination(),
+      relay::transform::EliminateCommonSubexpr(),
+      relay::transform::AlterOpLayout()
+  };
+  relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
+  auto mod = relay::ModuleNode::FromExpr(func);
+  auto pass_ctx = relay::transform::PassContext::Create();
+  pass_ctx->opt_level = 3;
+  pass_ctx->fallback_device = 1;
+  {
+    tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
+    tvm::With<tvm::Target> tctx(tvm::Target::Create("llvm"));
+    mod = seq(mod);
+  }
+
+  CHECK(mod.defined());
+  auto entry_func = mod->entry_func;
+  CHECK(entry_func.defined());
+  relay::Function f = mod->Lookup(entry_func->name_hint);
+  CHECK(f.defined());
+
+  // Expected function
+  auto c1 = relay::ConstantNode::make(c_data);
+  auto x1 = relay::VarNode::make("x", tensor_type);
+  auto y1 = relay::CallNode::make(add_op, {c1, c1});
+  y1 = relay::CallNode::make(add_op, {x1, y1});
+  auto zz = relay::CallNode::make(add_op, {y1, c1});
+  zz = relay::CallNode::make(add_op, {zz, zz});
+  relay::Function expected_func =
+      relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {});
+
+  // Infer type for the expected function.
+  auto expected = relay::InferType(expected_func, relay::Module(nullptr));
+  CHECK(relay::AlphaEqual(f, expected));
+}
+
+int main(int argc, char** argv) {
+  testing::InitGoogleTest(&argc, argv);
+  testing::FLAGS_gtest_death_test_style = "threadsafe";
+  return RUN_ALL_TESTS();
+}
index 2703e5c..7fdef3f 100644 (file)
@@ -327,7 +327,8 @@ def test_sequential_pass():
     def test_only_module_pass():
         passes = [module_pass]
         sequential = _transform.Sequential(opt_level=1, passes=passes)
-        ret_mod = sequential(mod)
+        with relay.build_config(required_pass=["mod_transform"]):
+            ret_mod = sequential(mod)
         # Check the subtract function.
         sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
         check_func(new_sub, sub)
@@ -341,7 +342,8 @@ def test_sequential_pass():
         # Check the subtract function.
         passes = [function_pass]
         sequential = _transform.Sequential(opt_level=1, passes=passes)
-        ret_mod = sequential(mod)
+        with relay.build_config(required_pass=["func_transform"]):
+            ret_mod = sequential(mod)
         _, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
         check_func(new_sub, get_ref_sub())
 
@@ -355,7 +357,9 @@ def test_sequential_pass():
         mod = relay.Module({v_sub: sub, v_log: log})
         passes = [module_pass, function_pass]
         sequential = _transform.Sequential(opt_level=1, passes=passes)
-        ret_mod = sequential(mod)
+        required = ["mod_transform", "func_transform"]
+        with relay.build_config(required_pass=required):
+            ret_mod = sequential(mod)
 
         # Check the abs function is added.
         abs_var, abs_func = get_var_func()
@@ -400,7 +404,48 @@ def test_sequential_pass():
     test_multiple_passes()
 
 
+def test_sequential_with_scoping():
+    shape = (1, 2, 3)
+    c_data = np.array(shape).astype("float32")
+    tp = relay.TensorType(shape, "float32")
+    def before():
+        c = relay.const(c_data)
+        x = relay.var("x", tp)
+        y = relay.add(c, c)
+        y = relay.multiply(y, relay.const(2, "float32"))
+        y = relay.add(x, y)
+        z = relay.add(y, c)
+        z1 = relay.add(y, c)
+        z2 = relay.add(z, z1)
+        return relay.Function([x], z2)
+
+    def expected():
+        x = relay.var("x", tp)
+        c_folded = (c_data + c_data) * 2
+        y = relay.add(x, relay.const(c_folded))
+        z = relay.add(y, relay.const(c_data))
+        z1 = relay.add(z, z)
+        return relay.Function([x], z1)
+
+    seq = _transform.Sequential([
+        relay.transform.InferType(),
+        relay.transform.FoldConstant(),
+        relay.transform.EliminateCommonSubexpr(),
+        relay.transform.AlterOpLayout()
+    ])
+
+    mod = relay.Module({"main": before()})
+    with relay.build_config(opt_level=3):
+        with tvm.target.create("llvm"):
+            mod = seq(mod)
+
+    zz = mod["main"]
+    zexpected = ir_pass.infer_type(expected())
+    assert relay.ir_pass.alpha_equal(zz, zexpected)
+
+
 if __name__ == "__main__":
     test_module_pass()
     test_function_pass()
     test_sequential_pass()
+    test_sequential_with_scoping()