class PassContext;
/*!
- * \brief PassContextNode contains the information that a pass can rely on, such as
- * analysis results.
+ * \brief PassContextNode contains the information that a pass can rely on,
+ * such as analysis results.
*/
class PassContextNode : public RelayNode {
public:
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
+/*!
+ * \brief PassContext that is used to configure the pass behavior.
+ *
+ * \code
+ *
+ * auto new_ctx = PassContext::Create();
+ * ctx->opt_level = 2;
+ * ctx->fallback_device = kDLCPU;
+ * With<PassContext> scope(ctx);
+ * // pass context in effect.
+ *
+ * \endcode
+ */
class PassContext : public NodeRef {
public:
PassContext() {}
- explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}
-
- /*
- * \brief Constructor of a `PassContext` object.
- *
- * \param opt_level The optimization level that will be applied.
- * \param fallback_device The fallback device used for heterogeneous
- * execution.
- * \param required_pass The passes that are required for a context to execute
- * other passes.
- * \param required_pass The passes that will be disabled during the
- * optimization under a context.
+ explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {}
+ /*!
+ * \brief const accessor.
+ * \return const access pointer.
+ */
+ const PassContextNode* operator->() const {
+ CHECK(node_.get() != nullptr);
+ return static_cast<const PassContextNode*>(node_.get());
+ }
+ /*!
+ * \brief mutable accessor.
+ * \return mutable access pointer.
+ */
+ PassContextNode* operator->() {
+ CHECK(node_.get() != nullptr);
+ return static_cast<PassContextNode*>(node_.get());
+ }
+ /*!
+ * \brief Construct a PassContext containing the default configurations.
+ * \return The new PassContext.
+ */
+ TVM_DLL static PassContext Create();
+ /*!
+ * \brief Get the default pass context in the current scope.
+ * \return The pass context.
*/
- TVM_DLL PassContext(int opt_level,
- int fallback_device,
- tvm::Array<tvm::Expr> required_pass,
- tvm::Array<tvm::Expr> disabled_pass);
-
- // Get the currently used pass context.
TVM_DLL static PassContext Current();
- const PassContextNode* operator->() const;
-
+ // accessor.
using ContainerType = PassContextNode;
class Internal;
virtual PassInfo Info() const = 0;
/*!
- * \brief Execute the optimization pass using a functor. This functor
- * internally uses a current pass context.
+ * \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
- * \return The updated module.
+ * \return The transformed module.
*/
Module operator()(const Module& mod) const {
return this->operator()(mod, PassContext::Current());
}
/*!
- * \brief Execute the optimization pass using a functor under a given pass context.
+ * \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
- * \param pass_ctx The pass context that will be used to help the execution of
- * optimizations.
+ * \param pass_ctx The pass context that can provide information for the optimization.
*
- * \return The updated module.
+ * \return The transformed module.
*/
virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;
class Pass : public NodeRef {
public:
- Pass() = default;
- explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
-
- PassNode* operator->() const {
- return static_cast<PassNode*>(this->node_.get());
+ /*!
+ * \brief Transform mod using the default PassContext in the current scope.
+ *
+ * \param mod The module that an optimization pass runs on.
+ *
+ * \return The transformed module.
+ */
+ Module operator()(const Module& mod) const {
+ const PassNode* node = operator->();
+ CHECK(node != nullptr);
+ return node->operator()(mod);
+ }
+ /*!
+ * \brief Transform mod using a functor under a given pass context.
+ *
+ * \param mod The module that an optimization pass runs on.
+ * \param pass_ctx The pass context that can provide information for the optimization.
+ *
+ * \return The transformed module.
+ */
+ Module operator()(const Module& mod,
+ const PassContext& pass_ctx) const {
+ const PassNode* node = operator->();
+ CHECK(node != nullptr);
+ return node->operator()(mod, pass_ctx);
}
- using ContainerType = PassNode;
+ TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode);
};
class SequentialNode;
}
};
-PassContext::PassContext(int opt_level, int fallback_device,
- tvm::Array<tvm::Expr> required_pass,
- tvm::Array<tvm::Expr> disabled_pass) {
- auto ctx = make_node<PassContextNode>();
- ctx->opt_level = opt_level;
- ctx->fallback_device = fallback_device;
- ctx->required_pass = std::move(required_pass);
- ctx->disabled_pass = std::move(disabled_pass);
- node_ = std::move(ctx);
-}
-
-const PassContextNode* PassContext::operator->() const {
- return static_cast<const PassContextNode*>(node_.get());
-}
-
struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
}
}
+PassContext PassContext::Create() {
+ return PassContext(make_node<PassContextNode>());
+}
+
class ModulePass;
/*!
*
* \return true if the pass is enabled. Otherwise, false.
*/
- bool pass_enabled(const std::string& pass_name) const;
+ bool PassEnabled(const std::string& pass_name) const;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
- LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name
- << " with opt level: " << pass_info.operator->()->opt_level << "\n";
-
+ DLOG(INFO) << "Executing module pass : " << pass_info->name
+ << " with opt level: " << pass_info->opt_level << "\n";
CHECK(mod.defined());
auto updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
- LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
- << " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined());
Module new_mod = ModuleNode::make({}, mod->type_definitions);
-
+ DLOG(INFO) << "Executing module pass : " << pass_info->name
+ << " with opt level: " << pass_info->opt_level << "\n";
// Execute the pass function and return a new module.
for (const auto& it : mod->functions) {
auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx);
return ret;
}
-bool SequentialNode::pass_enabled(const std::string& pass_name) const {
+bool SequentialNode::PassEnabled(const std::string& pass_name) const {
PassContext ctx = PassContext::Current();
- const PassContextNode* ctx_node = ctx.operator->();
- auto required = RequiredPasses(ctx_node->required_pass);
- auto disabled = DisabledPasses(ctx_node->required_pass);
+ auto required = RequiredPasses(ctx->required_pass);
+ auto disabled = DisabledPasses(ctx->required_pass);
if (disabled.count(pass_name)) {
return false;
if (required.count(pass_name)) {
return true;
}
- return ctx_node->opt_level >= opt_pass_level[pass_name];
+ return ctx->opt_level >= opt_pass_level[pass_name];
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// ordering problem needed to be handled in the future.
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
- const auto* ctx_node = pass_ctx.operator->();
- int opt_level = ctx_node->opt_level;
- auto disabled = DisabledPasses(ctx_node->disabled_pass);
+ int opt_level = pass_ctx->opt_level;
+ auto disabled = DisabledPasses(pass_ctx->disabled_pass);
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
PassInfo info = pass->Info();
- const auto& pass_name = info.operator->()->name;
- const auto& pass_opt_level = info.operator->()->opt_level;
+ const auto& pass_name = info->name;
+ const auto& pass_opt_level = info->opt_level;
// Skip the pass if its optimization level is higher that the one of in the
// pass context or if this pass is disabled.
if (pass_opt_level > opt_level || disabled.count(pass_name)) {
TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
- Pass pass = args[0];
- Module mod = args[1];
- CHECK(pass.defined())
- << "Running an undefined pass is not allowed."
- << "\n";
-
- const auto* pn = pass.operator->();
- *ret = (*pn)(mod);
+ *ret = args[0].operator Pass()(args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_API("relay._transform.PassContext")
.set_body([](TVMArgs args, TVMRetValue* ret) {
+ auto pctx = PassContext::Create();
int opt_level = args[0];
int fallback_device = args[1];
tvm::Array<tvm::Expr> required = args[2];
tvm::Array<tvm::Expr> disabled = args[3];
- *ret = PassContext(opt_level, fallback_device, required, disabled);
+ pctx->opt_level = opt_level;
+ pctx->fallback_device = fallback_device;
+ pctx->required_pass = std::move(required);
+ pctx->disabled_pass = std::move(disabled);
+ *ret = pctx;
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)