[relay][pass manager] Open transform namespace (#3226)
authorZhi <5145158+zhiics@users.noreply.github.com>
Wed, 22 May 2019 20:52:52 +0000 (13:52 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 22 May 2019 20:52:52 +0000 (13:52 -0700)
include/tvm/relay/pass.h
include/tvm/relay/transform.h [new file with mode: 0644]
python/tvm/relay/__init__.py
python/tvm/relay/_ir_pass.pyi
python/tvm/relay/_transform.py [new file with mode: 0644]
python/tvm/relay/ir_pass.py
python/tvm/relay/transform.py [new file with mode: 0644]
python/tvm/relay/transform.pyi [new file with mode: 0644]
src/relay/pass/pass_manager.cc
tests/python/relay/test_pass_manager.py

index 3106792..c84e3f9 100644 (file)
 /*!
  * \file tvm/relay/pass.h
  * \brief The set of Relay passes written in C++.
- *
- * This file also implements a pass manager. The pass manager manages a sequence
- * of Relay-to-Relay transformation passes over a particlar unit of AST. The
- * design is largely inspired from LLVM's pass manager and modern deep learning
- * frameworks that perform tensor->tensor transformations.
- *
- * The responsibilities of a traditional compiler pass manager usually involves:
- *  - Organizing the execution order of optimization passes though not
- * necessarily in the optimal sequence.
- *  - Collecting required analysis information and keep them up-to-date.
- *  - Reducing the effort required to implement new passes for compiler
- * developers, etc.
- *
- * Similar to LLVM's pass manager, we designed the Relay pass manager to work
- * different granularity, i.e. module level, function level, and even sequential
- * passe that contains a host of passes.
- *
- * However, we also extend the functionality of the traditional pass manager
- * with the consideration of requirements/convention from deep learning
- * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
- * manager performs the Relay.Module -> Relay.Module transformation. All
- * different types of passes, including the sequential-level pass object, are
- * essentially pass objects. This design, therefore, effectively provides users
- * a consistent and convenient interface, i.e. Pass, to play with. It offers a
- * means to ease the development and testing of Relay passes. For example, with
- * the pass manager, external users will be able to have custom passes correctly
- * scheduled without having to modify a single handcrafted pass order.
- *
- * In the future we need to describe constraints between passes. For example,
- * we may want to preserve dependencies between different passes and validate
- * them on the completion of a certain pass.
- *
- * We also need to store side information and import the error reporting system.
- */
+  */
 #ifndef TVM_RELAY_PASS_H_
 #define TVM_RELAY_PASS_H_
 
 #include <tvm/ir.h>
 #include <tvm/packed_func_ext.h>
-#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/module.h>
 #include <tvm/relay/op_attr_types.h>
 namespace tvm {
 namespace relay {
 
-namespace pass {
-
-/*
- * \brief The context of pass.
- */
-class PassContext;
-
-/*!
- * \brief PassContextNode contains the information that a pass can rely on, such as
- * analysis results.
- */
-class PassContextNode : public RelayNode {
- public:
-  /*!
-   * \brief The error reporter used to notify users why an optimization fails.
-   */
-  ErrorReporter err_reporter;
-
-  PassContextNode() = default;
-
-  void VisitAttrs(tvm::AttrVisitor* v) final {
-  }
-
-  TVM_DLL static PassContext make();
-
-  static constexpr const char* _type_key = "relay.PassContext";
-  TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
-};
-
-TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
-
-/*
- * \brief The meta data of a pass.
- *
- * PassInfo can be extended conveniently in the future if more meta information
- * is needed.
- */
-class PassInfo;
-
-/*!
- * \brief PassInfoNode contains meta data that will be used to help optimization
- * and analysis.
- */
-class PassInfoNode : public RelayNode {
- public:
-  /*! \brief The minimal optimization level that this pass will be enabled. */
-  int opt_level;
-
-  /*! \brief The name of an optimization/analysis pass. */
-  std::string name;
-
-  /*! \brief The passes that are required to perform the current pass. */
-  tvm::Array<tvm::Expr> required;
-
-  PassInfoNode() = default;
-
-  void VisitAttrs(tvm::AttrVisitor* v) final {
-    v->Visit("opt_level", &opt_level);
-    v->Visit("name", &name);
-    v->Visit("required", &required);
-  }
-
-  TVM_DLL static PassInfo make(int opt_level, std::string name,
-                               tvm::Array<tvm::Expr> required);
-
-  static constexpr const char* _type_key = "relay.PassInfo";
-  TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
-};
-
-TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
-
-class Pass;
-
-/*!
- * \brief PassNode is the base type of differnt types of optimization passes.
- * It is designed as a pure class and implemented by different pass subclasses
- * at different granularity of Relay nodes.
- */
-class PassNode : public RelayNode {
- public:
-  /*
-   * \brief Get the pass information/meta data. */
-  virtual PassInfo Info() const = 0;
-
-  /*!
-   * \brief Set the context information for a pass.
-   *
-   * \param pass_ctx The context information for a certain pass.
-   */
-  virtual void SetContext(const PassContext& pass_ctx) = 0;
-
-  /*!
-   * \brief Execute the optimization pass using a functor.
-   *
-   * \param mod The module that an optimization pass runs on.
-   *
-   * \return The updated module.
-   */
-  virtual Module operator()(const Module& mod) const = 0;
-
-  void VisitAttrs(tvm::AttrVisitor* v) override {}
-
-  static constexpr const char* _type_key = "relay.Pass";
-  TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
-};
-
-class Pass : public NodeRef {
- public:
-  Pass() = default;
-  explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
-
-  PassNode* operator->() const {
-    return static_cast<PassNode*>(this->node_.get());
-  }
-
-  using ContainerType = PassNode;
-};
-
-/*
- * \brief Create a module pass.
- *
- * \param pass_func The packed function that contains the optimization.
- * \param opt_level The optimization level of the module pass.
- * \param name The name of the module pass.
- * \param required The list of the passes that the module pass is dependent on.
- *
- * \return The created module pass.
- */
-Pass CreateModulePass(
-    const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
-    int opt_level,
-    const std::string& name,
-    const tvm::Array<tvm::Expr>& required);
-
-/*
- * \brief Create a function pass.
- *
- * \param pass_func The packed function that contains the optimization.
- * \param opt_level The optimization level of the function pass.
- * \param name The name of the function pass.
- * \param required The list of the passes that the function pass is dependent on.
- *
- * \return The created function pass.
- */
-Pass CreateFunctionPass(
-    const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
-    int opt_level,
-    const std::string& name,
-    const tvm::Array<tvm::Expr>& required);
-/*
- * \brief Create a sequential pass.
- *
- * \param passes The optimization passes will be performed.
- * \param opt_level The optimization level of the sequential pass.
- * \param name The name of the sequential pass.
- * \param required The list of the passes that the sequential pass is dependent on.
- * \param disabled The disabled passes.
- *
- * \return The created sequential pass.
- */
-Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
-                          int opt_level,
-                          const std::string& name,
-                          const tvm::Array<tvm::Expr>& required,
-                          const tvm::Array<tvm::Expr>& disabled);
-
-}  // namespace pass
-
 /*!
  * \brief Infer the type of an expression.
  *
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
new file mode 100644 (file)
index 0000000..ba25483
--- /dev/null
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/transform.h
+ *
+ * This file implements a pass manager. The pass manager manages a sequence
+ * of Relay-to-Relay transformation passes over a particlar unit of AST. The
+ * design is largely inspired from LLVM's pass manager and modern deep learning
+ * frameworks that perform tensor->tensor transformations.
+ *
+ * The responsibilities of a traditional compiler pass manager usually involves:
+ *  - Organizing the execution order of optimization passes though not
+ * necessarily in the optimal sequence.
+ *  - Collecting required analysis information and keep them up-to-date.
+ *  - Reducing the effort required to implement new passes for compiler
+ * developers, etc.
+ *
+ * Similar to LLVM's pass manager, we designed the Relay pass manager to work
+ * different granularity, i.e. module level, function level, and even sequential
+ * passe that contains a host of passes.
+ *
+ * However, we also extend the functionality of the traditional pass manager
+ * with the consideration of requirements/convention from deep learning
+ * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
+ * manager performs the Relay.Module -> Relay.Module transformation. All
+ * different types of passes, including the sequential-level pass object, are
+ * essentially pass objects. This design, therefore, effectively provides users
+ * a consistent and convenient interface, i.e. Pass, to play with. It offers a
+ * means to ease the development and testing of Relay passes. For example, with
+ * the pass manager, external users will be able to have custom passes correctly
+ * scheduled without having to modify a single handcrafted pass order.
+ *
+ * In the future we need to describe constraints between passes. For example,
+ * we may want to preserve dependencies between different passes and validate
+ * them on the completion of a certain pass.
+ *
+ * We also need to store side information and import the error reporting system.
+ */
+#ifndef TVM_RELAY_TRANSFORM_H_
+#define TVM_RELAY_TRANSFORM_H_
+
+#include <tvm/packed_func_ext.h>
+#include <tvm/relay/error.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/module.h>
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+/*
+ * \brief The context of pass.
+ */
+class PassContext;
+
+/*!
+ * \brief PassContextNode contains the information that a pass can rely on, such as
+ * analysis results.
+ */
+class PassContextNode : public RelayNode {
+ public:
+  /*!
+   * \brief The error reporter used to notify users why an optimization fails.
+   */
+  ErrorReporter err_reporter;
+
+  PassContextNode() = default;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+  }
+
+  TVM_DLL static PassContext make();
+
+  static constexpr const char* _type_key = "relay.PassContext";
+  TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
+};
+
+TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
+
+/*
+ * \brief The meta data of a pass.
+ *
+ * PassInfo can be extended conveniently in the future if more meta information
+ * is needed.
+ */
+class PassInfo;
+
+/*!
+ * \brief PassInfoNode contains meta data that will be used to help optimization
+ * and analysis.
+ */
+class PassInfoNode : public RelayNode {
+ public:
+  /*! \brief The minimal optimization level that this pass will be enabled. */
+  int opt_level;
+
+  /*! \brief The name of an optimization/analysis pass. */
+  std::string name;
+
+  /*! \brief The passes that are required to perform the current pass. */
+  tvm::Array<tvm::Expr> required;
+
+  PassInfoNode() = default;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("opt_level", &opt_level);
+    v->Visit("name", &name);
+    v->Visit("required", &required);
+  }
+
+  TVM_DLL static PassInfo make(int opt_level, std::string name,
+                               tvm::Array<tvm::Expr> required);
+
+  static constexpr const char* _type_key = "relay.PassInfo";
+  TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
+};
+
+TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
+
+class Pass;
+
+/*!
+ * \brief PassNode is the base type of differnt types of optimization passes.
+ * It is designed as a pure class and implemented by different pass subclasses
+ * at different granularity of Relay nodes.
+ */
+class PassNode : public RelayNode {
+ public:
+  /*
+   * \brief Get the pass information/meta data. */
+  virtual PassInfo Info() const = 0;
+
+  /*!
+   * \brief Set the context information for a pass.
+   *
+   * \param pass_ctx The context information for a certain pass.
+   */
+  virtual void SetContext(const PassContext& pass_ctx) = 0;
+
+  /*!
+   * \brief Execute the optimization pass using a functor.
+   *
+   * \param mod The module that an optimization pass runs on.
+   *
+   * \return The updated module.
+   */
+  virtual Module operator()(const Module& mod) const = 0;
+
+  void VisitAttrs(tvm::AttrVisitor* v) override {}
+
+  static constexpr const char* _type_key = "relay.Pass";
+  TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
+};
+
+class Pass : public NodeRef {
+ public:
+  Pass() = default;
+  explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
+
+  PassNode* operator->() const {
+    return static_cast<PassNode*>(this->node_.get());
+  }
+
+  using ContainerType = PassNode;
+};
+
+class SequentialNode;
+
+class Sequential : public Pass {
+ public:
+  /*!
+   * \brief The constructor of `Sequential`.
+   * \param passes The passes to apply.
+   * \param pass_info The pass metadata.
+   * \param disabled The passes that will not be applied.
+   */
+  TVM_DLL Sequential(tvm::Array<Pass> passes,
+                     PassInfo pass_info,
+                     tvm::Array<tvm::Expr> disabled);
+  Sequential() = default;
+  explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
+
+  const SequentialNode* operator->() const;
+  using ContainerType = Sequential;
+};
+
+
+/*
+ * \brief Create a module pass.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the module pass.
+ * \param name The name of the module pass.
+ * \param required The list of the passes that the module pass is dependent on.
+ *
+ * \return The created module pass.
+ */
+Pass CreateModulePass(
+    const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
+    int opt_level,
+    const std::string& name,
+    const tvm::Array<tvm::Expr>& required);
+
+/*
+ * \brief Create a function pass.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the function pass.
+ * \param name The name of the function pass.
+ * \param required The list of the passes that the function pass is dependent on.
+ *
+ * \return The created function pass.
+ */
+Pass CreateFunctionPass(
+    const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
+    int opt_level,
+    const std::string& name,
+    const tvm::Array<tvm::Expr>& required);
+
+}  // namespace transform
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_TRANSFORM_H_
index 1f1e4a6..d832c89 100644 (file)
@@ -25,6 +25,7 @@ from . import expr_functor
 from . import module
 from . import adt
 from . import ir_pass
+from . import transform
 from .build_module import build, build_config, create_executor
 from . import prelude
 from . import parser
@@ -97,9 +98,8 @@ Match = adt.Match
 var = expr.var
 const = expr.const
 bind = expr.bind
-module_pass = ir_pass.module_pass
-function_pass = ir_pass.function_pass
-sequential_pass = ir_pass.sequential_pass
+module_pass = transform.module_pass
+function_pass = transform.function_pass
 
 # ExprFunctor
 ExprFunctor = expr_functor.ExprFunctor
@@ -114,9 +114,9 @@ save_param_dict = param_dict.save_param_dict
 load_param_dict = param_dict.load_param_dict
 
 # Pass manager
-PassInfo = ir_pass.PassInfo
-PassContext = ir_pass.PassContext
-Pass = ir_pass.Pass
-ModulePass = ir_pass.ModulePass
-FunctionPass = ir_pass.FunctionPass
-SequentialPass = ir_pass.SequentialPass
+PassInfo = transform.PassInfo
+PassContext = transform.PassContext
+Pass = transform.Pass
+ModulePass = transform.ModulePass
+FunctionPass = transform.FunctionPass
+Sequential = transform.Sequential
index 6aedb52..13035bb 100644 (file)
 
 import tvm
 from . import ir
-from .base import NodeBase
 from .env import Module
 
-
-class PassContext(NodeBase):
-    def __init__(self):
-        ...
-
-class PassInfo(NodeBase):
-    name = ...  # type: str
-    opt_level = ... # type: int
-    required = ... # type: list
-
-    def __init__(self, name, opt_level, required)
-        # type: (str, int, list) -> None
-
-
-class Pass(NodeBase):
-    def __init__(self):
-        ...
-
-
-class ModulePass(Pass):
-    name = ...  # type: str
-    opt_level = ...  # type: int
-    pass_func = ...  # type: Callable
-    required = ...  # type: list
-
-    def __init__(self, name, opt_level, pass_func, required):
-        # type: (str, int, Callable, list) -> None
-        ...
-
-
-class FunctionPass(Pass):
-    name = ...  # type: str
-    opt_level = ...  # type: int
-    pass_func = ...  # type: Callable
-    required = ...  # type: list
-
-    def __init__(self, name, opt_level, pass_func, required):
-        # type: (str, int, Callable, list) -> None
-        ...
-
-
-class SequentialPass(Pass):
-    name = ...  # type: str
-    opt_level = ...  # type: int
-    passes = ...  # type: list
-    required = ...  # type: list
-    disabled = ... # type: list
-
-    def __init__(self, name, opt_level, passes, required, disabled):
-        # type: (str, int, list, list, list) -> None
-        ...
-
-
 def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
 def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
 def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
diff --git a/python/tvm/relay/_transform.py b/python/tvm/relay/_transform.py
new file mode 100644 (file)
index 0000000..273d97e
--- /dev/null
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI exposing the Relay type inference and checking."""
+
+from tvm._ffi.function import _init_api
+
+_init_api("relay._transform", __name__)
index 5f23e14..ea34c6b 100644 (file)
 # pylint: disable=no-else-return
 # pylint: disable=unidiomatic-typecheck
 """
-This file contains:
-1. The set of passes for Relay, which exposes an interface for configuring the
-   passes and scripting them in Python.
-
-2. The pass manager for Relay which exposes different granularity of interfaces
-   for users to implement and use passes more conveniently.
+This file contains the set of passes for Relay, which exposes an interface for
+configuring the passes and scripting them in Python.
 """
-import types
-
 from . import _ir_pass
 from . import _make
 from .expr import Expr
 from .ty import Type
-from .base import RelayNode, register_relay_node
 from .module import Module
 
 
-@register_relay_node
-class PassInfo(RelayNode):
-    """The class that contains the meta data required by a pass. It is the
-    container of information needed by running an optimization or analysis.
-    This class can be extended by adding new members when more meta data is
-    needed.
-
-    Parameters
-    ----------
-    name : str
-        The pass name.
-
-    opt_level : int
-        The optimization level of this pass.
-
-    required : List[str]
-        The list of passes that are required by a certain pass.
-    """
-
-    def __init__(self, name, opt_level, required=None):
-        self.__init_handle_by_constructor__(_ir_pass.PassInfo, name, opt_level,
-                                            required)
-
-
-@register_relay_node
-class PassContext(RelayNode):
-    """The basis where a Relay optimization/analysis runs on.
-    Each pass context contains a number of auxiliary information that is used
-    to help an optimization pass. Such information includes the error reporter
-    to record the errors of during the optimization, etc.
-    """
-
-    def __init__(self):
-        self.__init_handle_by_constructor__(_ir_pass.PassContext)
-
-
-@register_relay_node
-class Pass(RelayNode):
-    """The base class of all passes. All methods here are just simple wrappers
-    that are implemented in the backend. They are defined for users to
-    conveniently interact with the base class.
-    """
-
-    def set_pass_context(self, pass_ctx):
-        """Setup the pass context for analysis and optimizations. This context
-        could be shared by different passes for sequential passes.
-
-        Parameters
-        ----------
-        pass_ctx : PassContext
-            The context that is used to help perform a certain pass or a series
-            of passes.
-        """
-        if not isinstance(pass_ctx, PassContext):
-            raise TypeError("pass_ctx is expected to be the PassContext type")
-        _ir_pass.SetContext(self, pass_ctx)
-
-    @property
-    def info(self):
-        """Get the pass meta."""
-        return _ir_pass.Info(self)
-
-    def __call__(self, mod):
-        """Execute the pass. Note that for sequential pass, the dependency among
-        different passes will be resolved in the backend.
-
-        Parameters
-        ----------
-        mod : tvm.relay.Module
-            The module that a certain optimization is performed on.
-
-        Returns
-        -------
-        mod : tvm.relay.Module
-            The updated module after applying this pass.
-        """
-        return _ir_pass.RunPass(self, mod)
-
-
-@register_relay_node
-class ModulePass(Pass):
-    """A pass that works on tvm.relay.Module. Users don't need to interact with
-    this class directly. Instead, a module pass should be created through
-    `module_pass`, because the design of the `module_pass` API is flexible
-    enough to handle the creation of a module pass in different manners. In
-    addition, all members of a module pass can be accessed from the base class.
-    The same rule applies to FunctionPass and SequentialPass as well.
-    """
-
-
-@register_relay_node
-class FunctionPass(Pass):
-    """A pass that works on each tvm.relay.Function in a module. A function
-    pass class should be created through `function_pass`.
-    """
-
-
-@register_relay_node
-class SequentialPass(Pass):
-    """A pass that works on a sequence of pass objects. A sequential pass class
-    should be created through `sequential_pass`.
-    """
-
-
-def module_pass(pass_func=None, opt_level=None, name=None, required=None):
-    """Create a module pass. This function returns a callback when pass_func
-    is provided. Otherwise, it returns the created module level pass using the
-    given optimization function.
-
-    Parameters
-    ----------
-    pass_func : Optional[Callable[(Module/Function, PassContext) ->
-                Module/Function]]
-        The implemented optimization pass.
-
-    opt_level : int
-        The optimization level of this module pass.
-
-    name : Optional[str]
-        The name of the module pass. The name could be empty. In this case, the
-        name of the optimization function will be used as the pass name.
-
-    required : Optional[List[str]]
-        The list of passes that the module pass is dependent on.
-
-    Returns
-    -------
-    create_module_pass : Union[Callable, ModulePass]
-        The callable that will create a module pass is returned when
-        pass_func is not passed in. Otherwise, a ModulePass object will be
-        directly created.
-
-       Examples
-    --------
-    The following code creates a module level pass and adds an abs function to
-    the module.
-
-    .. code-block:: python
-
-        @relay.ir_pass.module_pass(opt_level=2)
-        def transform(mod, ctx):
-            tp = relay.TensorType((10,), "float32")
-            x = relay.var("x", tp)
-            gv = relay.GlobalVar("var")
-            func = relay.Function([x], relay.abs(x))
-            new_mod = relay.Module({gv: func})
-            new_mod.update(mod)
-            return new_mod
-
-        module_pass = transform
-        assert isinstance(module_pass, ir_pass.ModulePass)
-        assert module_pass.info.opt_level == 2
-
-        # Given a module m, the optimization could be invoked as the follwoing:
-        updated_mod = module_pass(m)
-        # Now a function abs should be added to the module m.
-    """
-
-    if opt_level is None:
-        raise ValueError("Please provide opt_level for the module pass.")
-
-    required = required if required else []
-    if not isinstance(required, (list, tuple)):
-        raise TypeError("Required is expected to be the type of " +
-                        "list/tuple.")
-
-    def create_module_pass(pass_func):
-        """Internal function that creates a module pass"""
-        if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
-            raise TypeError("pass_func must be a callable for Module pass")
-
-        return _ir_pass.CreateModulePass(pass_func, opt_level,
-                                         name if name else pass_func.__name__,
-                                         required)
-
-    if pass_func:
-        return create_module_pass(pass_func)
-    return create_module_pass
-
-
-def function_pass(pass_func=None, opt_level=None, name=None, required=None):
-    """Create a function pass. This function returns a callback when pass_func
-    is provided. Otherwise, it returns the created function pass using the
-    given optimization function.
-
-    Parameters
-    ----------
-    pass_func : Optional[Callable[(Module/Function, PassContext) ->
-                Module/Function]]
-        The implemented optimization pass.
-
-    opt_level : int
-        The optimization level of this module pass.
-
-    name : Optional[str]
-        The name of the function pass. The name could be empty. In this case, the
-        name of the optimization function will be used as the pass name.
-
-    required : Optional[List[str]]
-        The list of passes that the module pass is dependent on.
-
-    Returns
-    -------
-    create_function_pass : Union[Callable, FunctionPass]
-        The callable that will create a function pass is returned when
-        pass_func is not passed in. Otherwise, a FunctionPass object will be
-        created.
-
-    Examples
-    --------
-    The following code creates a function level pass that performs constant
-    folding.
-
-    .. code-block:: python
-
-        @relay.ir_pass.function_pass(opt_level=2)
-        def transform(func, ctx):
-            return ir_pass.fold_constant(func)
-
-        function_pass = transform
-        assert isinstance(function_pass, ir_pass.FunctionPass)
-        assert function_pass.info.opt_level == 2
-
-        # Given a module m, the optimization could be invoked as the follwoing:
-        updated_mod = function_pass(m)
-        # Now constant folding should have been applied to every function in
-        # the provided module m. And the updated module will be returned.
-    """
-
-    if opt_level is None:
-        raise ValueError("Please provide opt_level for the funtion pass.")
-
-    required = required if required else []
-    if not isinstance(required, (list, tuple)):
-        raise TypeError("Required is expected to be the type of " +
-                        "list/tuple.")
-
-    def create_function_pass(pass_func):
-        """Internal function that creates a function pass"""
-        if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
-            raise TypeError("pass_func must be a callable for Module pass")
-
-        return _ir_pass.CreateFunctionPass(pass_func, opt_level,
-                                           name if name else pass_func.__name__,
-                                           required)
-
-    if pass_func:
-        return create_function_pass(pass_func)
-    return create_function_pass
-
-
-def sequential_pass(passes=None, opt_level=2, name="sequential_pass",
-                    required=None, disabled=None):
-    """Create a sequential pass using a defined optimization function from
-    Python. Some typical usage of the sequential pass are:
-    1. Users provide a list of passes for optimization.
-    2. Only an optimization level is provided so that the backend system has
-       to glob all passes at this level and below to perform the optimizations.
-    Note that users can also provide a series of passes that they don't want to
-    apply when running a sequential pass. Pass dependency will be resolved in
-    the backend as well.
-
-    Parameters
-    ----------
-    passes : Optional[List[Pass]]
-        A sequence of passes candidate for optimization.
-
-    opt_level : Optional[int]
-        The optimization level of this sequential pass.
-
-    name : Optional[str]
-        The name of the sequential pass.
-
-    required : Optional[List[str]]
-        The list of passes that the sequential pass is dependent on.
-
-    disabled : Optional[List[str]]
-        A list of disabled passes.
-
-    Returns
-    -------
-    ret : Pass
-        A sequential pass built through pass_func.
-    """
-
-    passes = passes if passes else []
-    if not isinstance(passes, (list, tuple)):
-        raise TypeError("passes must be a list of Pass objects.")
-
-    disabled = disabled if disabled else []
-    if not isinstance(disabled, (list, tuple)):
-        raise TypeError("disabled must be a list or tuple of pass names")
-
-    required = required if required else []
-    if not isinstance(required, (list, tuple)):
-        raise TypeError("Required is expected to be the type of list/tuple.")
-
-    return _ir_pass.CreateSequentialPass(passes, opt_level, name, required,
-                                         disabled)
-
-
 def post_order_visit(expr, fvisit):
     """Recursively visit the ir in post DFS order node,
     apply fvisit. Each node is guaranteed to be visited
diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py
new file mode 100644 (file)
index 0000000..877538a
--- /dev/null
@@ -0,0 +1,325 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return
+# pylint: disable=unidiomatic-typecheck
+"""
+This file contains the pass manager for Relay which exposes different
+granularity of interfaces for users to implement and use passes more
+conveniently.
+"""
+import types
+
+from . import _transform
+from .base import RelayNode, register_relay_node
+
+
+@register_relay_node
+class PassInfo(RelayNode):
+    """The class that contains the meta data required by a pass. It is the
+    container of information needed by running an optimization or analysis.
+    This class can be extended by adding new members when more meta data is
+    needed.
+
+    Parameters
+    ----------
+    name : str
+        The pass name.
+
+    opt_level : int
+        The optimization level of this pass.
+
+    required : List[str]
+        The list of passes that are required by a certain pass.
+    """
+
+    def __init__(self, name, opt_level, required=None):
+        self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level,
+                                            required)
+
+
+@register_relay_node
+class PassContext(RelayNode):
+    """The basis where a Relay optimization/analysis runs on.
+    Each pass context contains a number of auxiliary information that is used
+    to help an optimization pass. Such information includes the error reporter
+    to record the errors of during the optimization, etc.
+    """
+
+    def __init__(self):
+        self.__init_handle_by_constructor__(_transform.PassContext)
+
+
+@register_relay_node
+class Pass(RelayNode):
+    """The base class of all passes. All methods here are just simple wrappers
+    that are implemented in the backend. They are defined for users to
+    conveniently interact with the base class.
+    """
+
+    def set_pass_context(self, pass_ctx):
+        """Setup the pass context for analysis and optimizations. This context
+        could be shared by different passes for sequential passes.
+
+        Parameters
+        ----------
+        pass_ctx : PassContext
+            The context that is used to help perform a certain pass or a series
+            of passes.
+        """
+        if not isinstance(pass_ctx, PassContext):
+            raise TypeError("pass_ctx is expected to be the PassContext type")
+        _transform.SetContext(self, pass_ctx)
+
+    @property
+    def info(self):
+        """Get the pass meta."""
+        return _transform.Info(self)
+
+    def __call__(self, mod):
+        """Execute the pass. Note that for sequential pass, the dependency among
+        different passes will be resolved in the backend.
+
+        Parameters
+        ----------
+        mod : tvm.relay.Module
+            The module that a certain optimization is performed on.
+
+        Returns
+        -------
+        mod : tvm.relay.Module
+            The updated module after applying this pass.
+        """
+        return _transform.RunPass(self, mod)
+
+
+@register_relay_node
+class ModulePass(Pass):
+    """A pass that works on tvm.relay.Module. Users don't need to interact with
+    this class directly. Instead, a module pass should be created through
+    `module_pass`, because the design of the `module_pass` API is flexible
+    enough to handle the creation of a module pass in different manners. In
+    addition, all members of a module pass can be accessed from the base class.
+    The same rule applies to FunctionPass and Sequential as well.
+    """
+
+
+@register_relay_node
+class FunctionPass(Pass):
+    """A pass that works on each tvm.relay.Function in a module. A function
+    pass class should be created through `function_pass`.
+    """
+
+
+@register_relay_node
+class Sequential(Pass):
+    """A pass that works on a sequence of pass objects. Multiple passes can be
+    executed sequentially using this class.
+
+    Some typical usage of the sequential pass are:
+    1. Users provide a list of passes for optimization.
+    2. Only an optimization level is provided so that the backend system has
+       to glob all passes at this level and below to perform the optimizations.
+    Note that users can also provide a series of passes that they don't want to
+    apply when running a sequential pass. Pass dependency will be resolved in
+    the backend as well.
+
+    Parameters
+    ----------
+    passes : Optional[List[Pass]]
+        A sequence of passes candidate for optimization.
+
+    opt_level : Optional[int]
+        The optimization level of this sequential pass.
+
+    name : Optional[str]
+        The name of the sequential pass.
+
+    required : Optional[List[str]]
+        The list of passes that the sequential pass is dependent on.
+
+    disabled : Optional[List[str]]
+        A list of disabled passes.
+    """
+
+    def __init__(self,
+                 passes=None,
+                 opt_level=2,
+                 name="sequential",
+                 required=None,
+                 disabled=None):
+        passes = passes if passes else []
+        if not isinstance(passes, (list, tuple)):
+            raise TypeError("passes must be a list of Pass objects.")
+
+        disabled = disabled if disabled else []
+        if not isinstance(disabled, (list, tuple)):
+            raise TypeError("disabled must be a list or tuple of pass names")
+
+        required = required if required else []
+        if not isinstance(required, (list, tuple)):
+            raise TypeError("Required is expected to be the type of list/tuple.")
+
+        self.__init_handle_by_constructor__(_transform.Sequential,
+                                            passes, opt_level, name, required,
+                                            disabled)
+
+
+def module_pass(pass_func=None, opt_level=None, name=None, required=None):
+    """Create a module pass. This function returns a callback when pass_func
+    is provided. Otherwise, it returns the created module level pass using the
+    given optimization function.
+
+    Parameters
+    ----------
+    pass_func : Optional[Callable[(Module/Function, PassContext) ->
+                Module/Function]]
+        The implemented optimization pass.
+
+    opt_level : int
+        The optimization level of this module pass.
+
+    name : Optional[str]
+        The name of the module pass. The name could be empty. In this case, the
+        name of the optimization function will be used as the pass name.
+
+    required : Optional[List[str]]
+        The list of passes that the module pass is dependent on.
+
+    Returns
+    -------
+    create_module_pass : Union[Callable, ModulePass]
+        The callable that will create a module pass is returned when
+        pass_func is not passed in. Otherwise, a ModulePass object will be
+        directly created.
+
+       Examples
+    --------
+    The following code creates a module level pass and adds an abs function to
+    the module.
+
+    .. code-block:: python
+
+        @relay.transform.module_pass(opt_level=2)
+        def transform(mod, ctx):
+            tp = relay.TensorType((10,), "float32")
+            x = relay.var("x", tp)
+            gv = relay.GlobalVar("var")
+            func = relay.Function([x], relay.abs(x))
+            new_mod = relay.Module({gv: func})
+            new_mod.update(mod)
+            return new_mod
+
+        module_pass = transform
+        assert isinstance(module_pass, transform.ModulePass)
+        assert module_pass.info.opt_level == 2
+
+        # Given a module m, the optimization could be invoked as the follwoing:
+        updated_mod = module_pass(m)
+        # Now a function abs should be added to the module m.
+    """
+
+    if opt_level is None:
+        raise ValueError("Please provide opt_level for the module pass.")
+
+    required = required if required else []
+    if not isinstance(required, (list, tuple)):
+        raise TypeError("Required is expected to be the type of " +
+                        "list/tuple.")
+
+    def create_module_pass(pass_func):
+        """Internal function that creates a module pass"""
+        if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
+            raise TypeError("pass_func must be a callable for Module pass")
+
+        return _transform.CreateModulePass(
+            pass_func, opt_level, name if name else pass_func.__name__,
+            required)
+
+    if pass_func:
+        return create_module_pass(pass_func)
+    return create_module_pass
+
+
+def function_pass(pass_func=None, opt_level=None, name=None, required=None):
+    """Create a function pass. This function returns a callback when pass_func
+    is provided. Otherwise, it returns the created function pass using the
+    given optimization function.
+
+    Parameters
+    ----------
+    pass_func : Optional[Callable[(Module/Function, PassContext) ->
+                Module/Function]]
+        The implemented optimization pass.
+
+    opt_level : int
+        The optimization level of this module pass.
+
+    name : Optional[str]
+        The name of the function pass. The name could be empty. In this case, the
+        name of the optimization function will be used as the pass name.
+
+    required : Optional[List[str]]
+        The list of passes that the module pass is dependent on.
+
+    Returns
+    -------
+    create_function_pass : Union[Callable, FunctionPass]
+        The callable that will create a function pass is returned when
+        pass_func is not passed in. Otherwise, a FunctionPass object will be
+        created.
+
+    Examples
+    --------
+    The following code creates a function level pass that performs constant
+    folding.
+
+    .. code-block:: python
+
+        @relay.transform.function_pass(opt_level=2)
+        def transform(func, ctx):
+            return ir_pass.fold_constant(func)
+
+        function_pass = transform
+        assert isinstance(function_pass, transform.FunctionPass)
+        assert function_pass.info.opt_level == 2
+
+        # Given a module m, the optimization could be invoked as the follwoing:
+        updated_mod = function_pass(m)
+        # Now constant folding should have been applied to every function in
+        # the provided module m. And the updated module will be returned.
+    """
+
+    if opt_level is None:
+        raise ValueError("Please provide opt_level for the funtion pass.")
+
+    required = required if required else []
+    if not isinstance(required, (list, tuple)):
+        raise TypeError("Required is expected to be the type of " +
+                        "list/tuple.")
+
+    def create_function_pass(pass_func):
+        """Internal function that creates a function pass"""
+        if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
+            raise TypeError("pass_func must be a callable for Module pass")
+
+        return _transform.CreateFunctionPass(
+            pass_func, opt_level, name if name else pass_func.__name__,
+            required)
+
+    if pass_func:
+        return create_function_pass(pass_func)
+    return create_function_pass
diff --git a/python/tvm/relay/transform.pyi b/python/tvm/relay/transform.pyi
new file mode 100644 (file)
index 0000000..343e899
--- /dev/null
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from .base import NodeBase
+
+
+class PassContext(NodeBase):
+    def __init__(self):
+        ...
+
+class PassInfo(NodeBase):
+    name = ...  # type: str
+    opt_level = ... # type: int
+    required = ... # type: list
+
+    def __init__(self, name, opt_level, required)
+        # type: (str, int, list) -> None
+
+
+class Pass(NodeBase):
+    def __init__(self):
+        ...
+
+
+class ModulePass(Pass):
+    name = ...  # type: str
+    opt_level = ...  # type: int
+    pass_func = ...  # type: Callable
+    required = ...  # type: list
+
+    def __init__(self, name, opt_level, pass_func, required):
+        # type: (str, int, Callable, list) -> None
+        ...
+
+
+class FunctionPass(Pass):
+    name = ...  # type: str
+    opt_level = ...  # type: int
+    pass_func = ...  # type: Callable
+    required = ...  # type: list
+
+    def __init__(self, name, opt_level, pass_func, required):
+        # type: (str, int, Callable, list) -> None
+        ...
+
+
+class Sequential(Pass):
+    name = ...  # type: str
+    opt_level = ...  # type: int
+    passes = ...  # type: list
+    required = ...  # type: list
+    disabled = ... # type: list
+
+    def __init__(self, name, opt_level, passes, required, disabled):
+        # type: (str, int, list, list, list) -> None
+        ...
index d607247..a105b69 100644 (file)
  * \brief Relay pass manager implementation.
  */
 #include <tvm/relay/expr_functor.h>
-#include <tvm/relay/pass.h>
+#include <tvm/relay/transform.h>
 
 namespace tvm {
 namespace relay {
-namespace pass {
+namespace transform {
 
 using tvm::IRPrinter;
 
@@ -169,17 +169,15 @@ class FunctionPassNode : public PassNode {
 
 RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);
 
-class SequentialPass;
-
 /*!
- * \brief The SequentialPassNode contains a set of passes that transform Relay
+ * \brief The SequentialNode contains a set of passes that transform Relay
  * programs from one AST to another semantically equivalent one.
  *
  * One example of this level of pass is that the pass manager needs to correctly
  * perform a host of optimizations with a given optimization level and disabled
  * passes.
  */
-class SequentialPassNode : public PassNode {
+class SequentialNode : public PassNode {
  public:
   /* \brief The pass meta data.*/
   PassInfo pass_info;
@@ -212,10 +210,6 @@ class SequentialPassNode : public PassNode {
     passes.push_back(pass);
   }
 
-  TVM_DLL static SequentialPass make(tvm::Array<Pass> passes,
-                                     PassInfo pass_info,
-                                     tvm::Array<tvm::Expr> disabled);
-
   /*!
    * \brief Resolve the pass dependency. It globs all required passes by
    *        a given pass and executes them.
@@ -251,8 +245,8 @@ class SequentialPassNode : public PassNode {
    */
   void SetContext(const PassContext& pass_ctx) final;
 
-  static constexpr const char* _type_key = "relay.SequentialPass";
-  TVM_DECLARE_NODE_TYPE_INFO(SequentialPassNode, PassNode);
+  static constexpr const char* _type_key = "relay.Sequential";
+  TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
 
  private:
   /*!
@@ -261,8 +255,6 @@ class SequentialPassNode : public PassNode {
   PassContext pass_ctx_;
 };
 
-RELAY_DEFINE_NODE_REF(SequentialPass, SequentialPassNode, Pass);
-
 PassInfo PassInfoNode::make(int opt_level, std::string name,
                             tvm::Array<tvm::Expr> required) {
   auto pass_info = make_node<PassInfoNode>();
@@ -350,20 +342,24 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {
   return pval && pval->value != 0;
 }
 
-SequentialPass SequentialPassNode::make(tvm::Array<Pass> passes,
-                                        PassInfo pass_info,
-                                        tvm::Array<tvm::Expr> disabled) {
-  auto n = make_node<SequentialPassNode>();
+Sequential::Sequential(tvm::Array<Pass> passes,
+                       PassInfo pass_info,
+                       tvm::Array<tvm::Expr> disabled) {
+  auto n = make_node<SequentialNode>();
   n->passes = std::move(passes);
   n->pass_info = std::move(pass_info);
   n->disabled = std::move(disabled);
-  return SequentialPass(n);
+  node_ = std::move(n);
+}
+
+const SequentialNode* Sequential::operator->() const {
+  return static_cast<const SequentialNode*>(this->node_.get());
 }
 
 // TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
-// a SequentialPass without the consideration of their orders. The phase
+// a Sequential without the consideration of their orders. The phase
 // ordering problem needed to be handled in the future.
-Module SequentialPassNode::operator()(const Module& module) const {
+Module SequentialNode::operator()(const Module& module) const {
   Module mod = module;
   for (const Pass& pass : passes) {
     CHECK(pass.defined()) << "Found undefined pass for optimization.";
@@ -373,7 +369,7 @@ Module SequentialPassNode::operator()(const Module& module) const {
   return mod;
 }
 
-void SequentialPassNode::ResolveDependency(const Module& mod) {
+void SequentialNode::ResolveDependency(const Module& mod) {
   // TODO(zhiics) Implement it.
   // 1. Consider the required passes for each pass.
   // 2. Only resolve the enabled passes.
@@ -382,7 +378,7 @@ void SequentialPassNode::ResolveDependency(const Module& mod) {
              << "\n";
 }
 
-std::vector<std::string> SequentialPassNode::DisabledPasses() const {
+std::vector<std::string> SequentialNode::DisabledPasses() const {
   std::vector<std::string> ret;
   for (const auto& it : disabled) {
     const auto* str = it.as<tvm::ir::StringImm>();
@@ -392,7 +388,7 @@ std::vector<std::string> SequentialPassNode::DisabledPasses() const {
   return ret;
 }
 
-void SequentialPassNode::SetContext(const PassContext& pass_ctx) {
+void SequentialNode::SetContext(const PassContext& pass_ctx) {
   pass_ctx_ = pass_ctx;
 }
 
@@ -414,21 +410,12 @@ Pass CreateFunctionPass(
   return FunctionPassNode::make(pass_func, pass_info);
 }
 
-Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
-                          int opt_level,
-                          const std::string& name,
-                          const tvm::Array<tvm::Expr>& required,
-                          const tvm::Array<tvm::Expr>& disabled) {
-  PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
-  return SequentialPassNode::make(passes, pass_info, disabled);
-}
-
 TVM_REGISTER_NODE_TYPE(PassInfoNode);
 
-TVM_REGISTER_API("relay._ir_pass.PassInfo")
+TVM_REGISTER_API("relay._transform.PassInfo")
 .set_body_typed(PassInfoNode::make);
 
-TVM_REGISTER_API("relay._ir_pass.Info")
+TVM_REGISTER_API("relay._transform.Info")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   Pass pass = args[0];
   *ret = pass->Info();
@@ -450,10 +437,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(ModulePassNode);
 
-TVM_REGISTER_API("relay._ir_pass.CreateModulePass")
+TVM_REGISTER_API("relay._transform.CreateModulePass")
 .set_body_typed(CreateModulePass);
 
-TVM_REGISTER_API("relay._ir_pass.RunPass")
+TVM_REGISTER_API("relay._transform.RunPass")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   Pass pass = args[0];
   Module mod = args[1];
@@ -475,7 +462,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(FunctionPassNode);
 
-TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass")
+TVM_REGISTER_API("relay._transform.CreateFunctionPass")
 .set_body_typed(CreateFunctionPass);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
@@ -486,9 +473,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
             << " at the optimization level " << pn->opt_level;
 });
 
-TVM_REGISTER_NODE_TYPE(SequentialPassNode);
+TVM_REGISTER_NODE_TYPE(SequentialNode);
 
-TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass")
+TVM_REGISTER_API("relay._transform.Sequential")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   tvm::Array<Pass> passes = args[0];
   int opt_level = args[1];
@@ -496,14 +483,14 @@ TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass")
   tvm::Array<tvm::Expr> required = args[3];
   tvm::Array<tvm::Expr> disabled = args[4];
   PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
-  *ret = SequentialPassNode::make(passes, pass_info, disabled);
+  *ret = Sequential(passes, pass_info, disabled);
 });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<SequentialPassNode>([](const SequentialPassNode* node,
-                                     tvm::IRPrinter* p) {
+.set_dispatch<SequentialNode>([](const SequentialNode* node,
+                                 tvm::IRPrinter* p) {
   const PassInfoNode* seq_pn = node->Info().operator->();
-  p->stream << "Run SequentialPass pass: " << seq_pn->name
+  p->stream << "Run Sequential pass: " << seq_pn->name
             << " at the optimization level. " << seq_pn->opt_level;
   p->stream << "The passes will be executed are: [";
   for (const auto& it : node->passes) {
@@ -514,7 +501,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
   p->stream << "]";
 });
 
-TVM_REGISTER_API("relay._ir_pass.SetContext")
+TVM_REGISTER_API("relay._transform.SetContext")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   Pass pass = args[0];
   PassContext pass_ctx = args[1];
@@ -523,7 +510,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext")
 
 TVM_REGISTER_NODE_TYPE(PassContextNode);
 
-TVM_REGISTER_API("relay._ir_pass.PassContext")
+TVM_REGISTER_API("relay._transform.PassContext")
 .set_body_typed(PassContextNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
@@ -534,6 +521,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
                << "\n";
 });
 
-}  // namespace pass
+}  // namespace transform
 }  // namespace relay
 }  // namespace tvm
index b821677..db346e7 100644 (file)
@@ -22,6 +22,7 @@ from tvm import relay
 from tvm.relay import ExprFunctor
 from tvm.relay import Function, Call
 from tvm.relay import ir_pass
+from tvm.relay import transform as _transform
 from tvm.relay.testing import ctx_list
 
 
@@ -126,13 +127,13 @@ def test_module_pass():
     opt_tester = OptTester(mod)
     pass_ctx = None
 
-    @ir_pass.module_pass(opt_level=opt_level, name=pass_name)
+    @_transform.module_pass(opt_level=opt_level, name=pass_name)
     def transform(expr, ctx):
         return opt_tester.transform(expr, ctx)
 
     def test_pass_registration():
         mod_pass = transform
-        assert isinstance(mod_pass, ir_pass.ModulePass)
+        assert isinstance(mod_pass, _transform.ModulePass)
         pass_info = mod_pass.info
         assert pass_info.name == pass_name
         assert pass_info.opt_level == opt_level
@@ -140,8 +141,8 @@ def test_module_pass():
     def test_pass_registration_no_decorator():
         def direct_transform(expr, ctx):
             return opt_tester.transform(expr, ctx)
-        mod_pass = ir_pass.module_pass(direct_transform, opt_level=3)
-        assert isinstance(mod_pass, ir_pass.ModulePass)
+        mod_pass = _transform.module_pass(direct_transform, opt_level=3)
+        assert isinstance(mod_pass, _transform.ModulePass)
         pass_info = mod_pass.info
         assert pass_info.name == "direct_transform"
         assert pass_info.opt_level == 3
@@ -202,7 +203,7 @@ def test_function_pass():
     opt_tester = OptTester(mod)
     pass_ctx = None
 
-    @ir_pass.function_pass(opt_level=opt_level, name=pass_name)
+    @_transform.function_pass(opt_level=opt_level, name=pass_name)
     def transform(expr, ctx):
         return opt_tester.transform(expr, ctx)
 
@@ -212,7 +213,7 @@ def test_function_pass():
 
     def test_pass_registration():
         function_pass = transform
-        assert isinstance(function_pass, ir_pass.FunctionPass)
+        assert isinstance(function_pass, _transform.FunctionPass)
         pass_info = function_pass.info
         assert pass_info.name == pass_name
         assert pass_info.opt_level == opt_level
@@ -220,8 +221,8 @@ def test_function_pass():
     def test_pass_registration_no_decorator():
         def direct_transform(expr, ctx):
             return opt_tester.transform(expr, ctx)
-        mod_pass = ir_pass.function_pass(direct_transform, opt_level=0)
-        assert isinstance(mod_pass, ir_pass.FunctionPass)
+        mod_pass = _transform.function_pass(direct_transform, opt_level=0)
+        assert isinstance(mod_pass, _transform.FunctionPass)
         pass_info = mod_pass.info
         assert pass_info.name == "direct_transform"
         assert pass_info.opt_level == 0
@@ -294,14 +295,14 @@ def test_sequential_pass():
     opt_tester = OptTester(mod)
     pass_ctx = None
 
-    @ir_pass.module_pass(opt_level=1)
+    @_transform.module_pass(opt_level=1)
     def mod_transform(expr, ctx):
         return opt_tester.transform(expr, ctx)
 
     module_pass = mod_transform
 
     # Register a function pass.
-    @ir_pass.function_pass(opt_level=1)
+    @_transform.function_pass(opt_level=1)
     def func_transform(expr, ctx):
         return opt_tester.transform(expr, ctx)
 
@@ -310,25 +311,23 @@ def test_sequential_pass():
     def test_pass_registration():
         passes = [module_pass, function_pass]
         opt_level = 2
-        pass_name = "sequential_pass"
-        sequential_pass = ir_pass.sequential_pass(passes=passes,
-                                                  opt_level=opt_level)
-        assert isinstance(sequential_pass, ir_pass.SequentialPass)
-        pass_info = sequential_pass.info
+        pass_name = "sequential"
+        sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
+        pass_info = sequential.info
         assert pass_info.name == pass_name
         assert pass_info.opt_level == opt_level
 
     def test_no_pass():
         passes = []
-        sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes)
-        ret_mod = sequential_pass(mod)
+        sequential = _transform.Sequential(opt_level=1, passes=passes)
+        ret_mod = sequential(mod)
         mod_func = ret_mod[v_sub]
         check_func(sub, mod_func)
 
     def test_only_module_pass():
         passes = [module_pass]
-        sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes)
-        ret_mod = sequential_pass(mod)
+        sequential = _transform.Sequential(opt_level=1, passes=passes)
+        ret_mod = sequential(mod)
         # Check the subtract function.
         sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
         check_func(new_sub, sub)
@@ -341,8 +340,8 @@ def test_sequential_pass():
     def test_only_function_pass():
         # Check the subtract function.
         passes = [function_pass]
-        sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes)
-        ret_mod = sequential_pass(mod)
+        sequential = _transform.Sequential(opt_level=1, passes=passes)
+        ret_mod = sequential(mod)
         _, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
         check_func(new_sub, get_ref_sub())
 
@@ -355,8 +354,8 @@ def test_sequential_pass():
         # function pass.
         mod = relay.Module({v_sub: sub, v_log: log})
         passes = [module_pass, function_pass]
-        sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes)
-        ret_mod = sequential_pass(mod)
+        sequential = _transform.Sequential(opt_level=1, passes=passes)
+        ret_mod = sequential(mod)
 
         # Check the abs function is added.
         abs_var, abs_func = get_var_func()