[PASS][RELAY] polish pass infra (#3319)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 7 Jun 2019 21:38:57 +0000 (14:38 -0700)
committerGitHub <noreply@github.com>
Fri, 7 Jun 2019 21:38:57 +0000 (14:38 -0700)
3rdparty/dmlc-core
src/relay/pass/pass_manager.cc

index 3943914..fbe142b 160000 (submodule)
@@ -1 +1 @@
-Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f
+Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661
index 13e908d..05eb43d 100644 (file)
@@ -37,47 +37,6 @@ namespace transform {
 
 using tvm::IRPrinter;
 
-namespace {
-
-// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
-// handled because we need to register the pass for Python invocation anyway.
-Pass GetPass(const std::string& pass_name) {
-  if (pass_name == "InferType") {
-    return InferType();
-  } else if (pass_name == "AlterOpLayout") {
-    return AlterOpLayout();
-  } else if (pass_name == "CanonicalizeOps") {
-    return CanonicalizeOps();
-  } else if (pass_name == "CombineParallelConv2d") {
-    return CombineParallelConv2D();
-  } else if (pass_name == "DeadCodeElimination") {
-    return DeadCodeElimination();
-  } else if (pass_name == "EliminateCommonSubexpr") {
-    return DeadCodeElimination();
-  } else if (pass_name == "FoldConstant") {
-    return FoldConstant();
-  } else if (pass_name == "BackwardFoldScaleAxis") {
-    return FoldScaleAxis();
-  } else if (pass_name == "ForwardFoldScaleAxis") {
-    return FoldScaleAxis();
-  } else if (pass_name == "FoldScaleAxis") {
-    return FoldScaleAxis();
-  } else if (pass_name == "PartialEvaluate") {
-    return SimplifyInference();
-  } else if (pass_name == "SimplifyInference") {
-    return SimplifyInference();
-  } else if (pass_name == "ToANormalForm") {
-    return ToANormalForm();
-  } else if (pass_name == "ToGraphNormalForm") {
-    return ToGraphNormalForm();
-  } else {
-    LOG(FATAL) << pass_name << " has not been registered yet." << "\n";
-    return Pass(nullptr);
-  }
-}
-
-}  // namespace
-
 struct RelayPassContextThreadLocalEntry {
   /*! \brief The default pass context. */
   PassContext default_context;
@@ -252,6 +211,7 @@ class SequentialNode : public PassNode {
 
   /*! \brief A list of passes that used to compose a sequential pass. */
   tvm::Array<Pass> passes;
+
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("pass_info", &pass_info);
     v->Visit("passes", &passes);
@@ -263,22 +223,13 @@ class SequentialNode : public PassNode {
   PassInfo Info() const { return pass_info; }
 
   /*!
-   * \brief Add a pass to the pass list.
-   *
-   * \param pass The candidate pass to be added.
-   */
-  void AddPass(const Pass& pass) {
-    passes.push_back(pass);
-  }
-
-  /*!
    * \brief Check if a pass is enabled.
    *
-   * \param pass_name The name of an optimization/analysis pass.
+   * \param info The pass information.
    *
    * \return true if the pass is enabled. Otherwise, false.
    */
-  bool PassEnabled(const std::string& pass_name) const;
+  bool PassEnabled(const PassInfo& info) const;
 
   /*!
    * \brief Resolve the pass dependency. It globs all required passes by
@@ -294,12 +245,6 @@ class SequentialNode : public PassNode {
    */
   void ResolveDependency(const Module& mod);
 
-  std::unordered_set<std::string> DisabledPasses(
-      const Array<tvm::Expr>& disabled) const;
-
-  std::unordered_set<std::string> RequiredPasses(
-      const Array<tvm::Expr>& required) const;
-
   /*!
    * \brief Perform optimizations on a series of passes. The aforementioned
    *        typical pass manager jobs could be done by it. This function could
@@ -317,7 +262,8 @@ class SequentialNode : public PassNode {
   TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
 };
 
-PassInfo PassInfoNode::make(int opt_level, std::string name,
+PassInfo PassInfoNode::make(int opt_level,
+                            std::string name,
                             tvm::Array<tvm::Expr> required) {
   auto pass_info = make_node<PassInfoNode>();
   pass_info->opt_level = opt_level;
@@ -338,23 +284,13 @@ ModulePass ModulePassNode::make(
 // Module -> Module optimizations.
 Module ModulePassNode::operator()(const Module& mod,
                                   const PassContext& pass_ctx) const {
-  PassInfo pass_info = Info();
-  DLOG(INFO) << "Executing module pass : " << pass_info->name
-             << " with opt level: " << pass_info->opt_level << "\n";
-
+  const PassInfo& pass_info = Info();
+  DLOG(INFO) << "Executing module pass : "
+             << pass_info->name
+             << " with opt level: "
+             << pass_info->opt_level;
   CHECK(mod.defined());
-  Module updated_mod = mod;
-  // Execute the required passes in a DFS way.
-  // TODO(zhiics) We may need to pass validation to detect the cyclic
-  // dependency.
-  for (const auto& it : pass_info->required) {
-    const auto* name = it.as<tvm::ir::StringImm>();
-    CHECK(name);
-    auto pass = GetPass(name->value);
-    updated_mod = pass(updated_mod, pass_ctx);
-  }
-
-  updated_mod = pass_func(updated_mod, pass_ctx);
+  Module updated_mod = pass_func(mod, pass_ctx);
   CHECK(updated_mod.defined());
   return updated_mod;
 }
@@ -369,25 +305,15 @@ FunctionPass FunctionPassNode::make(
 }
 
 // Perform Module -> Module optimizations at the Function level.
-// TODO(zhiics) Check and handle the required passes.
 Module FunctionPassNode::operator()(const Module& mod,
                                     const PassContext& pass_ctx) const {
-  PassInfo pass_info = Info();
+  const PassInfo& pass_info = Info();
   CHECK(mod.defined());
-  DLOG(INFO) << "Executing module pass : " << pass_info->name
-             << " with opt level: " << pass_info->opt_level << "\n";
-
+  DLOG(INFO) << "Executing module pass : "
+             << pass_info->name
+             << " with opt level: "
+             << pass_info->opt_level;
   Module updated_mod = mod;
-  // Execute the required passes in a DFS way.
-  // TODO(zhiics) We may need to pass validation to detect the cyclic
-  // dependency.
-  for (const auto& it : pass_info->required) {
-    const auto* name = it.as<tvm::ir::StringImm>();
-    CHECK(name);
-    auto pass = GetPass(name->value);
-    updated_mod = pass(updated_mod, pass_ctx);
-  }
-
   Module new_mod = ModuleNode::make({}, mod->type_definitions);
   // Execute the pass function and return a new module.
   for (const auto& it : mod->functions) {
@@ -396,7 +322,6 @@ Module FunctionPassNode::operator()(const Module& mod,
                             : pass_func(it.second, updated_mod, pass_ctx);
     new_mod->Add(it.first, updated_func);
   }
-
   return new_mod;
 }
 
@@ -436,47 +361,40 @@ void SequentialNode::ResolveDependency(const Module& mod) {
              << "\n";
 }
 
-std::unordered_set<std::string> SequentialNode::DisabledPasses(
-    const Array<tvm::Expr>& disabled) const {
-  std::unordered_set<std::string> ret;
-  for (const auto& it : disabled) {
-    const auto* str = it.as<tvm::ir::StringImm>();
-    CHECK(str) << "Disabled pass name must be string.";
-    ret.emplace(str->value);
-  }
-  return ret;
-}
-
-std::unordered_set<std::string> SequentialNode::RequiredPasses(
-    const Array<tvm::Expr>& required) const {
-  std::unordered_set<std::string> ret;
-  for (const auto& it : required) {
-    const auto* str = it.as<tvm::ir::StringImm>();
-    CHECK(str) << "Required pass name must be string.";
-    ret.emplace(str->value);
+// linearly scan the pass array to match pass_name
+inline bool PassArrayContains(const Array<tvm::Expr>& pass_array,
+                              const std::string& pass_name) {
+  for (auto x : pass_array) {
+    auto* str_name = x.as<ir::StringImm>();
+    CHECK(str_name) << "pass name must be str";
+    if (str_name->value == pass_name) return true;
   }
-  return ret;
+  return false;
 }
 
-bool SequentialNode::PassEnabled(const std::string& pass_name) const {
+bool SequentialNode::PassEnabled(const PassInfo& info) const {
   PassContext ctx = PassContext::Current();
 
-  auto required = RequiredPasses(ctx->required_pass);
-  auto disabled = DisabledPasses(ctx->disabled_pass);
-
-  if (disabled.count(pass_name)) {
+  if (PassArrayContains(ctx->disabled_pass, info->name)) {
     return false;
   }
 
-  if (required.count(pass_name)) {
+  if (PassArrayContains(ctx->required_pass, info->name)) {
     return true;
   }
 
-  const Pass pass = GetPass(pass_name);
-  PassInfo info = pass->Info();
   return ctx->opt_level >= info->opt_level;
 }
 
+Pass GetPass(const std::string& pass_name) {
+  using tvm::runtime::Registry;
+  std::string fpass_name = "relay._transform." + pass_name;
+  const auto* f = Registry::Get(fpass_name);
+  CHECK(f != nullptr) << "Cannot find " << fpass_name
+                      << "to create the pass " << pass_name;
+  return (*f)();
+}
+
 // TODO(zhiics): we currenlty only sequentially execute each pass in
 // a Sequential without the consideration of their orders. The phase
 // ordering problem needs to be handled in the future.
@@ -485,13 +403,15 @@ Module SequentialNode::operator()(const Module& module,
   Module mod = module;
   for (const Pass& pass : passes) {
     CHECK(pass.defined()) << "Found undefined pass for optimization.";
-
-    PassInfo info = pass->Info();
-    const auto& pass_name = info->name;
-    // Execute the pass if it is enabled.
-    if (PassEnabled(pass_name)) {
-      mod = pass(mod, pass_ctx);
+    const PassInfo& pass_info = pass->Info();
+    if (!PassEnabled(pass_info))  continue;
+    // resolve dependencies
+    for (const auto& it : pass_info->required) {
+      const auto* name = it.as<tvm::ir::StringImm>();
+      CHECK(name);
+      mod = GetPass(name->value)(mod, pass_ctx);
     }
+    mod = pass(mod, pass_ctx);
   }
   return mod;
 }