using tvm::IRPrinter;
-namespace {
-
-// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
-// handled because we need to register the pass for Python invocation anyway.
-Pass GetPass(const std::string& pass_name) {
- if (pass_name == "InferType") {
- return InferType();
- } else if (pass_name == "AlterOpLayout") {
- return AlterOpLayout();
- } else if (pass_name == "CanonicalizeOps") {
- return CanonicalizeOps();
- } else if (pass_name == "CombineParallelConv2d") {
- return CombineParallelConv2D();
- } else if (pass_name == "DeadCodeElimination") {
- return DeadCodeElimination();
- } else if (pass_name == "EliminateCommonSubexpr") {
- return DeadCodeElimination();
- } else if (pass_name == "FoldConstant") {
- return FoldConstant();
- } else if (pass_name == "BackwardFoldScaleAxis") {
- return FoldScaleAxis();
- } else if (pass_name == "ForwardFoldScaleAxis") {
- return FoldScaleAxis();
- } else if (pass_name == "FoldScaleAxis") {
- return FoldScaleAxis();
- } else if (pass_name == "PartialEvaluate") {
- return SimplifyInference();
- } else if (pass_name == "SimplifyInference") {
- return SimplifyInference();
- } else if (pass_name == "ToANormalForm") {
- return ToANormalForm();
- } else if (pass_name == "ToGraphNormalForm") {
- return ToGraphNormalForm();
- } else {
- LOG(FATAL) << pass_name << " has not been registered yet." << "\n";
- return Pass(nullptr);
- }
-}
-
-} // namespace
-
struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
+
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes);
PassInfo Info() const { return pass_info; }
/*!
- * \brief Add a pass to the pass list.
- *
- * \param pass The candidate pass to be added.
- */
- void AddPass(const Pass& pass) {
- passes.push_back(pass);
- }
-
- /*!
* \brief Check if a pass is enabled.
*
- * \param pass_name The name of an optimization/analysis pass.
+ * \param info The pass information.
*
* \return true if the pass is enabled. Otherwise, false.
*/
- bool PassEnabled(const std::string& pass_name) const;
+ bool PassEnabled(const PassInfo& info) const;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
*/
void ResolveDependency(const Module& mod);
- std::unordered_set<std::string> DisabledPasses(
- const Array<tvm::Expr>& disabled) const;
-
- std::unordered_set<std::string> RequiredPasses(
- const Array<tvm::Expr>& required) const;
-
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could
TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
};
-PassInfo PassInfoNode::make(int opt_level, std::string name,
+PassInfo PassInfoNode::make(int opt_level,
+ std::string name,
tvm::Array<tvm::Expr> required) {
auto pass_info = make_node<PassInfoNode>();
pass_info->opt_level = opt_level;
// Module -> Module optimizations.
Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
- PassInfo pass_info = Info();
- DLOG(INFO) << "Executing module pass : " << pass_info->name
- << " with opt level: " << pass_info->opt_level << "\n";
-
+ const PassInfo& pass_info = Info();
+ DLOG(INFO) << "Executing module pass : "
+ << pass_info->name
+ << " with opt level: "
+ << pass_info->opt_level;
CHECK(mod.defined());
- Module updated_mod = mod;
- // Execute the required passes in a DFS way.
- // TODO(zhiics) We may need to pass validation to detect the cyclic
- // dependency.
- for (const auto& it : pass_info->required) {
- const auto* name = it.as<tvm::ir::StringImm>();
- CHECK(name);
- auto pass = GetPass(name->value);
- updated_mod = pass(updated_mod, pass_ctx);
- }
-
- updated_mod = pass_func(updated_mod, pass_ctx);
+ Module updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
return updated_mod;
}
}
// Perform Module -> Module optimizations at the Function level.
-// TODO(zhiics) Check and handle the required passes.
Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
- PassInfo pass_info = Info();
+ const PassInfo& pass_info = Info();
CHECK(mod.defined());
- DLOG(INFO) << "Executing module pass : " << pass_info->name
- << " with opt level: " << pass_info->opt_level << "\n";
-
+ DLOG(INFO) << "Executing module pass : "
+ << pass_info->name
+ << " with opt level: "
+ << pass_info->opt_level;
Module updated_mod = mod;
- // Execute the required passes in a DFS way.
- // TODO(zhiics) We may need to pass validation to detect the cyclic
- // dependency.
- for (const auto& it : pass_info->required) {
- const auto* name = it.as<tvm::ir::StringImm>();
- CHECK(name);
- auto pass = GetPass(name->value);
- updated_mod = pass(updated_mod, pass_ctx);
- }
-
Module new_mod = ModuleNode::make({}, mod->type_definitions);
// Execute the pass function and return a new module.
for (const auto& it : mod->functions) {
: pass_func(it.second, updated_mod, pass_ctx);
new_mod->Add(it.first, updated_func);
}
-
return new_mod;
}
<< "\n";
}
-std::unordered_set<std::string> SequentialNode::DisabledPasses(
- const Array<tvm::Expr>& disabled) const {
- std::unordered_set<std::string> ret;
- for (const auto& it : disabled) {
- const auto* str = it.as<tvm::ir::StringImm>();
- CHECK(str) << "Disabled pass name must be string.";
- ret.emplace(str->value);
- }
- return ret;
-}
-
-std::unordered_set<std::string> SequentialNode::RequiredPasses(
- const Array<tvm::Expr>& required) const {
- std::unordered_set<std::string> ret;
- for (const auto& it : required) {
- const auto* str = it.as<tvm::ir::StringImm>();
- CHECK(str) << "Required pass name must be string.";
- ret.emplace(str->value);
+// linearly scan the pass array to match pass_name
+inline bool PassArrayContains(const Array<tvm::Expr>& pass_array,
+ const std::string& pass_name) {
+ for (auto x : pass_array) {
+ auto* str_name = x.as<ir::StringImm>();
+ CHECK(str_name) << "pass name must be str";
+ if (str_name->value == pass_name) return true;
}
- return ret;
+ return false;
}
-bool SequentialNode::PassEnabled(const std::string& pass_name) const {
+bool SequentialNode::PassEnabled(const PassInfo& info) const {
PassContext ctx = PassContext::Current();
- auto required = RequiredPasses(ctx->required_pass);
- auto disabled = DisabledPasses(ctx->disabled_pass);
-
- if (disabled.count(pass_name)) {
+ if (PassArrayContains(ctx->disabled_pass, info->name)) {
return false;
}
- if (required.count(pass_name)) {
+ if (PassArrayContains(ctx->required_pass, info->name)) {
return true;
}
- const Pass pass = GetPass(pass_name);
- PassInfo info = pass->Info();
return ctx->opt_level >= info->opt_level;
}
+Pass GetPass(const std::string& pass_name) {
+ using tvm::runtime::Registry;
+ std::string fpass_name = "relay._transform." + pass_name;
+ const auto* f = Registry::Get(fpass_name);
+ CHECK(f != nullptr) << "Cannot find " << fpass_name
+ << "to create the pass " << pass_name;
+ return (*f)();
+}
+
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
-
- PassInfo info = pass->Info();
- const auto& pass_name = info->name;
- // Execute the pass if it is enabled.
- if (PassEnabled(pass_name)) {
- mod = pass(mod, pass_ctx);
+ const PassInfo& pass_info = pass->Info();
+ if (!PassEnabled(pass_info)) continue;
+ // resolve dependencies
+ for (const auto& it : pass_info->required) {
+ const auto* name = it.as<tvm::ir::StringImm>();
+ CHECK(name);
+ mod = GetPass(name->value)(mod, pass_ctx);
}
+ mod = pass(mod, pass_ctx);
}
return mod;
}