.. autofunction:: tvm.relay.build_module.build
-.. autofunction:: tvm.relay.build_module.build_config
-
.. autofunction:: tvm.relay.build_module.optimize
.. autofunction:: tvm.relay.build_module.create_executor
-.. autoclass:: tvm.relay.build_module.BuildConfig
- :members:
-
-.. autofunction:: tvm.relay.build_module.build_config
- :members:
-
.. autoclass:: tvm.relay.build_module.GraphExecutor
:members:
--- /dev/null
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+tvm.relay.transform
+----------------------
+
+.. automodule:: tvm.relay.transform
+
+.. autofunction:: tvm.relay.transform.build_config
+
+.. autofunction:: tvm.relay.transform.module_pass
+
+.. autofunction:: tvm.relay.transform.function_pass
+
+.. autoclass:: tvm.relay.transform.Pass
+ :members:
+
+.. autoclass:: tvm.relay.transform.PassInfo
+ :members:
+
+.. autoclass:: tvm.relay.transform.PassContext
+ :members:
+
+.. autoclass:: tvm.relay.transform.ModulePass
+ :members:
+
+.. autoclass:: tvm.relay.transform.FunctionPass
+ :members:
+
+.. autoclass:: tvm.relay.transform.Sequential
+ :members:
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
+#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <string>
+#include <unordered_map>
#include <vector>
namespace tvm {
*/
ErrorReporter err_reporter;
+ /*! \brief The default optimization level. */
+ int opt_level{2};
+
+ /*! \brief CPU is the default fallback device for heterogeneous execution. */
+ int fallback_device{static_cast<int>(kDLCPU)};
+
+ /*! \brief The list of required passes. */
+ tvm::Array<tvm::Expr> required_pass;
+ /*! \brief The list of disabled passes. */
+ tvm::Array<tvm::Expr> disabled_pass;
+
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("opt_level", &opt_level);
+ v->Visit("fallback_device", &fallback_device);
+ v->Visit("required_pass", &required_pass);
+ v->Visit("disabled_pass", &disabled_pass);
}
- TVM_DLL static PassContext make();
-
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
-TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
+class PassContext : public NodeRef {
+ public:
+ PassContext() {}
+ explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}
+
+ /*
+ * \brief Constructor of a `PassContext` object.
+ *
+ * \param opt_level The optimization level that will be applied.
+ * \param fallback_device The fallback device used for heterogeneous
+ * execution.
+ * \param required_pass The passes that are required for a context to execute
+ * other passes.
+ * \param required_pass The passes that will be disabled during the
+ * optimization under a context.
+ */
+ TVM_DLL PassContext(int opt_level,
+ int fallback_device,
+ tvm::Array<tvm::Expr> required_pass,
+ tvm::Array<tvm::Expr> disabled_pass);
+
+ // Get the currently used pass context.
+ TVM_DLL static PassContext Current();
+
+ const PassContextNode* operator->() const;
+
+ using ContainerType = PassContextNode;
+ class Internal;
+
+ private:
+ // The entry of a pass context scope.
+ TVM_DLL void EnterWithScope();
+ // The exit of a pass context scope.
+ TVM_DLL void ExitWithScope();
+
+ // Classes to get the Python `with` like syntax.
+ friend class Internal;
+ friend class tvm::With<PassContext>;
+};
/*
* \brief The meta data of a pass.
virtual PassInfo Info() const = 0;
/*!
- * \brief Set the context information for a pass.
+ * \brief Execute the optimization pass using a functor. This functor
+ * internally uses a current pass context.
+ *
+ * \param mod The module that an optimization pass runs on.
*
- * \param pass_ctx The context information for a certain pass.
+ * \return The updated module.
*/
- virtual void SetContext(const PassContext& pass_ctx) = 0;
+ Module operator()(const Module& mod) const {
+ return this->operator()(mod, PassContext::Current());
+ }
/*!
- * \brief Execute the optimization pass using a functor.
+ * \brief Execute the optimization pass using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
+ * \param pass_ctx The pass context that will be used to help the execution of
+ * optimizations.
*
* \return The updated module.
*/
- virtual Module operator()(const Module& mod) const = 0;
+ virtual Module operator()(const Module& mod,
+ const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {}
public:
/*!
* \brief The constructor of `Sequential`.
+ *
* \param passes The passes to apply.
* \param pass_info The pass metadata.
- * \param disabled The passes that will not be applied.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
- PassInfo pass_info,
- tvm::Array<tvm::Expr> disabled);
+ PassInfo pass_info);
+/*!
+ * \brief The constructor of `Sequential`.
+ *
+ * \param passes The passes to apply.
+ * \param name The name of a sequential pass. It's defaulted to "sequential".
+ * This allows users to only provide a list of passes and execute them
+ * under a given context.
+ */
+ TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
+
Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
from . import adt
from . import ir_pass
from . import transform
-from .build_module import build, build_config, create_executor
+from .build_module import build, create_executor
+from .transform import build_config
from . import prelude
from . import parser
from . import debug
from . import ir_pass
from . import ty as _ty
from . import expr as _expr
+from . import transform as _transform
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
-class BuildConfig(object):
- """Configuration scope to set a build config option.
-
- Parameters
- ----------
- kwargs
- Keyword arguments of configurations to set.
- """
- current = None
- defaults = {
- "opt_level": 2,
- "add_pass": None,
- "disable_pass": None,
- "fallback_device": None,
- }
-
- def __init__(self, **kwargs):
- self._old_scope = None
- for k, _ in kwargs.items():
- if k not in BuildConfig.defaults:
- raise ValueError("invalid argument %s, candidates are %s" %
- (k, BuildConfig.defaults.keys()))
- self._attr = kwargs
-
- def __getattr__(self, name):
- if name not in self._attr:
- return BuildConfig.defaults[name]
- return self._attr[name]
-
- def __enter__(self):
- # pylint: disable=protected-access
- self._old_scope = BuildConfig.current
- attr = BuildConfig.current._attr.copy()
- attr.update(self._attr)
- self._attr = attr
- BuildConfig.current = self
- return self
-
- def __exit__(self, ptype, value, trace):
- assert self._old_scope
- BuildConfig.current = self._old_scope
-
-
-BuildConfig.current = BuildConfig()
-
-
-def build_config(**kwargs):
- """Configure the build behavior by setting config variables.
-
- Parameters
- ----------
- opt_level: int, default=2
- Optimization level. See OPT_PASS_LEVEL for level of each pass.
-
- add_pass: set of str
- Optimization pass to be added regardless of optimization level.
-
- disable_pass: set of str
- Optimization pass to be disabled during optimization.
-
- fallback_device : str or tvm.TVMContext
- The fallback device. It is also used as the default device for
- operators without specified device during heterogeneous execution.
-
- Returns
- -------
- config: BuildConfig
- The build configuration
- """
- return BuildConfig(**kwargs)
-
-
def _update_target(target):
target = target if target else _target.current_target()
if target is None:
return graph_json, mod, params
def _setup_build_config(self, params):
- cfg = BuildConfig.current
+ cfg = _transform.PassContext.current()
# Set opt_level.
self.set_opt_level(cfg.opt_level)
self.set_fallback_device(cfg.fallback_device)
# Add required passes.
- if cfg.add_pass:
+ if cfg.required_pass:
passes = set()
- if isinstance(cfg.add_pass, (list, tuple, set)):
- passes = set(cfg.add_pass)
+ if isinstance(cfg.required_pass, (list, tuple, set)):
+ passes = set(cfg.required_pass)
else:
raise TypeError("add_pass must be list, tuple, or set, but " +
- "got {}".format(type(cfg.add_pass)))
+ "got {}".format(type(cfg.required_pass)))
for pass_name in passes:
self.add_pass(pass_name)
# Add disabled passes.
- if cfg.disable_pass:
+ if cfg.disabled_pass:
passes = set()
- if isinstance(cfg.disable_pass, (list, tuple, set)):
- passes = set(cfg.disable_pass)
+ if isinstance(cfg.disabled_pass, (list, tuple, set)):
+ passes = set(cfg.disabled_pass)
else:
raise TypeError("disable_pass must be list, tuple, or set, " +
- "but got {}".format(type(cfg.disable_pass)))
+ "but got {}".format(type(cfg.disabled_pass)))
for pass_name in passes:
self.disable_pass(pass_name)
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
"""
- if isinstance(fallback_device, str):
+ if isinstance(fallback_device, (int, str)):
fallback_device = _nd.context(fallback_device)
if not isinstance(fallback_device, TVMContext):
- raise TypeError("fallback_device is expected to be str " +
- "TVMContext, or dict of device name to target, " +
- "but received: {}".format(type(fallback_device)))
+ raise TypeError("fallback_device is expected to be str, int, or " +
+ "TVMContext but received: {}".format(type(fallback_device)))
self._set_fallback_device(fallback_device.device_type)
from . import _quantize
from .. import expr as _expr
from .. import ir_pass as _ir_pass
-from .. import build_module as _build
+from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
"FoldConstant",
"CanonicalizeOps"]
- cfg = _build.build_config(add_pass=opt_passes)
+ cfg = _transform.build_config(required_pass=opt_passes)
if params:
name_dict = {}
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
- if "SimplifyInference" in cfg.add_pass:
+ if "SimplifyInference" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
- if "FoldConstant" in cfg.add_pass:
+ if "FoldConstant" in cfg.required_pass:
func = _ir_pass.fold_constant(func)
- if "FoldScaleAxis" in cfg.add_pass:
+ if "FoldScaleAxis" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
- if "CanonicalizeOps" in cfg.add_pass:
+ if "CanonicalizeOps" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
- if "FoldConstant" in cfg.add_pass:
+ if "FoldConstant" in cfg.required_pass:
func = _ir_pass.fold_constant(func)
return func
"""
import types
+from tvm._ffi.runtime_ctypes import TVMContext
from . import _transform
from .base import RelayNode, register_relay_node
+from .. import nd as _nd
@register_relay_node
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
to record the errors of during the optimization, etc.
+
+ opt_level : Optional[int]
+ The optimization level of this pass.
+
+ fallback_device : Optional[Union[int, str, TVMContext]]
+ The fallback device type. It is also used as the default device for
+ operators that are not annotated during heterogeneous execution.
+
+ required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
+ The list of passes that are required by a certain pass.
+
+ disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
+ The list of passes that are disabled.
"""
+ def __init__(self,
+ opt_level=2,
+ fallback_device=_nd.cpu(),
+ required_pass=None,
+ disabled_pass=None):
+ if isinstance(fallback_device, str):
+ fallback_device = _nd.context(fallback_device).device_type
+ elif isinstance(fallback_device, TVMContext):
+ fallback_device = fallback_device.device_type
+ if not isinstance(fallback_device, int):
+ raise TypeError("required_pass is expected to be the type of " +
+ "int/str/TVMContext.")
+
+ required = list(required_pass) if required_pass else []
+ if not isinstance(required, (list, tuple)):
+ raise TypeError("required_pass is expected to be the type of " +
+ "list/tuple/set.")
- def __init__(self):
- self.__init_handle_by_constructor__(_transform.PassContext)
+ disabled = list(disabled_pass) if disabled_pass else []
+ if not isinstance(disabled, (list, tuple)):
+ raise TypeError("disabled_pass is expected to be the type of " +
+ "list/tuple/set.")
+
+ self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
+ fallback_device, required,
+ disabled)
+
+ def __enter__(self):
+ _transform.EnterPassContext(self)
+ return self
+
+ def __exit__(self, ptype, value, trace):
+ _transform.ExitPassContext(self)
+
+ @staticmethod
+ def current():
+ """Return the current pass context."""
+ return _transform.GetCurrentPassContext()
+
+
+def build_config(opt_level=2,
+ fallback_device=_nd.cpu(),
+ required_pass=None,
+ disabled_pass=None):
+ """Configure the build behavior by setting config variables.
+
+ Parameters
+ ----------
+ opt_level: int, optional
+ Optimization level. The optimization pass name and level are as the
+ following:
+
+ .. code-block:: python
+
+ OPT_PASS_LEVEL = {
+ "SimplifyInference": 0,
+ "OpFusion": 1,
+ "FoldConstant": 2,
+ "CombineParallelConv2D": 3,
+ "FoldScaleAxis": 3,
+ "AlterOpLayout": 3,
+ "CanonicalizeOps": 3,
+ "EliminateCommonSubexpr": 3,
+ }
+
+ fallback_device : int, str, or tvm.TVMContext, optional
+ The fallback device. It is also used as the default device for
+ operators without specified device during heterogeneous execution.
+
+ required_pass: set of str, optional
+ Optimization passes that are required regardless of optimization level.
+
+ disabled_pass: set of str, optional
+ Optimization passes to be disabled during optimization.
+
+ Returns
+ -------
+ pass_context: PassContext
+ The pass context for optimizations.
+ """
+ return PassContext(opt_level, fallback_device, required_pass,
+ disabled_pass)
@register_relay_node
conveniently interact with the base class.
"""
- def set_pass_context(self, pass_ctx):
- """Setup the pass context for analysis and optimizations. This context
- could be shared by different passes for sequential passes.
-
- Parameters
- ----------
- pass_ctx : PassContext
- The context that is used to help perform a certain pass or a series
- of passes.
- """
- if not isinstance(pass_ctx, PassContext):
- raise TypeError("pass_ctx is expected to be the PassContext type")
- _transform.SetContext(self, pass_ctx)
-
@property
def info(self):
"""Get the pass meta."""
required : Optional[List[str]]
The list of passes that the sequential pass is dependent on.
-
- disabled : Optional[List[str]]
- A list of disabled passes.
"""
def __init__(self,
passes=None,
opt_level=2,
name="sequential",
- required=None,
- disabled=None):
+ required=None):
passes = passes if passes else []
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a list of Pass objects.")
- disabled = disabled if disabled else []
- if not isinstance(disabled, (list, tuple)):
- raise TypeError("disabled must be a list or tuple of pass names")
-
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of list/tuple.")
self.__init_handle_by_constructor__(_transform.Sequential,
- passes, opt_level, name, required,
- disabled)
+ passes, opt_level, name, required)
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
* \file src/relay/pass/pass_manager.cc
* \brief Relay pass manager implementation.
*/
+#include <dmlc/thread_local.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
+#include <tvm/runtime/device_api.h>
+
+#include <algorithm>
+#include <stack>
+#include <unordered_set>
namespace tvm {
namespace relay {
using tvm::IRPrinter;
+/*!
+ * \brief A data structure to map the names of specific optimizations to
+ * numeric optimization levels
+ */
+class OptPassLevel {
+ public:
+ /*!
+ * \brief Get level for an optimization pass
+ *
+ * \param key pass name
+ * \return int level
+ */
+ int operator[](const std::string& key) const {
+ const auto data = CreateMap();
+ auto it = data.find(key);
+ if (it == data.end()) {
+ return -1;
+ }
+ return it->second;
+ }
+
+ private:
+ static const std::unordered_map<std::string, int> CreateMap() {
+ const std::unordered_map<std::string, int> m = {
+ {"SimplifyInference", 0},
+ {"OpFusion", 1},
+ {"FoldConstant", 2},
+ {"CombineParallelConv2D", 3},
+ {"FoldScaleAxis", 3},
+ {"AlterOpLayout", 3},
+ {"CanonicalizeOps", 3},
+ {"EliminateCommonSubexpr", 3}
+ };
+ return m;
+ }
+};
+
+PassContext::PassContext(int opt_level, int fallback_device,
+ tvm::Array<tvm::Expr> required_pass,
+ tvm::Array<tvm::Expr> disabled_pass) {
+ auto ctx = make_node<PassContextNode>();
+ ctx->opt_level = opt_level;
+ ctx->fallback_device = fallback_device;
+ ctx->required_pass = std::move(required_pass);
+ ctx->disabled_pass = std::move(disabled_pass);
+ node_ = std::move(ctx);
+}
+
+const PassContextNode* PassContext::operator->() const {
+ return static_cast<const PassContextNode*>(node_.get());
+}
+
+struct RelayPassContextThreadLocalEntry {
+ /*! \brief The default pass context. */
+ PassContext default_context;
+
+ /*! \brief The current pass context. */
+ std::stack<PassContext> context_stack;
+
+ RelayPassContextThreadLocalEntry() {
+ default_context = PassContext(make_node<PassContextNode>());
+ }
+};
+
+/*! \brief Thread local store to hold the pass context. */
+typedef dmlc::ThreadLocalStore<RelayPassContextThreadLocalEntry>
+ RelayPassContextThreadLocalStore;
+
+void PassContext::EnterWithScope() {
+ RelayPassContextThreadLocalEntry* entry =
+ RelayPassContextThreadLocalStore::Get();
+ entry->context_stack.push(*this);
+}
+
+void PassContext::ExitWithScope() {
+ RelayPassContextThreadLocalEntry* entry =
+ RelayPassContextThreadLocalStore::Get();
+ CHECK(!entry->context_stack.empty());
+ CHECK(entry->context_stack.top().same_as(*this));
+ entry->context_stack.pop();
+}
+
+PassContext PassContext::Current() {
+ RelayPassContextThreadLocalEntry* entry =
+ RelayPassContextThreadLocalStore::Get();
+ if (!entry->context_stack.empty()) {
+ return entry->context_stack.top();
+ } else {
+ return entry->default_context;
+ }
+}
+
class ModulePass;
/*!
}
/*!
- * \brief Run a module pass on a certain module.
+ * \brief Run a module pass on given pass context.
*
- * \param mod The module that an optimization pass runs on.
+ * \param mod The module that an optimization pass is applied on.
+ * \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
- Module operator()(const Module& mod) const final;
+ Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }
- /*!
- * \brief Set the context information for a module pass.
- *
- * \param pass_ctx The context information for a module pass.
- */
- void SetContext(const PassContext& pass_ctx) final;
-
TVM_DLL static ModulePass make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
PassInfo pass_info);
static constexpr const char* _type_key = "relay.ModulePass";
TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode);
-
- private:
- /*!
- * \brief The context information that is used to help perform a module pass.
- */
- PassContext pass_ctx_;
};
RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass);
}
/*!
- * \brief Run a function pass on a certain module.
+ * \brief Run a function pass on given pass context.
*
- * \param mod The module that an optimization pass runs on.
+ * \param mod The module that an optimization pass is applied on.
+ * \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
- Module operator()(const Module& mod) const final;
+ Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }
- /*!
- * \brief Set the context information for a function-level pass.
- *
- * \param pass_ctx The context information for a function-level pass.
- */
- void SetContext(const PassContext& pass_ctx) final;
-
TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
PassInfo pass_info);
* \return Return true if the function will be skipped, otherwise false.
*/
bool SkipFunction(const Function& func) const;
-
- /*!
- * \brief The context information that is used to help perform a module pass.
- */
- PassContext pass_ctx_;
};
RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);
/* \brief The pass meta data.*/
PassInfo pass_info;
- /*! \brief A list of passes that used to compose a sequential pass. */
- tvm::Array<Pass> passes;
/*!
- * \brief A list of disabled passes that should be excluded when executing the
- * sequential pass.
+ * \brief A helper struct to get the optimization pass name to opt level
+ * mapping.
*/
- tvm::Array<tvm::Expr> disabled;
+ OptPassLevel opt_pass_level;
+ /*! \brief A list of passes that used to compose a sequential pass. */
+ tvm::Array<Pass> passes;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes);
- v->Visit("disabled", &disabled);
}
/*!
}
/*!
+ * \brief Check if a pass is enabled.
+ *
+ * \param pass_name The name of an optimization/analysis pass.
+ *
+ * \return true if the pass is enabled. Otherwise, false.
+ */
+ bool pass_enabled(const std::string& pass_name) const;
+
+ /*!
* \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them.
*
*/
void ResolveDependency(const Module& mod);
- TVM_DLL std::vector<std::string> DisabledPasses() const;
+ std::unordered_set<std::string> DisabledPasses(
+ const Array<tvm::Expr>& disabled) const;
+
+ std::unordered_set<std::string> RequiredPasses(
+ const Array<tvm::Expr>& disabled) const;
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* be overloaded to focus on different metrics, i.e. performance,
* memory footprint, etc.
*
- * \param mod The module that an optimization pass runs on.
+ * \param mod The module that these passes are applied on.
+ * \param pass_ctx The context that these passes execute on.
*
* \return Return the updated module.
*/
- Module operator()(const Module& mod) const final;
-
- /*!
- * \brief Set the context information for a sequential pass.
- *
- * \param pass_ctx The context information for a sequential pass.
- */
- void SetContext(const PassContext& pass_ctx) final;
+ Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "relay.Sequential";
TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
-
- private:
- /*!
- * \brief The context information that is used to help perform a module pass.
- */
- PassContext pass_ctx_;
};
PassInfo PassInfoNode::make(int opt_level, std::string name,
return PassInfo(pass_info);
}
-PassContext PassContextNode::make() {
- auto ctx = make_node<PassContextNode>();
- return PassContext(ctx);
-}
-
ModulePass ModulePassNode::make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
PassInfo pass_info) {
}
// Module -> Module optimizations.
-// TODO(zhiics) 1. Check and handle the required passes.
-// 2. Probably use CoW for all places that use module instead of
-// returning the updated one.
-Module ModulePassNode::operator()(const Module& mod) const {
+// TODO(zhiics) Check and handle the required passes.
+Module ModulePassNode::operator()(const Module& mod,
+ const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
+
CHECK(mod.defined());
- auto updated_mod = pass_func(mod, pass_ctx_);
+ auto updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
return updated_mod;
}
-void ModulePassNode::SetContext(const PassContext& pass_ctx) {
- pass_ctx_ = pass_ctx;
-}
-
FunctionPass FunctionPassNode::make(
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
PassInfo pass_info) {
// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
-Module FunctionPassNode::operator()(const Module& mod) const {
+Module FunctionPassNode::operator()(const Module& mod,
+ const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined());
- std::vector<std::pair<GlobalVar, Function>> updated_funcs;
- ModuleNode* mod_node = mod.operator->();
- for (const auto& it : mod_node->functions) {
- if (!SkipFunction(it.second)) {
- auto updated_func = pass_func(it.second, pass_ctx_);
- CHECK(updated_func.defined());
- updated_funcs.push_back({std::move(it.first), std::move(updated_func)});
- }
- }
+ Module new_mod = ModuleNode::make({}, mod->type_definitions);
- // Update the optimized functions.
- for (const auto& it : updated_funcs) {
- mod_node->Update(it.first, it.second);
+ // Execute the pass function and return a new module.
+ for (const auto& it : mod->functions) {
+ auto updated_func =
+ SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx);
+ new_mod->Add(it.first, updated_func);
}
- return GetRef<Module>(mod_node);
-}
-
-void FunctionPassNode::SetContext(const PassContext& pass_ctx) {
- pass_ctx_ = pass_ctx;
+ return new_mod;
}
// TODO(zhiics) Create an enum attribute for FunctionNode
return pval && pval->value != 0;
}
-Sequential::Sequential(tvm::Array<Pass> passes,
- PassInfo pass_info,
- tvm::Array<tvm::Expr> disabled) {
+Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
auto n = make_node<SequentialNode>();
n->passes = std::move(passes);
n->pass_info = std::move(pass_info);
- n->disabled = std::move(disabled);
node_ = std::move(n);
}
-const SequentialNode* Sequential::operator->() const {
- return static_cast<const SequentialNode*>(this->node_.get());
+Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
+ auto n = make_node<SequentialNode>();
+ n->passes = std::move(passes);
+ PassInfo pass_info = PassInfoNode::make(2, std::move(name), {});
+ n->pass_info = std::move(pass_info);
+ node_ = std::move(n);
}
-// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
-// a Sequential without the consideration of their orders. The phase
-// ordering problem needed to be handled in the future.
-Module SequentialNode::operator()(const Module& module) const {
- Module mod = module;
- for (const Pass& pass : passes) {
- CHECK(pass.defined()) << "Found undefined pass for optimization.";
- const auto* pn = pass.operator->();
- mod = (*pn)(mod);
- }
- return mod;
+const SequentialNode* Sequential::operator->() const {
+ return static_cast<const SequentialNode*>(this->node_.get());
}
void SequentialNode::ResolveDependency(const Module& mod) {
<< "\n";
}
-std::vector<std::string> SequentialNode::DisabledPasses() const {
- std::vector<std::string> ret;
+std::unordered_set<std::string> SequentialNode::DisabledPasses(
+ const Array<tvm::Expr>& disabled) const {
+ std::unordered_set<std::string> ret;
for (const auto& it : disabled) {
const auto* str = it.as<tvm::ir::StringImm>();
CHECK(str) << "disabled passes must be string.";
- ret.push_back(str->value);
+ ret.emplace(str->value);
}
return ret;
}
-void SequentialNode::SetContext(const PassContext& pass_ctx) {
- pass_ctx_ = pass_ctx;
+std::unordered_set<std::string> SequentialNode::RequiredPasses(
+ const Array<tvm::Expr>& required) const {
+ std::unordered_set<std::string> ret;
+ for (const auto& it : required) {
+ const auto* str = it.as<tvm::ir::StringImm>();
+ CHECK(str) << "disabled passes must be string.";
+ ret.emplace(str->value);
+ }
+ return ret;
+}
+
+bool SequentialNode::pass_enabled(const std::string& pass_name) const {
+ PassContext ctx = PassContext::Current();
+
+ const PassContextNode* ctx_node = ctx.operator->();
+ auto required = RequiredPasses(ctx_node->required_pass);
+ auto disabled = DisabledPasses(ctx_node->required_pass);
+
+ if (disabled.count(pass_name)) {
+ return false;
+ }
+
+ if (required.count(pass_name)) {
+ return true;
+ }
+ return ctx_node->opt_level >= opt_pass_level[pass_name];
+}
+
+// TODO(zhiics): we currenlty only sequentially execute each pass in
+// a Sequential without the consideration of their orders. The phase
+// ordering problem needed to be handled in the future.
+Module SequentialNode::operator()(const Module& module,
+ const PassContext& pass_ctx) const {
+ const auto* ctx_node = pass_ctx.operator->();
+ int opt_level = ctx_node->opt_level;
+ auto disabled = DisabledPasses(ctx_node->disabled_pass);
+ Module mod = module;
+ for (const Pass& pass : passes) {
+ CHECK(pass.defined()) << "Found undefined pass for optimization.";
+ PassInfo info = pass->Info();
+ const auto& pass_name = info.operator->()->name;
+ const auto& pass_opt_level = info.operator->()->opt_level;
+ // Skip the pass if its optimization level is higher that the one of in the
+ // pass context or if this pass is disabled.
+ if (pass_opt_level > opt_level || disabled.count(pass_name)) {
+ continue;
+ }
+ const auto* pn = pass.operator->();
+ mod = (*pn)(mod, pass_ctx);
+ }
+ return mod;
}
Pass CreateModulePass(
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::Expr> required = args[3];
- tvm::Array<tvm::Expr> disabled = args[4];
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
- *ret = Sequential(passes, pass_info, disabled);
+ *ret = Sequential(passes, pass_info);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "]";
});
-TVM_REGISTER_API("relay._transform.SetContext")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- Pass pass = args[0];
- PassContext pass_ctx = args[1];
- pass->SetContext(pass_ctx);
-});
-
TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._transform.PassContext")
-.set_body_typed(PassContextNode::make);
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ int opt_level = args[0];
+ int fallback_device = args[1];
+ tvm::Array<tvm::Expr> required = args[2];
+ tvm::Array<tvm::Expr> disabled = args[3];
+ *ret = PassContext(opt_level, fallback_device, required, disabled);
+});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PassContextNode>([](const PassContextNode* node,
- tvm::IRPrinter* p) {
- p->stream << "TODO(zhiics): printing context";
- LOG(FATAL) << "PassContext printer has not been implemented yet."
- << "\n";
+ tvm::IRPrinter* p) {
+ p->stream << "Pass context information: " << "\n";
+ p->stream << "\topt_level: " << node->opt_level << "\n";
+ p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level)
+ << "\n";
+
+ p->stream << "\trequired passes: [" << node->opt_level;
+ for (const auto& it : node->required_pass) {
+ p->stream << it << " ";
+ }
+ p->stream << "]\n";
+
+ p->stream << "\tdisabled passes: [" << node->opt_level;
+ for (const auto& it : node->disabled_pass) {
+ p->stream << it << " ";
+ }
+ p->stream << "]";
});
+class PassContext::Internal {
+ public:
+ static void EnterScope(PassContext pass_ctx) {
+ pass_ctx.EnterWithScope();
+ }
+
+ static void ExitScope(PassContext pass_ctx) {
+ pass_ctx.ExitWithScope();
+ }
+};
+
+TVM_REGISTER_API("relay._transform.GetCurrentPassContext")
+.set_body_typed(PassContext::Current);
+
+TVM_REGISTER_API("relay._transform.EnterPassContext")
+.set_body_typed(PassContext::Internal::EnterScope);
+
+TVM_REGISTER_API("relay._transform.ExitPassContext")
+.set_body_typed(PassContext::Internal::ExitScope);
+
} // namespace transform
} // namespace relay
} // namespace tvm
def get_tvm_output(func, x, params, target, ctx,
out_shape=(1, 1000), input_name='image', dtype='float32'):
- with relay.build_module.build_config(opt_level=3):
+ with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
dtype_dict = {input_name: input_data.dtype}
func, params = relay.frontend.from_coreml(coreml_model, shape_dict)
- with relay.build_module.build_config(opt_level=3):
+ with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
from tvm.contrib import graph_runtime
def get_tvm_output(xs, target, ctx, dtype='float32'):
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
func, params = relay.frontend.from_keras(keras_model, shape_dict)
- with relay.build_module.build_config(opt_level=2):
+ with relay.transform.build_config(opt_level=2):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
for name, x in zip(keras_model.input_names, xs):
# target x86 CPU
target = "llvm"
-with relay.build_module.build_config(opt_level=3):
+with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
######################################################################