* \param update Controls whether you can replace a definition in the
* environment.
*/
- void Add(const GlobalVar& var, const Function& func, bool update = false);
+ TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
/*!
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The type definition.
*/
- void AddDef(const GlobalTypeVar& var, const TypeData& type);
+ TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type);
/*!
* \brief Add a function to the global environment.
*
* It does not do type inference as Add does.
*/
- void AddUnchecked(const GlobalVar& var, const Function& func);
+ TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
- void Update(const GlobalVar& var, const Function& func);
+ TVM_DLL void Update(const GlobalVar& var, const Function& func);
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
- void Remove(const GlobalVar& var);
+ TVM_DLL void Remove(const GlobalVar& var);
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
- GlobalVar GetGlobalVar(const std::string& str);
+ TVM_DLL GlobalVar GetGlobalVar(const std::string& str);
/*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
- GlobalTypeVar GetGlobalTypeVar(const std::string& str);
+ TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str);
/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
- Function Lookup(const GlobalVar& var);
+ TVM_DLL Function Lookup(const GlobalVar& var);
/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
- Function Lookup(const std::string& name);
+ TVM_DLL Function Lookup(const std::string& name);
/*!
* \brief Lookup a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
- TypeData LookupDef(const GlobalTypeVar& var);
+ TVM_DLL TypeData LookupDef(const GlobalTypeVar& var);
/*!
* \brief Lookup a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
- TypeData LookupDef(const std::string& var);
+ TVM_DLL TypeData LookupDef(const std::string& var);
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
- void Update(const Module& other);
+ TVM_DLL void Update(const Module& other);
/*! \brief Construct a module from a standalone expression.
*
*
* \returns A module with expr set as the entry point.
*/
- static Module FromExpr(
+ TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
/*!
+ * \brief Collect the device anntation operators.
+ *
+ * \param expr The expression.
+ *
+ * \return The annotated expression to device type mapping for annotation ops.
+ */
+TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
+
+/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
*/
TVM_DLL Expr PartialEval(const Expr& e);
+/*!
+ * \brief Bind the free variables to a Relay expression.
+ *
+ * \param expr The expression.
+ * \param bind_map The variable to expression map that will be used to help the
+ * binding.
+ *
+ * \return The updated expression.
+ */
+TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);
+
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
+#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
+#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
#include <unordered_map>
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
- TVM_DLL Sequential(tvm::Array<Pass> passes,
- PassInfo pass_info);
-/*!
+ TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
+
+ /*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
using ContainerType = Sequential;
};
-
/*
* \brief Create a module pass.
*
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
- Function(Function, Module, PassContext)>& pass_func,
+ Function(Function, Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
*/
TVM_DLL Pass PartialEval();
+/*!
+ * \brief Simplify certain operators during inference. For example, batch norm
+ * will be unpacked into a number of simplified operators.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass SimplifyInference();
+
+/*!
+ * \brief Infer the type of an expression.
+ *
+ * The result of type checking is a new expression with unambigous
+ * type information filled in, as well as it's checked type field
+ * populated with the result type.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass InferType();
+
+/*!
+ * \brief Search and eliminate common subexpression. For example, if there are
+ * two expressions evaluated to an identical value, a single variable is created
+ * and these two expressions are replaced by this variable.
+ *
+ * \param fskip The callback argument that allows to skip certain expressions.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
+
+/*!
+ * \brief Combine parallel 2d convolutions into a single convolution if the
+ * number of branches of this conv2d operator is not less than
+ * `min_num_branch`.
+ *
+ * \param min_num_branches The minimun number of branches.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
+
+/*!
+ * \brief Backward fold axis scaling into weights of conv/dense operators.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass BackwardFoldScaleAxis();
+
+/*!
+ * \brief Forward fold axis scaling into weights of conv/dense operators.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ForwardFoldScaleAxis();
+
+/*!
+ * \brief A sequential pass that executes ForwardFoldScaleAxis and
+ * BackwardFoldScaleAxis passes.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass FoldScaleAxis();
+
+/*!
+ * \brief Canonicalize some operators to the simplified operators. For example,
+ * bias_add can be canonicalized to expand_dims and broadcast_add.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CanonicalizeOps();
+
+/*!
+ * \brief Alternate the layouts of operators or replace primitive operators
+ * with other expressions.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass AlterOpLayout();
+
} // namespace transform
} // namespace relay
} // namespace tvm
"""
import numpy as np
-from tvm._ffi.runtime_ctypes import TVMContext
from tvm import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from . import ty as _ty
from . import expr as _expr
-from . import transform as _transform
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
- self._add_pass = self.mod["add_pass"]
- self._disable_pass = self.mod["disable_pass"]
- self._set_opt_level = self.mod["set_opt_level"]
- self._set_fallback_device = self.mod["set_fallback_device"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
"""
target = _update_target(target)
- # Setup the build configurations passed in through `with build_config`.
- self._setup_build_config(params)
+ # Setup the params.
+ if params:
+ self._set_params(params)
# Build the function
self._build(func, target, target_host)
# Get artifacts
return graph_json, mod, params
- def _setup_build_config(self, params):
- cfg = _transform.PassContext.current()
-
- # Set opt_level.
- self.set_opt_level(cfg.opt_level)
-
- # Set fallback device if it is available.
- if cfg.fallback_device:
- self.set_fallback_device(cfg.fallback_device)
-
- # Add required passes.
- if cfg.required_pass:
- passes = set()
- if isinstance(cfg.required_pass, (list, tuple, set)):
- passes = set(cfg.required_pass)
- else:
- raise TypeError("add_pass must be list, tuple, or set, but " +
- "got {}".format(type(cfg.required_pass)))
- for pass_name in passes:
- self.add_pass(pass_name)
-
- # Add disabled passes.
- if cfg.disabled_pass:
- passes = set()
- if isinstance(cfg.disabled_pass, (list, tuple, set)):
- passes = set(cfg.disabled_pass)
- else:
- raise TypeError("disable_pass must be list, tuple, or set, " +
- "but got {}".format(type(cfg.disabled_pass)))
- for pass_name in passes:
- self.disable_pass(pass_name)
-
- if params:
- self._set_params(params)
-
def _set_params(self, params):
inputs = {}
for name, param in params.items():
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
- def add_pass(self, pass_name):
- """Add a pass to the pass list.
-
- Parameters
- ----------
- pass_name : str
- The name of the pass that will be added to the list of passes used
- for optimizations.
- """
- self._add_pass(pass_name)
-
- def disable_pass(self, pass_name):
- """Add a pass to the disabled pass list.
-
- Parameters
- ----------
- pass_name : str
- The name of a pass. This pass will be added to the list of passes
- that are disabled during optimization.
- """
- self._disable_pass(pass_name)
-
def get_json(self):
"""Return the json file of the built program."""
return self._get_graph_json()
ret[key] = value.data
return ret
- def set_opt_level(self, level):
- """Set the optimization level.
-
- Parameters
- ----------
- level : int
- The optimization level for build.
- """
- self._set_opt_level(level)
-
- def set_fallback_device(self, fallback_device):
- """Set the fallback device for heterogeneous execution.
-
- Parameters
- ----------
- fallback_device : str or tvm.TVMContext
- The fallback device used for heterogeneous execution.
- """
- if isinstance(fallback_device, (int, str)):
- fallback_device = _nd.context(fallback_device)
- if not isinstance(fallback_device, TVMContext):
- raise TypeError("fallback_device is expected to be str, int, or " +
- "TVMContext but received: {}".format(type(fallback_device)))
-
- self._set_fallback_device(fallback_device.device_type)
-
def build(func, target=None, target_host=None, params=None):
"""Helper function that builds a Relay function to run on TVM graph
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
+# pylint: disable=invalid-name
"""
This file contains the pass manager for Relay which exposes different
granularity of interfaces for users to implement and use passes more
if pass_func:
return create_function_pass(pass_func)
return create_function_pass
+
+
+def InferType():
+ """Infer the type of an expr.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered type inference pass.
+ """
+ return _transform.InferType()
+
+
+def FoldScaleAxis():
+ """Fold the scaling of axis into weights of conv2d/dense. This pass will
+ invoke both forward and backward scale folding.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass to fold expressions.
+
+ Note
+ ----
+ Internally, we will call backward_fold_scale_axis before using
+ forward_fold_scale_axis. As backward folding targets common conv-bn
+ pattern.
+ """
+ return _transform.FoldScaleAxis()
+
+
+def SimplifyInference():
+ """Simplify the data-flow graph for inference phase. An simplified expression
+ which is semantically equal to the input expression will be returned.
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered to perform operator simplification.
+ """
+ return _transform.SimplifyInference()
+
+
+def CanonicalizeOps():
+ """ Canonicalize special operators to basic operators.
+ This can simplify followed analysis. (e.g. expanding bias_add to
+ expand_dims and broadcast_add.)
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass performing the canonicalization.
+ """
+ return _transform.CanonicalizeOps()
+
+
+def DeadCodeElimination():
+ """ Remove expressions which does not effect the program result (dead code).
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass that eliminates the dead code in a Relay program.
+ """
+ return _transform.DeadCodeElimination()
+
+
+def FoldConstant():
+ """Fold the constant expression in expr.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass for constant folding.
+ """
+ return _transform.FoldConstant()
+
+
+def FuseOps(fuse_opt_level=-1):
+ """Fuse operators in an expr to a larger operator according to some rules.
+
+ Parameters
+ ----------
+ fuse_opt_level : int
+ The level of fuse optimization. -1 indicates that the level will be
+ inferred from pass context.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass for operator fusion.
+ """
+ return _transform.FuseOps(fuse_opt_level)
+
+
+def CombineParallelConv2D(min_num_branches=3):
+ """Combine multiple conv2d operators into one.
+
+ Parameters
+ ----------
+ min_num_branches : int
+ The minimum number of required parallel branches for performing this
+ optimization.
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass that combines parallel conv2d operators.
+ """
+ return _transform.CombineParallelConv2D(min_num_branches)
+
+
+def AlterOpLayout():
+ """Alternate the layouts of operators or replace primitive operators with
+ other expressions.
+ This pass can be used for computing convolution in custom layouts or
+ other general weight pre-transformation.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass that alters the layout of operators.
+ """
+ return _transform.AlterOpLayout()
+
+
+def RewriteAnnotatedOps(fallback_device):
+ """Rewrite the annotated program where annotation operators, e.g.
+ `on_deivce`, mark which device an expression should be scheduled to.
+ This pass helps heterogeneous execution where different operators may need
+ to be allocated on various devices.
+
+ Parameters
+ ----------
+ fallback_device : int
+ The fallback device type. It is also used as the default device for
+ operators with no annotated device.
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass that rewrites an expression with annotated
+ `on_device` operators.
+ """
+ return _transform.RewriteDeviceAnnotation(fallback_device)
+
+
+def ToANormalForm():
+ """Turn Graph Normal Form expression into A Normal Form Expression.
+ The scope of the root expression is the global scope.
+ The scope of any non root expression is the least common ancestor of all it's scope.
+ Values are ordered by post-DFS order in each scope.
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass that transforms an expression into A Normal Form.
+ """
+ return _transform.ToANormalForm()
+
+
+def ToGraphNormalForm():
+ """Turn A Normal Form expression into Graph Normal Form expression
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass that transforms an expression into Graph Normal Form.
+ """
+ return _transform.ToGraphNormalForm()
+
+
+def EliminateCommonSubexpr(fskip=None):
+ """Eliminate common subexpressions.
+
+ Parameters
+ ----------
+ fskip: Callable
+ The callback function that decides whether an expression should be
+ skipped.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass that eliminates common subexpressions.
+ """
+ return _transform.EliminateCommonSubexpr(fskip)
+
+
+def PartialEvaluate():
+ """Evaluate the static fragment of the code.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass that performs partial evaluation on an expression.
+ """
+ return _transform.PartialEvaluate()
*/
#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
-#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
-#include <tvm/relay/attrs/nn.h>
-#include <tvm/relay/attrs/transform.h>
-#include <vector>
-#include <string>
+#include <tvm/relay/transform.h>
#include <memory>
#include "utils.h"
namespace backend {
using TargetsMap = Map<tvm::Integer, tvm::Target>;
-
-/*!
- * \brief A data structure to map the names of specific optimizations to
- * numeric optimization levels
- *
- */
-struct OptPassLevel {
- static const std::unordered_map<std::string, int> _data;
- /*!
- * \brief Get level for an optimization pass
- *
- * \param key pass name
- * \return int level
- */
- int operator[](const std::string& key) const {
- auto it = _data.find(key);
- if (it == _data.end()) {
- return -1;
- }
- return it->second;
- }
-};
-
-const std::unordered_map<std::string, int> OptPassLevel::_data = {
- {"SimplifyInference", 0},
- {"OpFusion", 1},
- {"FoldConstant", 2},
- {"CombineParallelConv2D", 4},
- {"FoldScaleAxis", 3},
- {"AlterOpLayout", 3},
- {"CanonicalizeOps", 3},
- {"EliminateCommonSubexpr", 3}
-};
+using namespace tvm::relay::transform;
/*!
* \brief Output of building module
};
/*!
- * \brief Relay building config
- *
- */
-struct RelayBuildConfig {
- int opt_level{2};
- int fallback_device{static_cast<int>(kDLCPU)};
- std::unordered_set<std::string> enabled_pass;
- std::unordered_set<std::string> disabled_pass;
- OptPassLevel OPT_PASS_LEVEL;
- inline bool pass_enabled(const std::string& pass_name) const {
- if (disabled_pass.count(pass_name)) {
- return false;
- }
- if (enabled_pass.count(pass_name)) {
- return true;
- }
- return opt_level >= OPT_PASS_LEVEL[pass_name];
- }
-};
-
-/*!
* \brief GraphCodegen module wrapper
*
*/
}
};
-template<typename R, typename ...Args>
-R CallPackedFunc(const std::string &name, Args... args) {
- auto pf = GetPackedFunc(name);
- return (*pf)(std::forward<Args>(args)...);
-}
-
-template<typename ...Args>
-Function CallPackedFunc(const std::string &name, Args... args) {
- auto pf = GetPackedFunc(name);
- return (*pf)(std::forward<Args>(args)...);
-}
-
/*!
* \brief Relay build module
*
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetParams();
});
- } else if (name == "set_opt_level") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- CHECK_EQ(args.num_args, 1);
- int level = args[0];
- this->SetOptLevel(level);
- });
- } else if (name == "set_fallback_device") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- CHECK_EQ(args.num_args, 1);
- int dev = args[0];
- this->SetFallBackDev(dev);
- });
- } else if (name == "add_pass") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- std::string pass_name = args[0];
- this->AddPass(pass_name);
- });
- } else if (name == "disable_pass") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- std::string pass_name = args[0];
- this->DisablePass(pass_name);
- });
} else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[0];
const std::string& GetGraphJSON() {
return ret_.graph_json;
}
- /*!
- * \brief Add extra pass into build cfg
- *
- * \param pass_name name of pass
- */
- void AddPass(const std::string& pass_name) {
- cfg_.enabled_pass.insert(pass_name);
- }
- /*!
- * \brief Disable a specific pass in cfg
- *
- * \param pass_name name of pass
- */
- void DisablePass(const std::string& pass_name) {
- cfg_.disabled_pass.insert(pass_name);
- }
- /*!
- * \brief Set the Fallback device
- *
- * \param device name
- */
- void SetFallBackDev(int dev) {
- cfg_.fallback_device = dev;
- }
+
/*!
* \brief Get the Module object
*
}
/*!
- * \brief Set the optimization level
- *
- * \param level
- */
- void SetOptLevel(char level) {
- cfg_.opt_level = level;
- }
-
- /*!
* \brief type key
*
* \return const char*
const tvm::Target& target_host) {
targets_ = targets;
target_host_ = target_host;
- BuildRelay(func, cfg_, params_);
+ BuildRelay(func, params_);
}
protected:
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
- auto e = CallPackedFunc<Expr>("relay._make.Constant", kv.second);
- bind_dict[arg] = e;
+ bind_dict[arg] = ConstantNode::make(kv.second);
}
- return CallPackedFunc("relay._expr.Bind", func, tvm::Map<relay::Var, Expr>(bind_dict));
+ Expr bound_expr = relay::Bind(func, bind_dict);
+ Function ret = Downcast<Function>(bound_expr);
+ CHECK(ret.defined())
+ << "The returning type is expected to be a Relay Function."
+ << "\n";
+ return ret;
}
/*!
- * \brief Optimize Relay function
+ * \brief Optimize a Relay module.
*
- * \param func Input function
- * \param target target device
- * \param cfg Relay build config
- * \param params params dict
- * \return relay::Function
+ * \param relay_module The input Relay module where optmization will be
+ * applied on.
+ * \param targets The device type to `Target` mapping.
+ * \param params The param name to value mapping.
+ *
+ * \return relay::Module The updated Relay module after optimization.
*/
- relay::Function Optimize(relay::Function func,
- const TargetsMap& targets,
- const RelayBuildConfig& cfg,
- const std::unordered_map<std::string, runtime::NDArray>& params) {
- if (params.size()) {
- func = BindParamsByName(func, params);
- }
- if (cfg.pass_enabled("SimplifyInference")) {
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- func = CallPackedFunc("relay._ir_pass.simplify_inference", func);
- }
- if (cfg.pass_enabled("EliminateCommonSubexpr")) {
- auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
- Expr expr = args[0];
- if (expr.as<CallNode>()) {
- auto call_node = expr.as<CallNode>();
- auto op_node = call_node->op.as<OpNode>();
- if (op_node->name == "cast") {
- auto attrs = call_node->attrs.as<CastAttrs>();
- if (attrs->dtype == HalideIR::Int(32)) {
- *rv = true;
- }
+ relay::Module Optimize(
+ relay::Module relay_module,
+ const TargetsMap& targets,
+ const std::unordered_map<std::string, runtime::NDArray>& params) {
+ Array<Pass> pass_seqs;
+ pass_seqs.push_back(transform::SimplifyInference());
+ PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
+ Expr expr = args[0];
+ if (expr.as<CallNode>()) {
+ auto call_node = expr.as<CallNode>();
+ auto op_node = call_node->op.as<OpNode>();
+ if (op_node->name == "cast") {
+ auto attrs = call_node->attrs.as<CastAttrs>();
+ if (attrs->dtype == HalideIR::Int(32)) {
+ *rv = true;
}
}
- *rv = false;
- });
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip);
- }
- if (cfg.pass_enabled("CombineParallelConv2D")) {
- const int min_num_branches = 3;
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches);
- }
- if (cfg.pass_enabled("FoldConstant")) {
- func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
- }
- if (cfg.pass_enabled("FoldScaleAxis")) {
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func);
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func);
- func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
- }
- if (cfg.pass_enabled("CanonicalizeOps")) {
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func);
+ }
+ *rv = false;
+ });
+ pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
+ pass_seqs.push_back(transform::CombineParallelConv2D(3));
+ pass_seqs.push_back(transform::FoldConstant());
+ pass_seqs.push_back(transform::FoldScaleAxis());
+ pass_seqs.push_back(transform::CanonicalizeOps());
+
+ // Alter layout transformation is only applied to homogeneous execution yet.
+ if (targets.size() == 1) {
+ pass_seqs.push_back(transform::AlterOpLayout());
}
- if (cfg.pass_enabled("AlterOpLayout")) {
- if (targets.size() == 1) {
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- for (const auto& kv : targets) {
- With<Target> tctx(kv.second);
- func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
- }
- } else {
- LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
- << " execution yet.";
+ pass_seqs.push_back(transform::FoldConstant());
+
+ // Create a sequential pass and perform optimizations.
+ transform::Pass seq = transform::Sequential(pass_seqs);
+ if (targets.size() == 1) {
+ for (const auto& kv : targets) {
+ With<Target> tctx(kv.second);
+ relay_module = seq(relay_module);
}
+ } else {
+ relay_module = seq(relay_module);
}
- if (cfg.pass_enabled("FoldConstant")) {
- func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
+
+ // Handle heterogeneous compilation.
+ transform::PassContext pass_ctx = PassContext::Current();
+ if (targets_.size() > 1) {
+ relay_module =
+ RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
}
- return func;
+
+ // Fuse the operations if it is needed.
+ relay_module = transform::FuseOps()(relay_module);
+ relay_module = transform::InferType()(relay_module);
+
+ return relay_module;
}
/*!
if (name == "gpu") return Target::Create("cuda");
return Target::Create(name);
}
+
/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
* Meanwhile, a CPU device type and "llvm" pair will be added to the target
* dictionary in this case.
*
- * \param targets dictionary
- * \param cfg
- * \return Map<tvm::Integer, tvm::Target>
+ * \param fallback_device The fallback device for heterogeneous execution.
*/
- TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets,
- const RelayBuildConfig& cfg) {
- TargetsMap device_target = targets;
+ void UpdateHeterogeneousInputs(int fallback_device) {
std::unordered_map<int64_t, tvm::Target> tmp_map;
- for (const auto& kv : targets) {
+ for (const auto& kv : targets_) {
tmp_map[kv.first->value] = kv.second;
}
- if (tmp_map.count(cfg.fallback_device) == 0) {
- device_target.Set(
- cfg.fallback_device,
- CreateDefaultTarget(cfg.fallback_device));
+ if (tmp_map.count(fallback_device) == 0) {
+ targets_.Set(fallback_device, CreateDefaultTarget(fallback_device));
}
- return device_target;
}
+
/*!
* \brief Execute the device annotation passes to update the input program and
* target information.
*
- * \param func
- * \param cfg
- * \param targets_map_ptr
- * \return Function
+ * \param relay_module The input Relay module.
+ * \param fallback_device The fallback device for heterogeneous execution.
+ *
+ * \return updated_module The updated module after device annotation.
*/
- Function RunDeviceAnnotationPass(Function func,
- const RelayBuildConfig& cfg,
- TargetsMap* targets_map_ptr) {
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
- cfg.fallback_device);
- auto device_map = CallPackedFunc<Map<Expr, Integer> >(
- "relay._ir_pass.CollectDeviceInfo", func, nullptr);
- if (device_map.size() == 0) {
- auto annotation_map = CallPackedFunc<Map<Expr, Integer> >(
- "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
- if (annotation_map.size() == 0) {
- targets_map_ptr->Set(
- 0, CreateDefaultTarget(cfg.fallback_device));
+ relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module,
+ int fallback_device) {
+ UpdateHeterogeneousInputs(fallback_device);
+ auto rewrite = transform::RewriteAnnotatedOps(fallback_device);
+ auto updated_module = rewrite(relay_module);
+ CHECK(updated_module.defined());
+
+ tvm::Map<Expr, Integer> device_map;
+ for (const auto& it : updated_module->functions) {
+ device_map = relay::CollectDeviceInfo(it.second);
+ if (!device_map.empty()) break;
+ }
+
+ if (device_map.empty()) {
+ tvm::Map<Expr, Integer> annotation_map;
+ for (const auto& it : relay_module->functions) {
+ annotation_map = relay::CollectDeviceAnnotationOps(it.second);
+ if (!annotation_map.empty()) break;
+ }
+ // None op is annotated but they are fallen back to the default device.
+ if (annotation_map.empty()) {
+ targets_.Set(0, CreateDefaultTarget(fallback_device));
} else {
+ // All ops are annotated to the same device type.
int64_t dev_type = -1;
for (auto kv : annotation_map) {
dev_type = kv.second->value;
<< "found. Please check the "
<< "RewriteAnnotation pass.";
}
- targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
+ targets_.Set(0, CreateDefaultTarget(dev_type));
}
}
- return func;
+ return updated_module;
}
/*!
* \brief Build relay function to runtime module
*
* \param func Relay Function
- * \param cfg Relay build config
* \param params parameters
*/
- void BuildRelay(Function func,
- const RelayBuildConfig& cfg,
- const std::unordered_map<std::string, tvm::runtime::NDArray> ¶ms) {
- // convert
- tvm_cfg_ = BuildConfig::Create();
- TargetsMap device_target;
- if (targets_.size() > 1) {
- device_target = UpdateHeterogeneousInputs(targets_, cfg);
- } else {
- device_target = targets_;
- }
- func = Optimize(func, targets_, cfg, params);
- if (device_target.size() > 1) {
- func = RunDeviceAnnotationPass(func, cfg, &device_target);
+ void BuildRelay(
+ Function func,
+ const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
+ if (params.size()) {
+ func = BindParamsByName(func, params);
}
- // TODO(@jroesch): use the passes directly.
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
- func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level, nullptr);
- func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+ // Perform Module->Module optimizations.
+ relay::Module relay_module = relay::ModuleNode::FromExpr(func);
+ relay_module = Optimize(relay_module, targets_, params);
+ CHECK(relay_module.defined());
+ // Get the updated function.
+ func = relay_module->Lookup(relay_module->entry_func->name_hint);
+
+ // Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
- graph_codegen_->Init(nullptr, device_target);
+ graph_codegen_->Init(nullptr, targets_);
graph_codegen_->Codegen(func);
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
- ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_);
+ ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_,
+ BuildConfig::Current());
}
protected:
TargetsMap targets_;
/*! \brief target host device */
tvm::Target target_host_;
- /*! \brief frontend optimization configure */
- RelayBuildConfig cfg_;
/*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_;
/*! \brief building output */
BuildOutput ret_;
- /*! \brief tvm building cfg */
- BuildConfig tvm_cfg_;
};
runtime::Module RelayBuildCreate() {
#include <tvm/relay/pass.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/transform.h>
#include <tvm/tvm.h>
#include <tuple>
#include <vector>
// Limiations:
// 1. the altered op should have the same number of arguments as the previous one
// 2. do not support nested tuple arguments
-TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
+Expr AlterOpLayout(const Expr& expr) {
TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>());
auto fcontext = [&](const Call& call) -> NodeRef{
return transformMemorizer;
};
- *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext);
-});
+ return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext);
+}
+
+TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
+.set_body_typed(AlterOpLayout);
} // namespace alter_op_layout
+namespace transform {
+
+Pass AlterOpLayout() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
+ };
+ return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
+ {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.AlterOpLayout")
+.set_body_typed(AlterOpLayout);
+
+} // namespace transform
+
} // namespace relay
} // namespace tvm
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
#include "pattern_util.h"
namespace tvm {
TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
.set_body_typed(CanonicalizeOps);
+namespace transform {
+
+Pass CanonicalizeOps() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(CanonicalizeOps(f));
+ };
+ return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
+ {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.CanonicalizeOps")
+.set_body_typed(CanonicalizeOps);
+
+} // namespace transform
+
} // namespace relay
} // namespace tvm
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body_typed(CombineParallelConv2D);
+namespace transform {
+
+Pass CombineParallelConv2D(uint64_t min_num_branches) {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
+ };
+ return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
+ {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.CombineParallelConv2D")
+.set_body_typed(CombineParallelConv2D);
+
+} // namespace transform
+
} // namespace relay
} // namespace tvm
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f));
};
- return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {});
+ return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}
+TVM_REGISTER_API("relay._transform.DeadCodeElimination")
+.set_body_typed(DeadCodeElimination);
+
} // namespace transform
} // namespace relay
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
+#include <tvm/relay/transform.h>
#include <memory>
#include <unordered_map>
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device));
};
- return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", {});
+ return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
+ {ir::StringImm::make("InferType")});
}
+TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation")
+.set_body_typed(RewriteAnnotatedOps);
+
} // namespace transform
} // namespace relay
} // namespace tvm
-
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
#include <unordered_map>
#include "./pattern_util.h"
TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr")
.set_body_typed<Expr(Expr, PackedFunc)>(EliminateCommonSubexpr);
+namespace transform {
+
+Pass EliminateCommonSubexpr(PackedFunc fskip) {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
+ };
+ return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
+ {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr")
+.set_body_typed(EliminateCommonSubexpr);
+
+} // namespace transform
+
} // namespace relay
} // namespace tvm
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
- return Downcast<Function>(FoldConstant(f));
+ return Downcast<Function>(FoldConstant(f));
};
- return CreateFunctionPass(pass_func, 1, "fold_constant", {});
+ return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
+TVM_REGISTER_API("relay._transform.FoldConstant")
+.set_body_typed(FoldConstant);
+
} // namespace transform
} // namespace relay
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
#include "pattern_util.h"
#include "pass_util.h"
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
-Expr ForwardFoldScaleAxis(Expr data) {
+Expr ForwardFoldScaleAxis(const Expr& data) {
auto message = ForwardPrep().Prepare(data);
auto fcontext = [&](const Call& call) -> NodeRef{
auto it = message.find(call.get());
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
-Expr BackwardFoldScaleAxis(Expr data) {
+Expr BackwardFoldScaleAxis(const Expr& data) {
return make_node<BackwardTransformerNode>()->Fold(data);
}
.set_body_typed<Expr(Expr)>(BackwardFoldScaleAxis);
} // namespace fold_scale_axis
+
+namespace transform {
+
+Pass ForwardFoldScaleAxis() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(
+ relay::fold_scale_axis::ForwardFoldScaleAxis(f));
+ };
+ return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
+ {ir::StringImm::make("InferType")});
+}
+
+Pass BackwardFoldScaleAxis() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(
+ relay::fold_scale_axis::BackwardFoldScaleAxis(f));
+ };
+ return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
+ {ir::StringImm::make("InferType")});
+}
+
+Pass FoldScaleAxis() {
+ // FoldScaleAxis pass contains the following three passes. Therefore, we can
+ // register it as a sequential pass.
+ Pass pass = Sequential(
+ {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
+ "FoldScaleAxis");
+ return pass;
+}
+
+TVM_REGISTER_API("relay._transform.FoldScaleAxis")
+.set_body_typed(FoldScaleAxis);
+
+} // namespace transform
+
} // namespace relay
} // namespace tvm
fcontext,
fmulti_ref_trigger));
};
- return CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
+ return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {});
}
Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
fcontext,
fmulti_ref_trigger));
};
- return CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
+ return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {});
}
} // namespace transform
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
#include "./pattern_util.h"
#include "../../common/arena.h"
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
};
- return CreateFunctionPass(pass_func, 1, "fuse_ops", {});
+ return CreateFunctionPass(pass_func, 1, "FuseOps",
+ {ir::StringImm::make("InferType")});
}
+TVM_REGISTER_API("relay._transform.FuseOps")
+.set_body_typed(FuseOps);
+
} // namespace transform
} // namespace relay
}
TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = PartialEval(args[0]);
- });
+.set_body_typed(PartialEval);
namespace transform {
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(PartialEval(f));
};
- return CreateFunctionPass(pass_func, 1, "partial_eval", {});
+ return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
}
+TVM_REGISTER_API("relay._transform.PartialEvaluate")
+.set_body_typed(PartialEval);
+
} // namespace transform
} // namespace relay
using tvm::IRPrinter;
-/*!
- * \brief A data structure to map the names of specific optimizations to
- * numeric optimization levels
- */
-class OptPassLevel {
- public:
- /*!
- * \brief Get level for an optimization pass
- *
- * \param key pass name
- * \return int level
- */
- int operator[](const std::string& key) const {
- const auto data = CreateMap();
- auto it = data.find(key);
- if (it == data.end()) {
- return -1;
- }
- return it->second;
+namespace {
+
+// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
+// handled because we need to register the pass for Python invocation anyway.
+Pass GetPass(const std::string& pass_name) {
+ if (pass_name == "InferType") {
+ return InferType();
+ } else if (pass_name == "AlterOpLayout") {
+ return AlterOpLayout();
+ } else if (pass_name == "CanonicalizeOps") {
+ return CanonicalizeOps();
+ } else if (pass_name == "CombineParallelConv2d") {
+ return CombineParallelConv2D();
+ } else if (pass_name == "DeadCodeElimination") {
+ return DeadCodeElimination();
+ } else if (pass_name == "EliminateCommonSubexpr") {
+ return DeadCodeElimination();
+ } else if (pass_name == "FoldConstant") {
+ return FoldConstant();
+ } else if (pass_name == "BackwardFoldScaleAxis") {
+ return FoldScaleAxis();
+ } else if (pass_name == "ForwardFoldScaleAxis") {
+ return FoldScaleAxis();
+ } else if (pass_name == "FoldScaleAxis") {
+ return FoldScaleAxis();
+ } else if (pass_name == "PartialEvaluate") {
+ return SimplifyInference();
+ } else if (pass_name == "SimplifyInference") {
+ return SimplifyInference();
+ } else if (pass_name == "ToANormalForm") {
+ return ToANormalForm();
+ } else if (pass_name == "ToGraphNormalForm") {
+ return ToGraphNormalForm();
+ } else {
+ LOG(FATAL) << pass_name << " has not been registered yet." << "\n";
+ return Pass(nullptr);
}
+}
- private:
- static const std::unordered_map<std::string, int> CreateMap() {
- const std::unordered_map<std::string, int> m = {
- {"SimplifyInference", 0},
- {"OpFusion", 1},
- {"FoldConstant", 2},
- {"CombineParallelConv2D", 3},
- {"FoldScaleAxis", 3},
- {"AlterOpLayout", 3},
- {"CanonicalizeOps", 3},
- {"EliminateCommonSubexpr", 3}
- };
- return m;
- }
-};
+} // namespace
struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */
/* \brief The pass meta data.*/
PassInfo pass_info;
- /*!
- * \brief A helper struct to get the optimization pass name to opt level
- * mapping.
- */
- OptPassLevel opt_pass_level;
-
/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
void VisitAttrs(tvm::AttrVisitor* v) final {
const Array<tvm::Expr>& disabled) const;
std::unordered_set<std::string> RequiredPasses(
- const Array<tvm::Expr>& disabled) const;
+ const Array<tvm::Expr>& required) const;
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
}
// Module -> Module optimizations.
-// TODO(zhiics) Check and handle the required passes.
Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level << "\n";
+
CHECK(mod.defined());
- auto updated_mod = pass_func(mod, pass_ctx);
+ Module updated_mod = mod;
+ // Execute the required passes in a DFS way.
+ // TODO(zhiics) We may need to pass validation to detect the cyclic
+ // dependency.
+ for (const auto& it : pass_info->required) {
+ const auto* name = it.as<tvm::ir::StringImm>();
+ CHECK(name);
+ auto pass = GetPass(name->value);
+ updated_mod = pass(updated_mod, pass_ctx);
+ }
+
+ updated_mod = pass_func(updated_mod, pass_ctx);
CHECK(updated_mod.defined());
return updated_mod;
}
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
CHECK(mod.defined());
- Module new_mod = ModuleNode::make({}, mod->type_definitions);
DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level << "\n";
+
+ Module updated_mod = mod;
+ // Execute the required passes in a DFS way.
+ // TODO(zhiics) We may need to pass validation to detect the cyclic
+ // dependency.
+ for (const auto& it : pass_info->required) {
+ const auto* name = it.as<tvm::ir::StringImm>();
+ CHECK(name);
+ auto pass = GetPass(name->value);
+ updated_mod = pass(updated_mod, pass_ctx);
+ }
+
+ Module new_mod = ModuleNode::make({}, mod->type_definitions);
// Execute the pass function and return a new module.
for (const auto& it : mod->functions) {
- auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx);
+ auto updated_func = SkipFunction(it.second)
+ ? it.second
+ : pass_func(it.second, updated_mod, pass_ctx);
new_mod->Add(it.first, updated_func);
}
std::unordered_set<std::string> ret;
for (const auto& it : disabled) {
const auto* str = it.as<tvm::ir::StringImm>();
- CHECK(str) << "disabled passes must be string.";
+ CHECK(str) << "Disabled pass name must be string.";
ret.emplace(str->value);
}
return ret;
std::unordered_set<std::string> ret;
for (const auto& it : required) {
const auto* str = it.as<tvm::ir::StringImm>();
- CHECK(str) << "disabled passes must be string.";
+ CHECK(str) << "Required pass name must be string.";
ret.emplace(str->value);
}
return ret;
PassContext ctx = PassContext::Current();
auto required = RequiredPasses(ctx->required_pass);
- auto disabled = DisabledPasses(ctx->required_pass);
+ auto disabled = DisabledPasses(ctx->disabled_pass);
if (disabled.count(pass_name)) {
return false;
if (required.count(pass_name)) {
return true;
}
- return ctx->opt_level >= opt_pass_level[pass_name];
+
+ const Pass pass = GetPass(pass_name);
+ PassInfo info = pass->Info();
+ return ctx->opt_level >= info->opt_level;
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
-// ordering problem needed to be handled in the future.
+// ordering problem needs to be handled in the future.
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
- int opt_level = pass_ctx->opt_level;
- auto disabled = DisabledPasses(pass_ctx->disabled_pass);
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
+
PassInfo info = pass->Info();
const auto& pass_name = info->name;
- const auto& pass_opt_level = info->opt_level;
- // Skip the pass if its optimization level is higher that the one of in the
- // pass context or if this pass is disabled.
- if (pass_opt_level > opt_level || disabled.count(pass_name)) {
- continue;
+ // Execute the pass if it is enabled.
+ if (PassEnabled(pass_name)) {
+ mod = pass(mod, pass_ctx);
}
- const auto* pn = pass.operator->();
- mod = (*pn)(mod, pass_ctx);
}
return mod;
}
TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Pass()(args[1]);
+ Pass pass = args[0];
+ Module mod = args[1];
+ *ret = pass(mod);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ModulePassNode* node,
tvm::IRPrinter* p) {
- const PassInfoNode* pn = node->Info().operator->();
- p->stream << "Run Module pass: " << pn->name
- << " at the optimization level " << pn->opt_level;
+ const PassInfo info = node->Info();
+ p->stream << "Run Module pass: " << info->name
+ << " at the optimization level " << info->opt_level;
});
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
tvm::IRPrinter* p) {
- const PassInfoNode* pn = node->Info().operator->();
- p->stream << "Run Function pass: " << pn->name
- << " at the optimization level " << pn->opt_level;
+ const PassInfo info = node->Info();
+ p->stream << "Run Function pass: " << info->name
+ << " at the optimization level " << info->opt_level;
});
TVM_REGISTER_NODE_TYPE(SequentialNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SequentialNode>([](const SequentialNode* node,
tvm::IRPrinter* p) {
- const PassInfoNode* seq_pn = node->Info().operator->();
- p->stream << "Run Sequential pass: " << seq_pn->name
- << " at the optimization level " << seq_pn->opt_level << ". ";
+ const PassInfo info = node->Info();
+ p->stream << "Run Sequential pass: " << info->name
+ << " at the optimization level " << info->opt_level << ". ";
p->stream << "The passes will be executed are: [";
for (const auto& it : node->passes) {
- const PassNode* pn = it.operator->();
- const PassInfoNode* pass_info_node = pn->Info().operator->();
- p->stream << pass_info_node->name << " ";
+ const PassInfo pass_info = it->Info();
+ p->stream << pass_info->name << " ";
}
p->stream << "]";
});
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
#include "./pattern_util.h"
namespace tvm {
TVM_REGISTER_API("relay._ir_pass.simplify_inference")
.set_body_typed(SimplifyInference);
+namespace transform {
+
+Pass SimplifyInference() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(SimplifyInference(f));
+ };
+ return CreateFunctionPass(pass_func, 0, "SimplifyInference",
+ {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.SimplifyInference")
+.set_body_typed(SimplifyInference);
+
+} // namespace transform
+
} // namespace relay
} // namespace tvm
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ToANormalForm(f, m));
};
- return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {});
+ return CreateFunctionPass(pass_func, 1, "ToANormalForm", {});
}
+TVM_REGISTER_API("relay._transform.ToANormalForm")
+.set_body_typed(ToANormalForm);
+
} // namespace transform
} // namespace relay
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ToGraphNormalForm(f));
};
- return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {});
+ return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {});
}
+TVM_REGISTER_API("relay._transform.ToGraphNormalForm")
+.set_body_typed(ToGraphNormalForm);
+
} // namespace transform
} // namespace relay
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pass.h>
+#include <tvm/relay/transform.h>
#include "./pass_util.h"
#include "type_solver.h"
#include "../ir/type_functor.h"
.set_body_typed<Expr(const Expr&, const Module&)>([](const Expr& expr, const Module& mod_ref) {
return InferType(expr, mod_ref);
});
+
+namespace transform {
+
+Pass InferType() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(InferType(f, m));
+ };
+ return CreateFunctionPass(pass_func, 0, "InferType", {});
+}
+
+TVM_REGISTER_API("relay._transform.InferType")
+.set_body_typed<Pass()>([]() {
+ return InferType();
+});
+
+} // namespace transform
+
} // namespace relay
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <topi/generic/injective.h>
+#include <tvm/build_module.h>
+#include <tvm/packed_func_ext.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/module.h>
+#include <tvm/relay/pass.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tvm.h>
+
+TVM_REGISTER_GLOBAL("schedule")
+ .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+ *rv = topi::generic::schedule_injective(args[0], args[1]);
+ });
+
+TEST(Relay, Sequential) {
+ using namespace tvm;
+ auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, ::tvm::Float(32));
+ auto c_data =
+ tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+
+ // Create a function for optimization.
+ auto c = relay::ConstantNode::make(c_data);
+ auto a = relay::VarNode::make("a", tensor_type);
+ auto x = relay::VarNode::make("x", tensor_type);
+ auto add_op = relay::Op::Get("add");
+ auto y = relay::CallNode::make(add_op, {c, c});
+ y = relay::CallNode::make(add_op, {x, y});
+ auto z = relay::CallNode::make(add_op, {y, c});
+ auto z1 = relay::CallNode::make(add_op, {y, c});
+ auto z2 = relay::CallNode::make(add_op, {z, z1});
+ // Let expression and varaible a should be dead-code eliminated.
+ auto z3 = relay::LetNode::make(a, c, z2);
+ relay::Function func =
+ relay::FunctionNode::make(relay::FreeVars(z3), z3, relay::Type(), {});
+
+ // Get schedule
+ auto reg = tvm::runtime::Registry::Get("relay.op._Register");
+ auto sch = tvm::runtime::Registry::Get("schedule");
+ if (!reg || !sch) {
+ LOG(FATAL) << "Register/schedule is not defined.";
+ }
+
+ (*reg)("add", "FTVMSchedule", *sch, 10);
+
+ // Run sequential passes.
+ tvm::Array<relay::transform::Pass> pass_seqs{
+ relay::transform::InferType(),
+ relay::transform::DeadCodeElimination(),
+ relay::transform::EliminateCommonSubexpr(),
+ relay::transform::AlterOpLayout()
+ };
+ relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
+ auto mod = relay::ModuleNode::FromExpr(func);
+ auto pass_ctx = relay::transform::PassContext::Create();
+ pass_ctx->opt_level = 3;
+ pass_ctx->fallback_device = 1;
+ {
+ tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
+ tvm::With<tvm::Target> tctx(tvm::Target::Create("llvm"));
+ mod = seq(mod);
+ }
+
+ CHECK(mod.defined());
+ auto entry_func = mod->entry_func;
+ CHECK(entry_func.defined());
+ relay::Function f = mod->Lookup(entry_func->name_hint);
+ CHECK(f.defined());
+
+ // Expected function
+ auto c1 = relay::ConstantNode::make(c_data);
+ auto x1 = relay::VarNode::make("x", tensor_type);
+ auto y1 = relay::CallNode::make(add_op, {c1, c1});
+ y1 = relay::CallNode::make(add_op, {x1, y1});
+ auto zz = relay::CallNode::make(add_op, {y1, c1});
+ zz = relay::CallNode::make(add_op, {zz, zz});
+ relay::Function expected_func =
+ relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {});
+
+ // Infer type for the expected function.
+ auto expected = relay::InferType(expected_func, relay::Module(nullptr));
+ CHECK(relay::AlphaEqual(f, expected));
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ testing::FLAGS_gtest_death_test_style = "threadsafe";
+ return RUN_ALL_TESTS();
+}
def test_only_module_pass():
passes = [module_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes)
- ret_mod = sequential(mod)
+ with relay.build_config(required_pass=["mod_transform"]):
+ ret_mod = sequential(mod)
# Check the subtract function.
sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, sub)
# Check the subtract function.
passes = [function_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes)
- ret_mod = sequential(mod)
+ with relay.build_config(required_pass=["func_transform"]):
+ ret_mod = sequential(mod)
_, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, get_ref_sub())
mod = relay.Module({v_sub: sub, v_log: log})
passes = [module_pass, function_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes)
- ret_mod = sequential(mod)
+ required = ["mod_transform", "func_transform"]
+ with relay.build_config(required_pass=required):
+ ret_mod = sequential(mod)
# Check the abs function is added.
abs_var, abs_func = get_var_func()
test_multiple_passes()
+def test_sequential_with_scoping():
+ shape = (1, 2, 3)
+ c_data = np.array(shape).astype("float32")
+ tp = relay.TensorType(shape, "float32")
+ def before():
+ c = relay.const(c_data)
+ x = relay.var("x", tp)
+ y = relay.add(c, c)
+ y = relay.multiply(y, relay.const(2, "float32"))
+ y = relay.add(x, y)
+ z = relay.add(y, c)
+ z1 = relay.add(y, c)
+ z2 = relay.add(z, z1)
+ return relay.Function([x], z2)
+
+ def expected():
+ x = relay.var("x", tp)
+ c_folded = (c_data + c_data) * 2
+ y = relay.add(x, relay.const(c_folded))
+ z = relay.add(y, relay.const(c_data))
+ z1 = relay.add(z, z)
+ return relay.Function([x], z1)
+
+ seq = _transform.Sequential([
+ relay.transform.InferType(),
+ relay.transform.FoldConstant(),
+ relay.transform.EliminateCommonSubexpr(),
+ relay.transform.AlterOpLayout()
+ ])
+
+ mod = relay.Module({"main": before()})
+ with relay.build_config(opt_level=3):
+ with tvm.target.create("llvm"):
+ mod = seq(mod)
+
+ zz = mod["main"]
+ zexpected = ir_pass.infer_type(expected())
+ assert relay.ir_pass.alpha_equal(zz, zexpected)
+
+
if __name__ == "__main__":
test_module_pass()
test_function_pass()
test_sequential_pass()
+ test_sequential_with_scoping()