[C++] Cleanup transform API nits (#3253)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 29 May 2019 01:12:17 +0000 (18:12 -0700)
committerGitHub <noreply@github.com>
Wed, 29 May 2019 01:12:17 +0000 (18:12 -0700)
include/tvm/relay/transform.h
src/relay/pass/pass_manager.cc

index 4d6921a..1c1b608 100644 (file)
@@ -76,8 +76,8 @@ namespace transform {
 class PassContext;
 
 /*!
- * \brief PassContextNode contains the information that a pass can rely on, such as
- * analysis results.
+ * \brief PassContextNode contains the information that a pass can rely on,
+ * such as analysis results.
  */
 class PassContextNode : public RelayNode {
  public:
@@ -110,32 +110,51 @@ class PassContextNode : public RelayNode {
   TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
 };
 
+/*!
+ * \brief PassContext that is used to configure the pass behavior.
+ *
+ * \code
+ *
+ *  auto new_ctx = PassContext::Create();
+ *  ctx->opt_level = 2;
+ *  ctx->fallback_device = kDLCPU;
+ *  With<PassContext> scope(ctx);
+ *  // pass context in effect.
+ *
+ * \endcode
+ */
 class PassContext : public NodeRef {
  public:
   PassContext() {}
-  explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}
-
-  /*
-   * \brief Constructor of a `PassContext` object.
-   *
-   * \param opt_level The optimization level that will be applied.
-   * \param fallback_device The fallback device used for heterogeneous
-   *        execution.
-   * \param required_pass The passes that are required for a context to execute
-   *        other passes.
-   * \param required_pass The passes that will be disabled during the
-   *        optimization under a context.
+  explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {}
+  /*!
+   * \brief const accessor.
+   * \return const access pointer.
+   */
+  const PassContextNode* operator->() const {
+    CHECK(node_.get() != nullptr);
+    return static_cast<const PassContextNode*>(node_.get());
+  }
+  /*!
+   * \brief mutable accessor.
+   * \return mutable access pointer.
+   */
+  PassContextNode* operator->() {
+    CHECK(node_.get() != nullptr);
+    return static_cast<PassContextNode*>(node_.get());
+  }
+  /*!
+   * \brief Construct a PassContext containing the default configurations.
+   * \return The new PassContext.
+   */
+  TVM_DLL static PassContext Create();
+  /*!
+   * \brief Get the default pass context in the current scope.
+   * \return The pass context.
    */
-  TVM_DLL PassContext(int opt_level,
-                      int fallback_device,
-                      tvm::Array<tvm::Expr> required_pass,
-                      tvm::Array<tvm::Expr> disabled_pass);
-
-  // Get the currently used pass context.
   TVM_DLL static PassContext Current();
 
-  const PassContextNode* operator->() const;
-
+  // accessor.
   using ContainerType = PassContextNode;
   class Internal;
 
@@ -204,25 +223,23 @@ class PassNode : public RelayNode {
   virtual PassInfo Info() const = 0;
 
   /*!
-   * \brief Execute the optimization pass using a functor. This functor
-   * internally uses a current pass context.
+   * \brief Transform mod using the default PassContext in the current scope.
    *
    * \param mod The module that an optimization pass runs on.
    *
-   * \return The updated module.
+   * \return The transformed module.
    */
   Module operator()(const Module& mod) const {
     return this->operator()(mod, PassContext::Current());
   }
 
   /*!
-   * \brief Execute the optimization pass using a functor under a given pass context.
+   * \brief Transform mod using a functor under a given pass context.
    *
    * \param mod The module that an optimization pass runs on.
-   * \param pass_ctx The pass context that will be used to help the execution of
-   *        optimizations.
+   * \param pass_ctx The pass context that can provide information for the optimization.
    *
-   * \return The updated module.
+   * \return The transformed module.
    */
   virtual Module operator()(const Module& mod,
                             const PassContext& pass_ctx) const = 0;
@@ -235,14 +252,34 @@ class PassNode : public RelayNode {
 
 class Pass : public NodeRef {
  public:
-  Pass() = default;
-  explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
-
-  PassNode* operator->() const {
-    return static_cast<PassNode*>(this->node_.get());
+  /*!
+   * \brief Transform mod using the default PassContext in the current scope.
+   *
+   * \param mod The module that an optimization pass runs on.
+   *
+   * \return The transformed module.
+   */
+  Module operator()(const Module& mod) const {
+    const PassNode* node = operator->();
+    CHECK(node != nullptr);
+    return node->operator()(mod);
+  }
+  /*!
+   * \brief Transform mod using a functor under a given pass context.
+   *
+   * \param mod The module that an optimization pass runs on.
+   * \param pass_ctx The pass context that can provide information for the optimization.
+   *
+   * \return The transformed module.
+   */
+  Module operator()(const Module& mod,
+                    const PassContext& pass_ctx) const {
+    const PassNode* node = operator->();
+    CHECK(node != nullptr);
+    return node->operator()(mod, pass_ctx);
   }
 
-  using ContainerType = PassNode;
+  TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode);
 };
 
 class SequentialNode;
index ea4c976..a9c671a 100644 (file)
@@ -74,21 +74,6 @@ class OptPassLevel {
   }
 };
 
-PassContext::PassContext(int opt_level, int fallback_device,
-                         tvm::Array<tvm::Expr> required_pass,
-                         tvm::Array<tvm::Expr> disabled_pass) {
-  auto ctx = make_node<PassContextNode>();
-  ctx->opt_level = opt_level;
-  ctx->fallback_device = fallback_device;
-  ctx->required_pass = std::move(required_pass);
-  ctx->disabled_pass = std::move(disabled_pass);
-  node_ = std::move(ctx);
-}
-
-const PassContextNode* PassContext::operator->() const {
-  return static_cast<const PassContextNode*>(node_.get());
-}
-
 struct RelayPassContextThreadLocalEntry {
   /*! \brief The default pass context. */
   PassContext default_context;
@@ -129,6 +114,10 @@ PassContext PassContext::Current() {
   }
 }
 
+PassContext PassContext::Create() {
+  return PassContext(make_node<PassContextNode>());
+}
+
 class ModulePass;
 
 /*!
@@ -291,7 +280,7 @@ class SequentialNode : public PassNode {
    *
    * \return true if the pass is enabled. Otherwise, false.
    */
-  bool pass_enabled(const std::string& pass_name) const;
+  bool PassEnabled(const std::string& pass_name) const;
 
   /*!
    * \brief Resolve the pass dependency. It globs all required passes by
@@ -353,9 +342,8 @@ ModulePass ModulePassNode::make(
 Module ModulePassNode::operator()(const Module& mod,
                                   const PassContext& pass_ctx) const {
   PassInfo pass_info = Info();
-  LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name
-            << " with opt level: " << pass_info.operator->()->opt_level << "\n";
-
+  DLOG(INFO) << "Executing module pass : " << pass_info->name
+             << " with opt level: " << pass_info->opt_level << "\n";
   CHECK(mod.defined());
   auto updated_mod = pass_func(mod, pass_ctx);
   CHECK(updated_mod.defined());
@@ -376,11 +364,10 @@ FunctionPass FunctionPassNode::make(
 Module FunctionPassNode::operator()(const Module& mod,
                                     const PassContext& pass_ctx) const {
   PassInfo pass_info = Info();
-  LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
-            << " with opt level: " << pass_info.operator->()->opt_level << "\n";
   CHECK(mod.defined());
   Module new_mod = ModuleNode::make({}, mod->type_definitions);
-
+  DLOG(INFO) << "Executing module pass : " << pass_info->name
+             << " with opt level: " << pass_info->opt_level << "\n";
   // Execute the pass function and return a new module.
   for (const auto& it : mod->functions) {
     auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx);
@@ -448,12 +435,11 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
   return ret;
 }
 
-bool SequentialNode::pass_enabled(const std::string& pass_name) const {
+bool SequentialNode::PassEnabled(const std::string& pass_name) const {
   PassContext ctx = PassContext::Current();
 
-  const PassContextNode* ctx_node = ctx.operator->();
-  auto required = RequiredPasses(ctx_node->required_pass);
-  auto disabled = DisabledPasses(ctx_node->required_pass);
+  auto required = RequiredPasses(ctx->required_pass);
+  auto disabled = DisabledPasses(ctx->required_pass);
 
   if (disabled.count(pass_name)) {
     return false;
@@ -462,7 +448,7 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
   if (required.count(pass_name)) {
     return true;
   }
-  return ctx_node->opt_level >= opt_pass_level[pass_name];
+  return ctx->opt_level >= opt_pass_level[pass_name];
 }
 
 // TODO(zhiics): we currenlty only sequentially execute each pass in
@@ -470,15 +456,14 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
 // ordering problem needed to be handled in the future.
 Module SequentialNode::operator()(const Module& module,
                                   const PassContext& pass_ctx) const {
-  const auto* ctx_node = pass_ctx.operator->();
-  int opt_level = ctx_node->opt_level;
-  auto disabled = DisabledPasses(ctx_node->disabled_pass);
+  int opt_level = pass_ctx->opt_level;
+  auto disabled = DisabledPasses(pass_ctx->disabled_pass);
   Module mod = module;
   for (const Pass& pass : passes) {
     CHECK(pass.defined()) << "Found undefined pass for optimization.";
     PassInfo info = pass->Info();
-    const auto& pass_name = info.operator->()->name;
-    const auto& pass_opt_level = info.operator->()->opt_level;
+    const auto& pass_name = info->name;
+    const auto& pass_opt_level = info->opt_level;
     // Skip the pass if its optimization level is higher that the  one of in the
     // pass context or if this pass is disabled.
     if (pass_opt_level > opt_level || disabled.count(pass_name)) {
@@ -540,14 +525,7 @@ TVM_REGISTER_API("relay._transform.CreateModulePass")
 
 TVM_REGISTER_API("relay._transform.RunPass")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
-  Pass pass = args[0];
-  Module mod = args[1];
-  CHECK(pass.defined())
-      << "Running an undefined pass is not allowed."
-      << "\n";
-
-  const auto* pn = pass.operator->();
-  *ret = (*pn)(mod);
+  *ret = args[0].operator Pass()(args[1]);
 });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
@@ -602,11 +580,16 @@ TVM_REGISTER_NODE_TYPE(PassContextNode);
 
 TVM_REGISTER_API("relay._transform.PassContext")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
+  auto pctx = PassContext::Create();
   int opt_level = args[0];
   int fallback_device = args[1];
   tvm::Array<tvm::Expr> required = args[2];
   tvm::Array<tvm::Expr> disabled = args[3];
-  *ret = PassContext(opt_level, fallback_device, required, disabled);
+  pctx->opt_level = opt_level;
+  pctx->fallback_device = fallback_device;
+  pctx->required_pass = std::move(required);
+  pctx->disabled_pass = std::move(disabled);
+  *ret = pctx;
 });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)