/*!
* \file tvm/relay/pass.h
* \brief The set of Relay passes written in C++.
- *
- * This file also implements a pass manager. The pass manager manages a sequence
- * of Relay-to-Relay transformation passes over a particlar unit of AST. The
- * design is largely inspired from LLVM's pass manager and modern deep learning
- * frameworks that perform tensor->tensor transformations.
- *
- * The responsibilities of a traditional compiler pass manager usually involves:
- * - Organizing the execution order of optimization passes though not
- * necessarily in the optimal sequence.
- * - Collecting required analysis information and keep them up-to-date.
- * - Reducing the effort required to implement new passes for compiler
- * developers, etc.
- *
- * Similar to LLVM's pass manager, we designed the Relay pass manager to work
- * different granularity, i.e. module level, function level, and even sequential
- * passe that contains a host of passes.
- *
- * However, we also extend the functionality of the traditional pass manager
- * with the consideration of requirements/convention from deep learning
- * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
- * manager performs the Relay.Module -> Relay.Module transformation. All
- * different types of passes, including the sequential-level pass object, are
- * essentially pass objects. This design, therefore, effectively provides users
- * a consistent and convenient interface, i.e. Pass, to play with. It offers a
- * means to ease the development and testing of Relay passes. For example, with
- * the pass manager, external users will be able to have custom passes correctly
- * scheduled without having to modify a single handcrafted pass order.
- *
- * In the future we need to describe constraints between passes. For example,
- * we may want to preserve dependencies between different passes and validate
- * them on the completion of a certain pass.
- *
- * We also need to store side information and import the error reporting system.
- */
+ */
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
-#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
namespace tvm {
namespace relay {
-namespace pass {
-
-/*
- * \brief The context of pass.
- */
-class PassContext;
-
-/*!
- * \brief PassContextNode contains the information that a pass can rely on, such as
- * analysis results.
- */
-class PassContextNode : public RelayNode {
- public:
- /*!
- * \brief The error reporter used to notify users why an optimization fails.
- */
- ErrorReporter err_reporter;
-
- PassContextNode() = default;
-
- void VisitAttrs(tvm::AttrVisitor* v) final {
- }
-
- TVM_DLL static PassContext make();
-
- static constexpr const char* _type_key = "relay.PassContext";
- TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
-};
-
-TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
-
-/*
- * \brief The meta data of a pass.
- *
- * PassInfo can be extended conveniently in the future if more meta information
- * is needed.
- */
-class PassInfo;
-
-/*!
- * \brief PassInfoNode contains meta data that will be used to help optimization
- * and analysis.
- */
-class PassInfoNode : public RelayNode {
- public:
- /*! \brief The minimal optimization level that this pass will be enabled. */
- int opt_level;
-
- /*! \brief The name of an optimization/analysis pass. */
- std::string name;
-
- /*! \brief The passes that are required to perform the current pass. */
- tvm::Array<tvm::Expr> required;
-
- PassInfoNode() = default;
-
- void VisitAttrs(tvm::AttrVisitor* v) final {
- v->Visit("opt_level", &opt_level);
- v->Visit("name", &name);
- v->Visit("required", &required);
- }
-
- TVM_DLL static PassInfo make(int opt_level, std::string name,
- tvm::Array<tvm::Expr> required);
-
- static constexpr const char* _type_key = "relay.PassInfo";
- TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
-};
-
-TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
-
-class Pass;
-
-/*!
- * \brief PassNode is the base type of differnt types of optimization passes.
- * It is designed as a pure class and implemented by different pass subclasses
- * at different granularity of Relay nodes.
- */
-class PassNode : public RelayNode {
- public:
- /*
- * \brief Get the pass information/meta data. */
- virtual PassInfo Info() const = 0;
-
- /*!
- * \brief Set the context information for a pass.
- *
- * \param pass_ctx The context information for a certain pass.
- */
- virtual void SetContext(const PassContext& pass_ctx) = 0;
-
- /*!
- * \brief Execute the optimization pass using a functor.
- *
- * \param mod The module that an optimization pass runs on.
- *
- * \return The updated module.
- */
- virtual Module operator()(const Module& mod) const = 0;
-
- void VisitAttrs(tvm::AttrVisitor* v) override {}
-
- static constexpr const char* _type_key = "relay.Pass";
- TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
-};
-
-class Pass : public NodeRef {
- public:
- Pass() = default;
- explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
-
- PassNode* operator->() const {
- return static_cast<PassNode*>(this->node_.get());
- }
-
- using ContainerType = PassNode;
-};
-
-/*
- * \brief Create a module pass.
- *
- * \param pass_func The packed function that contains the optimization.
- * \param opt_level The optimization level of the module pass.
- * \param name The name of the module pass.
- * \param required The list of the passes that the module pass is dependent on.
- *
- * \return The created module pass.
- */
-Pass CreateModulePass(
- const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
- int opt_level,
- const std::string& name,
- const tvm::Array<tvm::Expr>& required);
-
-/*
- * \brief Create a function pass.
- *
- * \param pass_func The packed function that contains the optimization.
- * \param opt_level The optimization level of the function pass.
- * \param name The name of the function pass.
- * \param required The list of the passes that the function pass is dependent on.
- *
- * \return The created function pass.
- */
-Pass CreateFunctionPass(
- const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
- int opt_level,
- const std::string& name,
- const tvm::Array<tvm::Expr>& required);
-/*
- * \brief Create a sequential pass.
- *
- * \param passes The optimization passes will be performed.
- * \param opt_level The optimization level of the sequential pass.
- * \param name The name of the sequential pass.
- * \param required The list of the passes that the sequential pass is dependent on.
- * \param disabled The disabled passes.
- *
- * \return The created sequential pass.
- */
-Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
- int opt_level,
- const std::string& name,
- const tvm::Array<tvm::Expr>& required,
- const tvm::Array<tvm::Expr>& disabled);
-
-} // namespace pass
-
/*!
* \brief Infer the type of an expression.
*
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/transform.h
+ *
+ * This file implements a pass manager. The pass manager manages a sequence
+ * of Relay-to-Relay transformation passes over a particlar unit of AST. The
+ * design is largely inspired from LLVM's pass manager and modern deep learning
+ * frameworks that perform tensor->tensor transformations.
+ *
+ * The responsibilities of a traditional compiler pass manager usually involves:
+ * - Organizing the execution order of optimization passes though not
+ * necessarily in the optimal sequence.
+ * - Collecting required analysis information and keep them up-to-date.
+ * - Reducing the effort required to implement new passes for compiler
+ * developers, etc.
+ *
+ * Similar to LLVM's pass manager, we designed the Relay pass manager to work
+ * different granularity, i.e. module level, function level, and even sequential
+ * passe that contains a host of passes.
+ *
+ * However, we also extend the functionality of the traditional pass manager
+ * with the consideration of requirements/convention from deep learning
+ * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
+ * manager performs the Relay.Module -> Relay.Module transformation. All
+ * different types of passes, including the sequential-level pass object, are
+ * essentially pass objects. This design, therefore, effectively provides users
+ * a consistent and convenient interface, i.e. Pass, to play with. It offers a
+ * means to ease the development and testing of Relay passes. For example, with
+ * the pass manager, external users will be able to have custom passes correctly
+ * scheduled without having to modify a single handcrafted pass order.
+ *
+ * In the future we need to describe constraints between passes. For example,
+ * we may want to preserve dependencies between different passes and validate
+ * them on the completion of a certain pass.
+ *
+ * We also need to store side information and import the error reporting system.
+ */
+#ifndef TVM_RELAY_TRANSFORM_H_
+#define TVM_RELAY_TRANSFORM_H_
+
+#include <tvm/packed_func_ext.h>
+#include <tvm/relay/error.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/module.h>
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+/*
+ * \brief The context of pass.
+ */
+class PassContext;
+
+/*!
+ * \brief PassContextNode contains the information that a pass can rely on, such as
+ * analysis results.
+ */
+class PassContextNode : public RelayNode {
+ public:
+ /*!
+ * \brief The error reporter used to notify users why an optimization fails.
+ */
+ ErrorReporter err_reporter;
+
+ PassContextNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ }
+
+ TVM_DLL static PassContext make();
+
+ static constexpr const char* _type_key = "relay.PassContext";
+ TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
+};
+
+TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
+
+/*
+ * \brief The meta data of a pass.
+ *
+ * PassInfo can be extended conveniently in the future if more meta information
+ * is needed.
+ */
+class PassInfo;
+
+/*!
+ * \brief PassInfoNode contains meta data that will be used to help optimization
+ * and analysis.
+ */
+class PassInfoNode : public RelayNode {
+ public:
+ /*! \brief The minimal optimization level that this pass will be enabled. */
+ int opt_level;
+
+ /*! \brief The name of an optimization/analysis pass. */
+ std::string name;
+
+ /*! \brief The passes that are required to perform the current pass. */
+ tvm::Array<tvm::Expr> required;
+
+ PassInfoNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("opt_level", &opt_level);
+ v->Visit("name", &name);
+ v->Visit("required", &required);
+ }
+
+ TVM_DLL static PassInfo make(int opt_level, std::string name,
+ tvm::Array<tvm::Expr> required);
+
+ static constexpr const char* _type_key = "relay.PassInfo";
+ TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
+};
+
+TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
+
+class Pass;
+
+/*!
+ * \brief PassNode is the base type of differnt types of optimization passes.
+ * It is designed as a pure class and implemented by different pass subclasses
+ * at different granularity of Relay nodes.
+ */
+class PassNode : public RelayNode {
+ public:
+ /*
+ * \brief Get the pass information/meta data. */
+ virtual PassInfo Info() const = 0;
+
+ /*!
+ * \brief Set the context information for a pass.
+ *
+ * \param pass_ctx The context information for a certain pass.
+ */
+ virtual void SetContext(const PassContext& pass_ctx) = 0;
+
+ /*!
+ * \brief Execute the optimization pass using a functor.
+ *
+ * \param mod The module that an optimization pass runs on.
+ *
+ * \return The updated module.
+ */
+ virtual Module operator()(const Module& mod) const = 0;
+
+ void VisitAttrs(tvm::AttrVisitor* v) override {}
+
+ static constexpr const char* _type_key = "relay.Pass";
+ TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
+};
+
+class Pass : public NodeRef {
+ public:
+ Pass() = default;
+ explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
+
+ PassNode* operator->() const {
+ return static_cast<PassNode*>(this->node_.get());
+ }
+
+ using ContainerType = PassNode;
+};
+
+class SequentialNode;
+
+class Sequential : public Pass {
+ public:
+ /*!
+ * \brief The constructor of `Sequential`.
+ * \param passes The passes to apply.
+ * \param pass_info The pass metadata.
+ * \param disabled The passes that will not be applied.
+ */
+ TVM_DLL Sequential(tvm::Array<Pass> passes,
+ PassInfo pass_info,
+ tvm::Array<tvm::Expr> disabled);
+ Sequential() = default;
+ explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
+
+ const SequentialNode* operator->() const;
+ using ContainerType = Sequential;
+};
+
+
+/*
+ * \brief Create a module pass.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the module pass.
+ * \param name The name of the module pass.
+ * \param required The list of the passes that the module pass is dependent on.
+ *
+ * \return The created module pass.
+ */
+Pass CreateModulePass(
+ const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
+ int opt_level,
+ const std::string& name,
+ const tvm::Array<tvm::Expr>& required);
+
+/*
+ * \brief Create a function pass.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the function pass.
+ * \param name The name of the function pass.
+ * \param required The list of the passes that the function pass is dependent on.
+ *
+ * \return The created function pass.
+ */
+Pass CreateFunctionPass(
+ const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
+ int opt_level,
+ const std::string& name,
+ const tvm::Array<tvm::Expr>& required);
+
+} // namespace transform
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_TRANSFORM_H_
from . import module
from . import adt
from . import ir_pass
+from . import transform
from .build_module import build, build_config, create_executor
from . import prelude
from . import parser
var = expr.var
const = expr.const
bind = expr.bind
-module_pass = ir_pass.module_pass
-function_pass = ir_pass.function_pass
-sequential_pass = ir_pass.sequential_pass
+module_pass = transform.module_pass
+function_pass = transform.function_pass
# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
load_param_dict = param_dict.load_param_dict
# Pass manager
-PassInfo = ir_pass.PassInfo
-PassContext = ir_pass.PassContext
-Pass = ir_pass.Pass
-ModulePass = ir_pass.ModulePass
-FunctionPass = ir_pass.FunctionPass
-SequentialPass = ir_pass.SequentialPass
+PassInfo = transform.PassInfo
+PassContext = transform.PassContext
+Pass = transform.Pass
+ModulePass = transform.ModulePass
+FunctionPass = transform.FunctionPass
+Sequential = transform.Sequential
import tvm
from . import ir
-from .base import NodeBase
from .env import Module
-
-class PassContext(NodeBase):
- def __init__(self):
- ...
-
-class PassInfo(NodeBase):
- name = ... # type: str
- opt_level = ... # type: int
- required = ... # type: list
-
- def __init__(self, name, opt_level, required)
- # type: (str, int, list) -> None
-
-
-class Pass(NodeBase):
- def __init__(self):
- ...
-
-
-class ModulePass(Pass):
- name = ... # type: str
- opt_level = ... # type: int
- pass_func = ... # type: Callable
- required = ... # type: list
-
- def __init__(self, name, opt_level, pass_func, required):
- # type: (str, int, Callable, list) -> None
- ...
-
-
-class FunctionPass(Pass):
- name = ... # type: str
- opt_level = ... # type: int
- pass_func = ... # type: Callable
- required = ... # type: list
-
- def __init__(self, name, opt_level, pass_func, required):
- # type: (str, int, Callable, list) -> None
- ...
-
-
-class SequentialPass(Pass):
- name = ... # type: str
- opt_level = ... # type: int
- passes = ... # type: list
- required = ... # type: list
- disabled = ... # type: list
-
- def __init__(self, name, opt_level, passes, required, disabled):
- # type: (str, int, list, list, list) -> None
- ...
-
-
def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI exposing the Relay type inference and checking."""
+
+from tvm._ffi.function import _init_api
+
+_init_api("relay._transform", __name__)
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
"""
-This file contains:
-1. The set of passes for Relay, which exposes an interface for configuring the
- passes and scripting them in Python.
-
-2. The pass manager for Relay which exposes different granularity of interfaces
- for users to implement and use passes more conveniently.
+This file contains the set of passes for Relay, which exposes an interface for
+configuring the passes and scripting them in Python.
"""
-import types
-
from . import _ir_pass
from . import _make
from .expr import Expr
from .ty import Type
-from .base import RelayNode, register_relay_node
from .module import Module
-@register_relay_node
-class PassInfo(RelayNode):
- """The class that contains the meta data required by a pass. It is the
- container of information needed by running an optimization or analysis.
- This class can be extended by adding new members when more meta data is
- needed.
-
- Parameters
- ----------
- name : str
- The pass name.
-
- opt_level : int
- The optimization level of this pass.
-
- required : List[str]
- The list of passes that are required by a certain pass.
- """
-
- def __init__(self, name, opt_level, required=None):
- self.__init_handle_by_constructor__(_ir_pass.PassInfo, name, opt_level,
- required)
-
-
-@register_relay_node
-class PassContext(RelayNode):
- """The basis where a Relay optimization/analysis runs on.
- Each pass context contains a number of auxiliary information that is used
- to help an optimization pass. Such information includes the error reporter
- to record the errors of during the optimization, etc.
- """
-
- def __init__(self):
- self.__init_handle_by_constructor__(_ir_pass.PassContext)
-
-
-@register_relay_node
-class Pass(RelayNode):
- """The base class of all passes. All methods here are just simple wrappers
- that are implemented in the backend. They are defined for users to
- conveniently interact with the base class.
- """
-
- def set_pass_context(self, pass_ctx):
- """Setup the pass context for analysis and optimizations. This context
- could be shared by different passes for sequential passes.
-
- Parameters
- ----------
- pass_ctx : PassContext
- The context that is used to help perform a certain pass or a series
- of passes.
- """
- if not isinstance(pass_ctx, PassContext):
- raise TypeError("pass_ctx is expected to be the PassContext type")
- _ir_pass.SetContext(self, pass_ctx)
-
- @property
- def info(self):
- """Get the pass meta."""
- return _ir_pass.Info(self)
-
- def __call__(self, mod):
- """Execute the pass. Note that for sequential pass, the dependency among
- different passes will be resolved in the backend.
-
- Parameters
- ----------
- mod : tvm.relay.Module
- The module that a certain optimization is performed on.
-
- Returns
- -------
- mod : tvm.relay.Module
- The updated module after applying this pass.
- """
- return _ir_pass.RunPass(self, mod)
-
-
-@register_relay_node
-class ModulePass(Pass):
- """A pass that works on tvm.relay.Module. Users don't need to interact with
- this class directly. Instead, a module pass should be created through
- `module_pass`, because the design of the `module_pass` API is flexible
- enough to handle the creation of a module pass in different manners. In
- addition, all members of a module pass can be accessed from the base class.
- The same rule applies to FunctionPass and SequentialPass as well.
- """
-
-
-@register_relay_node
-class FunctionPass(Pass):
- """A pass that works on each tvm.relay.Function in a module. A function
- pass class should be created through `function_pass`.
- """
-
-
-@register_relay_node
-class SequentialPass(Pass):
- """A pass that works on a sequence of pass objects. A sequential pass class
- should be created through `sequential_pass`.
- """
-
-
-def module_pass(pass_func=None, opt_level=None, name=None, required=None):
- """Create a module pass. This function returns a callback when pass_func
- is provided. Otherwise, it returns the created module level pass using the
- given optimization function.
-
- Parameters
- ----------
- pass_func : Optional[Callable[(Module/Function, PassContext) ->
- Module/Function]]
- The implemented optimization pass.
-
- opt_level : int
- The optimization level of this module pass.
-
- name : Optional[str]
- The name of the module pass. The name could be empty. In this case, the
- name of the optimization function will be used as the pass name.
-
- required : Optional[List[str]]
- The list of passes that the module pass is dependent on.
-
- Returns
- -------
- create_module_pass : Union[Callable, ModulePass]
- The callable that will create a module pass is returned when
- pass_func is not passed in. Otherwise, a ModulePass object will be
- directly created.
-
- Examples
- --------
- The following code creates a module level pass and adds an abs function to
- the module.
-
- .. code-block:: python
-
- @relay.ir_pass.module_pass(opt_level=2)
- def transform(mod, ctx):
- tp = relay.TensorType((10,), "float32")
- x = relay.var("x", tp)
- gv = relay.GlobalVar("var")
- func = relay.Function([x], relay.abs(x))
- new_mod = relay.Module({gv: func})
- new_mod.update(mod)
- return new_mod
-
- module_pass = transform
- assert isinstance(module_pass, ir_pass.ModulePass)
- assert module_pass.info.opt_level == 2
-
- # Given a module m, the optimization could be invoked as the follwoing:
- updated_mod = module_pass(m)
- # Now a function abs should be added to the module m.
- """
-
- if opt_level is None:
- raise ValueError("Please provide opt_level for the module pass.")
-
- required = required if required else []
- if not isinstance(required, (list, tuple)):
- raise TypeError("Required is expected to be the type of " +
- "list/tuple.")
-
- def create_module_pass(pass_func):
- """Internal function that creates a module pass"""
- if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
- raise TypeError("pass_func must be a callable for Module pass")
-
- return _ir_pass.CreateModulePass(pass_func, opt_level,
- name if name else pass_func.__name__,
- required)
-
- if pass_func:
- return create_module_pass(pass_func)
- return create_module_pass
-
-
-def function_pass(pass_func=None, opt_level=None, name=None, required=None):
- """Create a function pass. This function returns a callback when pass_func
- is provided. Otherwise, it returns the created function pass using the
- given optimization function.
-
- Parameters
- ----------
- pass_func : Optional[Callable[(Module/Function, PassContext) ->
- Module/Function]]
- The implemented optimization pass.
-
- opt_level : int
- The optimization level of this module pass.
-
- name : Optional[str]
- The name of the function pass. The name could be empty. In this case, the
- name of the optimization function will be used as the pass name.
-
- required : Optional[List[str]]
- The list of passes that the module pass is dependent on.
-
- Returns
- -------
- create_function_pass : Union[Callable, FunctionPass]
- The callable that will create a function pass is returned when
- pass_func is not passed in. Otherwise, a FunctionPass object will be
- created.
-
- Examples
- --------
- The following code creates a function level pass that performs constant
- folding.
-
- .. code-block:: python
-
- @relay.ir_pass.function_pass(opt_level=2)
- def transform(func, ctx):
- return ir_pass.fold_constant(func)
-
- function_pass = transform
- assert isinstance(function_pass, ir_pass.FunctionPass)
- assert function_pass.info.opt_level == 2
-
- # Given a module m, the optimization could be invoked as the follwoing:
- updated_mod = function_pass(m)
- # Now constant folding should have been applied to every function in
- # the provided module m. And the updated module will be returned.
- """
-
- if opt_level is None:
- raise ValueError("Please provide opt_level for the funtion pass.")
-
- required = required if required else []
- if not isinstance(required, (list, tuple)):
- raise TypeError("Required is expected to be the type of " +
- "list/tuple.")
-
- def create_function_pass(pass_func):
- """Internal function that creates a function pass"""
- if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
- raise TypeError("pass_func must be a callable for Module pass")
-
- return _ir_pass.CreateFunctionPass(pass_func, opt_level,
- name if name else pass_func.__name__,
- required)
-
- if pass_func:
- return create_function_pass(pass_func)
- return create_function_pass
-
-
-def sequential_pass(passes=None, opt_level=2, name="sequential_pass",
- required=None, disabled=None):
- """Create a sequential pass using a defined optimization function from
- Python. Some typical usage of the sequential pass are:
- 1. Users provide a list of passes for optimization.
- 2. Only an optimization level is provided so that the backend system has
- to glob all passes at this level and below to perform the optimizations.
- Note that users can also provide a series of passes that they don't want to
- apply when running a sequential pass. Pass dependency will be resolved in
- the backend as well.
-
- Parameters
- ----------
- passes : Optional[List[Pass]]
- A sequence of passes candidate for optimization.
-
- opt_level : Optional[int]
- The optimization level of this sequential pass.
-
- name : Optional[str]
- The name of the sequential pass.
-
- required : Optional[List[str]]
- The list of passes that the sequential pass is dependent on.
-
- disabled : Optional[List[str]]
- A list of disabled passes.
-
- Returns
- -------
- ret : Pass
- A sequential pass built through pass_func.
- """
-
- passes = passes if passes else []
- if not isinstance(passes, (list, tuple)):
- raise TypeError("passes must be a list of Pass objects.")
-
- disabled = disabled if disabled else []
- if not isinstance(disabled, (list, tuple)):
- raise TypeError("disabled must be a list or tuple of pass names")
-
- required = required if required else []
- if not isinstance(required, (list, tuple)):
- raise TypeError("Required is expected to be the type of list/tuple.")
-
- return _ir_pass.CreateSequentialPass(passes, opt_level, name, required,
- disabled)
-
-
def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return
+# pylint: disable=unidiomatic-typecheck
+"""
+This file contains the pass manager for Relay which exposes different
+granularity of interfaces for users to implement and use passes more
+conveniently.
+"""
+import types
+
+from . import _transform
+from .base import RelayNode, register_relay_node
+
+
+@register_relay_node
+class PassInfo(RelayNode):
+ """The class that contains the meta data required by a pass. It is the
+ container of information needed by running an optimization or analysis.
+ This class can be extended by adding new members when more meta data is
+ needed.
+
+ Parameters
+ ----------
+ name : str
+ The pass name.
+
+ opt_level : int
+ The optimization level of this pass.
+
+ required : List[str]
+ The list of passes that are required by a certain pass.
+ """
+
+ def __init__(self, name, opt_level, required=None):
+ self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level,
+ required)
+
+
+@register_relay_node
+class PassContext(RelayNode):
+ """The basis where a Relay optimization/analysis runs on.
+ Each pass context contains a number of auxiliary information that is used
+ to help an optimization pass. Such information includes the error reporter
+ to record the errors of during the optimization, etc.
+ """
+
+ def __init__(self):
+ self.__init_handle_by_constructor__(_transform.PassContext)
+
+
+@register_relay_node
+class Pass(RelayNode):
+ """The base class of all passes. All methods here are just simple wrappers
+ that are implemented in the backend. They are defined for users to
+ conveniently interact with the base class.
+ """
+
+ def set_pass_context(self, pass_ctx):
+ """Setup the pass context for analysis and optimizations. This context
+ could be shared by different passes for sequential passes.
+
+ Parameters
+ ----------
+ pass_ctx : PassContext
+ The context that is used to help perform a certain pass or a series
+ of passes.
+ """
+ if not isinstance(pass_ctx, PassContext):
+ raise TypeError("pass_ctx is expected to be the PassContext type")
+ _transform.SetContext(self, pass_ctx)
+
+ @property
+ def info(self):
+ """Get the pass meta."""
+ return _transform.Info(self)
+
+ def __call__(self, mod):
+ """Execute the pass. Note that for sequential pass, the dependency among
+ different passes will be resolved in the backend.
+
+ Parameters
+ ----------
+ mod : tvm.relay.Module
+ The module that a certain optimization is performed on.
+
+ Returns
+ -------
+ mod : tvm.relay.Module
+ The updated module after applying this pass.
+ """
+ return _transform.RunPass(self, mod)
+
+
+@register_relay_node
+class ModulePass(Pass):
+ """A pass that works on tvm.relay.Module. Users don't need to interact with
+ this class directly. Instead, a module pass should be created through
+ `module_pass`, because the design of the `module_pass` API is flexible
+ enough to handle the creation of a module pass in different manners. In
+ addition, all members of a module pass can be accessed from the base class.
+ The same rule applies to FunctionPass and Sequential as well.
+ """
+
+
+@register_relay_node
+class FunctionPass(Pass):
+ """A pass that works on each tvm.relay.Function in a module. A function
+ pass class should be created through `function_pass`.
+ """
+
+
+@register_relay_node
+class Sequential(Pass):
+ """A pass that works on a sequence of pass objects. Multiple passes can be
+ executed sequentially using this class.
+
+ Some typical usage of the sequential pass are:
+ 1. Users provide a list of passes for optimization.
+ 2. Only an optimization level is provided so that the backend system has
+ to glob all passes at this level and below to perform the optimizations.
+ Note that users can also provide a series of passes that they don't want to
+ apply when running a sequential pass. Pass dependency will be resolved in
+ the backend as well.
+
+ Parameters
+ ----------
+ passes : Optional[List[Pass]]
+ A sequence of passes candidate for optimization.
+
+ opt_level : Optional[int]
+ The optimization level of this sequential pass.
+
+ name : Optional[str]
+ The name of the sequential pass.
+
+ required : Optional[List[str]]
+ The list of passes that the sequential pass is dependent on.
+
+ disabled : Optional[List[str]]
+ A list of disabled passes.
+ """
+
+ def __init__(self,
+ passes=None,
+ opt_level=2,
+ name="sequential",
+ required=None,
+ disabled=None):
+ passes = passes if passes else []
+ if not isinstance(passes, (list, tuple)):
+ raise TypeError("passes must be a list of Pass objects.")
+
+ disabled = disabled if disabled else []
+ if not isinstance(disabled, (list, tuple)):
+ raise TypeError("disabled must be a list or tuple of pass names")
+
+ required = required if required else []
+ if not isinstance(required, (list, tuple)):
+ raise TypeError("Required is expected to be the type of list/tuple.")
+
+ self.__init_handle_by_constructor__(_transform.Sequential,
+ passes, opt_level, name, required,
+ disabled)
+
+
+def module_pass(pass_func=None, opt_level=None, name=None, required=None):
+ """Create a module pass. This function returns a callback when pass_func
+ is provided. Otherwise, it returns the created module level pass using the
+ given optimization function.
+
+ Parameters
+ ----------
+ pass_func : Optional[Callable[(Module/Function, PassContext) ->
+ Module/Function]]
+ The implemented optimization pass.
+
+ opt_level : int
+ The optimization level of this module pass.
+
+ name : Optional[str]
+ The name of the module pass. The name could be empty. In this case, the
+ name of the optimization function will be used as the pass name.
+
+ required : Optional[List[str]]
+ The list of passes that the module pass is dependent on.
+
+ Returns
+ -------
+ create_module_pass : Union[Callable, ModulePass]
+ The callable that will create a module pass is returned when
+ pass_func is not passed in. Otherwise, a ModulePass object will be
+ directly created.
+
+ Examples
+ --------
+ The following code creates a module level pass and adds an abs function to
+ the module.
+
+ .. code-block:: python
+
+ @relay.transform.module_pass(opt_level=2)
+ def transform(mod, ctx):
+ tp = relay.TensorType((10,), "float32")
+ x = relay.var("x", tp)
+ gv = relay.GlobalVar("var")
+ func = relay.Function([x], relay.abs(x))
+ new_mod = relay.Module({gv: func})
+ new_mod.update(mod)
+ return new_mod
+
+ module_pass = transform
+ assert isinstance(module_pass, transform.ModulePass)
+ assert module_pass.info.opt_level == 2
+
+ # Given a module m, the optimization could be invoked as the follwoing:
+ updated_mod = module_pass(m)
+ # Now a function abs should be added to the module m.
+ """
+
+ if opt_level is None:
+ raise ValueError("Please provide opt_level for the module pass.")
+
+ required = required if required else []
+ if not isinstance(required, (list, tuple)):
+ raise TypeError("Required is expected to be the type of " +
+ "list/tuple.")
+
+ def create_module_pass(pass_func):
+ """Internal function that creates a module pass"""
+ if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
+ raise TypeError("pass_func must be a callable for Module pass")
+
+ return _transform.CreateModulePass(
+ pass_func, opt_level, name if name else pass_func.__name__,
+ required)
+
+ if pass_func:
+ return create_module_pass(pass_func)
+ return create_module_pass
+
+
+def function_pass(pass_func=None, opt_level=None, name=None, required=None):
+ """Create a function pass. This function returns a callback when pass_func
+ is provided. Otherwise, it returns the created function pass using the
+ given optimization function.
+
+ Parameters
+ ----------
+ pass_func : Optional[Callable[(Module/Function, PassContext) ->
+ Module/Function]]
+ The implemented optimization pass.
+
+ opt_level : int
+ The optimization level of this module pass.
+
+ name : Optional[str]
+ The name of the function pass. The name could be empty. In this case, the
+ name of the optimization function will be used as the pass name.
+
+ required : Optional[List[str]]
+ The list of passes that the module pass is dependent on.
+
+ Returns
+ -------
+ create_function_pass : Union[Callable, FunctionPass]
+ The callable that will create a function pass is returned when
+ pass_func is not passed in. Otherwise, a FunctionPass object will be
+ created.
+
+ Examples
+ --------
+ The following code creates a function level pass that performs constant
+ folding.
+
+ .. code-block:: python
+
+ @relay.transform.function_pass(opt_level=2)
+ def transform(func, ctx):
+ return ir_pass.fold_constant(func)
+
+ function_pass = transform
+ assert isinstance(function_pass, transform.FunctionPass)
+ assert function_pass.info.opt_level == 2
+
+ # Given a module m, the optimization could be invoked as the follwoing:
+ updated_mod = function_pass(m)
+ # Now constant folding should have been applied to every function in
+ # the provided module m. And the updated module will be returned.
+ """
+
+ if opt_level is None:
+ raise ValueError("Please provide opt_level for the funtion pass.")
+
+ required = required if required else []
+ if not isinstance(required, (list, tuple)):
+ raise TypeError("Required is expected to be the type of " +
+ "list/tuple.")
+
+ def create_function_pass(pass_func):
+ """Internal function that creates a function pass"""
+ if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
+ raise TypeError("pass_func must be a callable for Module pass")
+
+ return _transform.CreateFunctionPass(
+ pass_func, opt_level, name if name else pass_func.__name__,
+ required)
+
+ if pass_func:
+ return create_function_pass(pass_func)
+ return create_function_pass
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from .base import NodeBase
+
+
+class PassContext(NodeBase):
+ def __init__(self):
+ ...
+
+class PassInfo(NodeBase):
+ name = ... # type: str
+ opt_level = ... # type: int
+ required = ... # type: list
+
+ def __init__(self, name, opt_level, required)
+ # type: (str, int, list) -> None
+
+
+class Pass(NodeBase):
+ def __init__(self):
+ ...
+
+
+class ModulePass(Pass):
+ name = ... # type: str
+ opt_level = ... # type: int
+ pass_func = ... # type: Callable
+ required = ... # type: list
+
+ def __init__(self, name, opt_level, pass_func, required):
+ # type: (str, int, Callable, list) -> None
+ ...
+
+
+class FunctionPass(Pass):
+ name = ... # type: str
+ opt_level = ... # type: int
+ pass_func = ... # type: Callable
+ required = ... # type: list
+
+ def __init__(self, name, opt_level, pass_func, required):
+ # type: (str, int, Callable, list) -> None
+ ...
+
+
+class Sequential(Pass):
+ name = ... # type: str
+ opt_level = ... # type: int
+ passes = ... # type: list
+ required = ... # type: list
+ disabled = ... # type: list
+
+ def __init__(self, name, opt_level, passes, required, disabled):
+ # type: (str, int, list, list, list) -> None
+ ...
* \brief Relay pass manager implementation.
*/
#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/pass.h>
+#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
-namespace pass {
+namespace transform {
using tvm::IRPrinter;
RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);
-class SequentialPass;
-
/*!
- * \brief The SequentialPassNode contains a set of passes that transform Relay
+ * \brief The SequentialNode contains a set of passes that transform Relay
* programs from one AST to another semantically equivalent one.
*
* One example of this level of pass is that the pass manager needs to correctly
* perform a host of optimizations with a given optimization level and disabled
* passes.
*/
-class SequentialPassNode : public PassNode {
+class SequentialNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
passes.push_back(pass);
}
- TVM_DLL static SequentialPass make(tvm::Array<Pass> passes,
- PassInfo pass_info,
- tvm::Array<tvm::Expr> disabled);
-
/*!
* \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them.
*/
void SetContext(const PassContext& pass_ctx) final;
- static constexpr const char* _type_key = "relay.SequentialPass";
- TVM_DECLARE_NODE_TYPE_INFO(SequentialPassNode, PassNode);
+ static constexpr const char* _type_key = "relay.Sequential";
+ TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
private:
/*!
PassContext pass_ctx_;
};
-RELAY_DEFINE_NODE_REF(SequentialPass, SequentialPassNode, Pass);
-
PassInfo PassInfoNode::make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required) {
auto pass_info = make_node<PassInfoNode>();
return pval && pval->value != 0;
}
-SequentialPass SequentialPassNode::make(tvm::Array<Pass> passes,
- PassInfo pass_info,
- tvm::Array<tvm::Expr> disabled) {
- auto n = make_node<SequentialPassNode>();
+Sequential::Sequential(tvm::Array<Pass> passes,
+ PassInfo pass_info,
+ tvm::Array<tvm::Expr> disabled) {
+ auto n = make_node<SequentialNode>();
n->passes = std::move(passes);
n->pass_info = std::move(pass_info);
n->disabled = std::move(disabled);
- return SequentialPass(n);
+ node_ = std::move(n);
+}
+
+const SequentialNode* Sequential::operator->() const {
+ return static_cast<const SequentialNode*>(this->node_.get());
}
// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
-// a SequentialPass without the consideration of their orders. The phase
+// a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future.
-Module SequentialPassNode::operator()(const Module& module) const {
+Module SequentialNode::operator()(const Module& module) const {
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
return mod;
}
-void SequentialPassNode::ResolveDependency(const Module& mod) {
+void SequentialNode::ResolveDependency(const Module& mod) {
// TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass.
// 2. Only resolve the enabled passes.
<< "\n";
}
-std::vector<std::string> SequentialPassNode::DisabledPasses() const {
+std::vector<std::string> SequentialNode::DisabledPasses() const {
std::vector<std::string> ret;
for (const auto& it : disabled) {
const auto* str = it.as<tvm::ir::StringImm>();
return ret;
}
-void SequentialPassNode::SetContext(const PassContext& pass_ctx) {
+void SequentialNode::SetContext(const PassContext& pass_ctx) {
pass_ctx_ = pass_ctx;
}
return FunctionPassNode::make(pass_func, pass_info);
}
-Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
- int opt_level,
- const std::string& name,
- const tvm::Array<tvm::Expr>& required,
- const tvm::Array<tvm::Expr>& disabled) {
- PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
- return SequentialPassNode::make(passes, pass_info, disabled);
-}
-
TVM_REGISTER_NODE_TYPE(PassInfoNode);
-TVM_REGISTER_API("relay._ir_pass.PassInfo")
+TVM_REGISTER_API("relay._transform.PassInfo")
.set_body_typed(PassInfoNode::make);
-TVM_REGISTER_API("relay._ir_pass.Info")
+TVM_REGISTER_API("relay._transform.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
*ret = pass->Info();
TVM_REGISTER_NODE_TYPE(ModulePassNode);
-TVM_REGISTER_API("relay._ir_pass.CreateModulePass")
+TVM_REGISTER_API("relay._transform.CreateModulePass")
.set_body_typed(CreateModulePass);
-TVM_REGISTER_API("relay._ir_pass.RunPass")
+TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
Module mod = args[1];
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
-TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass")
+TVM_REGISTER_API("relay._transform.CreateFunctionPass")
.set_body_typed(CreateFunctionPass);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< " at the optimization level " << pn->opt_level;
});
-TVM_REGISTER_NODE_TYPE(SequentialPassNode);
+TVM_REGISTER_NODE_TYPE(SequentialNode);
-TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass")
+TVM_REGISTER_API("relay._transform.Sequential")
.set_body([](TVMArgs args, TVMRetValue* ret) {
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
tvm::Array<tvm::Expr> required = args[3];
tvm::Array<tvm::Expr> disabled = args[4];
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
- *ret = SequentialPassNode::make(passes, pass_info, disabled);
+ *ret = Sequential(passes, pass_info, disabled);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<SequentialPassNode>([](const SequentialPassNode* node,
- tvm::IRPrinter* p) {
+.set_dispatch<SequentialNode>([](const SequentialNode* node,
+ tvm::IRPrinter* p) {
const PassInfoNode* seq_pn = node->Info().operator->();
- p->stream << "Run SequentialPass pass: " << seq_pn->name
+ p->stream << "Run Sequential pass: " << seq_pn->name
<< " at the optimization level. " << seq_pn->opt_level;
p->stream << "The passes will be executed are: [";
for (const auto& it : node->passes) {
p->stream << "]";
});
-TVM_REGISTER_API("relay._ir_pass.SetContext")
+TVM_REGISTER_API("relay._transform.SetContext")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
PassContext pass_ctx = args[1];
TVM_REGISTER_NODE_TYPE(PassContextNode);
-TVM_REGISTER_API("relay._ir_pass.PassContext")
+TVM_REGISTER_API("relay._transform.PassContext")
.set_body_typed(PassContextNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< "\n";
});
-} // namespace pass
+} // namespace transform
} // namespace relay
} // namespace tvm
from tvm.relay import ExprFunctor
from tvm.relay import Function, Call
from tvm.relay import ir_pass
+from tvm.relay import transform as _transform
from tvm.relay.testing import ctx_list
opt_tester = OptTester(mod)
pass_ctx = None
- @ir_pass.module_pass(opt_level=opt_level, name=pass_name)
+ @_transform.module_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx):
return opt_tester.transform(expr, ctx)
def test_pass_registration():
mod_pass = transform
- assert isinstance(mod_pass, ir_pass.ModulePass)
+ assert isinstance(mod_pass, _transform.ModulePass)
pass_info = mod_pass.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
def test_pass_registration_no_decorator():
def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
- mod_pass = ir_pass.module_pass(direct_transform, opt_level=3)
- assert isinstance(mod_pass, ir_pass.ModulePass)
+ mod_pass = _transform.module_pass(direct_transform, opt_level=3)
+ assert isinstance(mod_pass, _transform.ModulePass)
pass_info = mod_pass.info
assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 3
opt_tester = OptTester(mod)
pass_ctx = None
- @ir_pass.function_pass(opt_level=opt_level, name=pass_name)
+ @_transform.function_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx):
return opt_tester.transform(expr, ctx)
def test_pass_registration():
function_pass = transform
- assert isinstance(function_pass, ir_pass.FunctionPass)
+ assert isinstance(function_pass, _transform.FunctionPass)
pass_info = function_pass.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
def test_pass_registration_no_decorator():
def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
- mod_pass = ir_pass.function_pass(direct_transform, opt_level=0)
- assert isinstance(mod_pass, ir_pass.FunctionPass)
+ mod_pass = _transform.function_pass(direct_transform, opt_level=0)
+ assert isinstance(mod_pass, _transform.FunctionPass)
pass_info = mod_pass.info
assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 0
opt_tester = OptTester(mod)
pass_ctx = None
- @ir_pass.module_pass(opt_level=1)
+ @_transform.module_pass(opt_level=1)
def mod_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
module_pass = mod_transform
# Register a function pass.
- @ir_pass.function_pass(opt_level=1)
+ @_transform.function_pass(opt_level=1)
def func_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
def test_pass_registration():
passes = [module_pass, function_pass]
opt_level = 2
- pass_name = "sequential_pass"
- sequential_pass = ir_pass.sequential_pass(passes=passes,
- opt_level=opt_level)
- assert isinstance(sequential_pass, ir_pass.SequentialPass)
- pass_info = sequential_pass.info
+ pass_name = "sequential"
+ sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
+ pass_info = sequential.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
def test_no_pass():
passes = []
- sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes)
- ret_mod = sequential_pass(mod)
+ sequential = _transform.Sequential(opt_level=1, passes=passes)
+ ret_mod = sequential(mod)
mod_func = ret_mod[v_sub]
check_func(sub, mod_func)
def test_only_module_pass():
passes = [module_pass]
- sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes)
- ret_mod = sequential_pass(mod)
+ sequential = _transform.Sequential(opt_level=1, passes=passes)
+ ret_mod = sequential(mod)
# Check the subtract function.
sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, sub)
def test_only_function_pass():
# Check the subtract function.
passes = [function_pass]
- sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes)
- ret_mod = sequential_pass(mod)
+ sequential = _transform.Sequential(opt_level=1, passes=passes)
+ ret_mod = sequential(mod)
_, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, get_ref_sub())
# function pass.
mod = relay.Module({v_sub: sub, v_log: log})
passes = [module_pass, function_pass]
- sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes)
- ret_mod = sequential_pass(mod)
+ sequential = _transform.Sequential(opt_level=1, passes=passes)
+ ret_mod = sequential(mod)
# Check the abs function is added.
abs_var, abs_func = get_var_func()