[C++][API] Consistent RAII scoping API. (#3231)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 24 May 2019 16:29:14 +0000 (09:29 -0700)
committerGitHub <noreply@github.com>
Fri, 24 May 2019 16:29:14 +0000 (09:29 -0700)
22 files changed:
include/tvm/arithmetic.h
include/tvm/base.h
include/tvm/build_module.h
python/tvm/build_module.py
python/tvm/target.py
src/api/api_arith.cc
src/arithmetic/analyzer.cc
src/arithmetic/rewrite_simplify.cc
src/arithmetic/stmt_simplify.cc
src/codegen/build_module.cc
src/codegen/codegen_aocl.cc
src/codegen/codegen_vhls.cc
src/codegen/llvm/codegen_llvm.cc
src/codegen/spirv/codegen_spirv.cc
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.cc
src/relay/backend/vm/compiler.cc
src/relay/pass/fold_constant.cc
src/relay/pass/partial_eval.cc
tests/cpp/build_module_test.cc
tests/cpp/relay_build_module_test.cc
topi/src/topi.cc

index 6eec767..600e3c5 100644 (file)
@@ -290,14 +290,14 @@ class CanonicalSimplifier {
 };
 
 /*!
- * \brief A RAII constraint context.
+ * \brief Constraint context.
  *
  * \code
  *
  *  Var("x");
  *  arith::Analyzer analyzer;
  *  {
- *    arith::ConstraintContext cctx(&analyzer, x % 3 == 0);
+ *    With<arith::ConstraintContext> scope(&analyzer, x % 3 == 0);
  *    CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
  *  }
  *  // constraint no longer in effect.
@@ -306,19 +306,24 @@ class CanonicalSimplifier {
  * \endcode
  */
 class ConstraintContext {
- public:
+ private:
+  // declare friend to enable with.
+  friend class With<ConstraintContext>;
   /*!
    * \brief Construct a constraint context.
    * \param analyzer The analyzer.
    * \param constraint The constraint to be applied.
    */
-  ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION;
-  /*! \brief destructor */
-  ~ConstraintContext() DMLC_THROW_EXCEPTION {
-    exit_();
-  }
-
- private:
+  ConstraintContext(Analyzer* analyzer, Expr constraint)
+      : analyzer_(analyzer), constraint_(constraint) {}
+  // enter the scope.
+  void EnterWithScope();
+  // exit the scope.
+  void ExitWithScope();
+  /*! \brief The analyzer */
+  Analyzer* analyzer_;
+  /*! \brief The constraint */
+  Expr constraint_;
   /*! \brief function to be called in recovery */
   std::function<void()> exit_;
 };
index 049a427..f358f7f 100644 (file)
@@ -102,6 +102,50 @@ using ::tvm::AttrVisitor;
   };
 
 /*!
+ * \brief RAII wrapper function to enter and exit a context object
+ *        similar to python's with syntax.
+ *
+ * \code
+ * // context class
+ * class MyContext {
+ *  private:
+ *    friend class With<MyContext>;
+      MyContext(arguments);
+ *    void EnterWithScope();
+ *    void ExitWithScope();
+ * };
+ *
+ * {
+ *   With<MyContext> scope(arguments);
+ *   // effect take place.
+ * }
+ * \endcode
+ *
+ * \tparam ContextType Type of the context object.
+ */
+template<typename ContextType>
+class With {
+ public:
+  /*!
+   * \brief constructor.
+   *  Enter the scope of the context.
+   */
+  template<typename ...Args>
+  explicit With(Args&& ...args)
+      : ctx_(std::forward<Args>(args)...) {
+    ctx_.EnterWithScope();
+  }
+  /*! \brief destructor, leaves the scope of the context. */
+  ~With() DMLC_THROW_EXCEPTION {
+    ctx_.ExitWithScope();
+  }
+
+ private:
+  /*! \brief internal context type. */
+  ContextType ctx_;
+};
+
+/*!
  * \brief save the node as well as all the node it depends on as json.
  *  This can be used to serialize any TVM object
  *
index 7fb456c..187a745 100644 (file)
@@ -37,7 +37,7 @@ namespace tvm {
 
 /*!
 * \brief Container for target device information.
-* Use target::llvm, target::cuda etc functions instead of constructing directly.
+*   Use target::llvm, target::cuda etc functions instead of constructing directly.
 */
 class TargetNode : public Node {
  public:
@@ -89,65 +89,47 @@ class TargetNode : public Node {
   mutable std::string str_repr_;
 };
 
+/*! \brief reference cpass to the target. */
 class Target : public NodeRef {
  public:
   Target() {}
   explicit Target(NodePtr<Node> n) : NodeRef(n) {}
-
   /*!
   * \brief Create a Target given a string
   * \param target_str the string to parse
   */
-  TVM_DLL static Target create(const std::string& target_str);
-
-  /*!
-  * \brief Push a new target context onto the thread local stack. The Target on top of
-  * the stack is used to determine which specialization to use when invoking a GenericFunc.
-  * \param target The target to set as the current context.
-  */
-  TVM_DLL static void EnterTargetScope(const tvm::Target& target);
-
-  /*!
-  * \brief Pop a target off the thread local context stack, restoring the previous target
-  * as the current context.
-  */
-  TVM_DLL static void ExitTargetScope();
-
+  TVM_DLL static Target Create(const std::string& target_str);
   /*!
-  * \brief Get the current target context from thread local storage.
-  * \param allow_not_defined If the context stack is empty and this is set to true, an
-  * undefined Target will be returned. Otherwise, an empty context stack will cause a
-  * runtime error.
-  * \return The target that is the current context. The target may not be defined if
-  * allow_not_defined is true.
-  */
-  TVM_DLL static tvm::Target current_target(bool allow_not_defined = true);
+   * \brief Get the current target context from thread local storage.
+   * \param allow_not_defined If the context stack is empty and this is set to true, an
+   *   undefined Target will be returned. Otherwise, an empty context stack will cause a
+   *   runtime error.
+   * \return The target that is the current context. The target may not be defined if
+   * allow_not_defined is true.
+   */
+  TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
 
-  inline const TargetNode* operator->() const {
+  const TargetNode* operator->() const {
       return static_cast<const TargetNode*>(node_.get());
   }
 
   using ContainerType = TargetNode;
-};
-
-/*!
- * \brief RAII container to provide a scoped target context. Pushes a target onto the
- * context stack when constructed, and pops it when destructed.
- */
-struct TargetContext {
+  class Internal;
+ private:
+  // enable with syntax.
+  friend class Internal;
+  friend class With<Target>;
   /*!
-   * \brief Enter a new target context. The given target becomes the new current context.
-   * When the TargetContext is destructed, the previous context is restored.
-   * \param target The target to set as the new current context.
+   * \brief Push a new target context onto the thread local stack.
+   *  The Target on top of the stack is used to determine which
+   *  specialization to use when invoking a GenericFunc.
    */
-  explicit TargetContext(const tvm::Target& target) {
-    Target::EnterTargetScope(target);
-  }
-
-  /*! \brief Destructor. Pops the context off the thread local stack. */
-  ~TargetContext() {
-    Target::ExitTargetScope();
-  }
+  TVM_DLL void EnterWithScope();
+  /*!
+   * \brief Pop a target off the thread local context stack,
+   *  restoring the previous target as the current context.
+   */
+  TVM_DLL void ExitWithScope();
 };
 
 /*! \brief This namespace provides functions to construct Target instances */
@@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options =
 
 }  // namespace target
 
-class BuildConfig;
-
 /*!
-* \brief Container for build configuration options
-*/
+ * \brief Container for build configuration options
+ */
 class BuildConfigNode : public Node {
  public:
   /*!
@@ -271,70 +251,49 @@ class BuildConfigNode : public Node {
 };
 
 /*!
-* \brief Container for build configuration options
-*/
+ * \brief Build configuration for compilations.
+ */
 class BuildConfig : public ::tvm::NodeRef {
  public:
   BuildConfig() {}
   explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
-
   const BuildConfigNode* operator->() const {
     return static_cast<const BuildConfigNode*>(node_.get());
   }
-
   BuildConfigNode* operator->() {
     return static_cast<BuildConfigNode*>(node_.get());
   }
-
   /*!
-   * \brief Push a new BuildConfig context onto the thread local stack.
-   * \param build_config The configuration to set as the current context.
+   * \brief Construct a BuildConfig containing a empty build config node.
+   * \return The new BuildConfig
    */
-  TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config);
-
-  /*!
-   * \brief Pop a build config off the thread local context stack, restoring the previous
-   * configuration as the current context.
-   */
-  TVM_DLL static void ExitBuildConfigScope();
-
+  TVM_DLL static BuildConfig Create();
   /*!
    * \brief Get the current BuildConfig context from thread local storage, or a default
    * configuration if a BuildConfig scope has not been entered.
    * \return The configuration that is the current context.
    */
-  TVM_DLL static tvm::BuildConfig Current();
+  TVM_DLL static BuildConfig Current();
 
   using ContainerType = BuildConfigNode;
-};
+  class Internal;
 
-/*!
- * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
- * context stack when constructed, and pops it when destructed.
- */
-struct BuildConfigContext {
+ private:
+  // Enable with syntax.
+  friend class With<BuildConfig>;
   /*!
-   * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
-   * context. When the BuildConfigContext is destructed, the previous context is restored.
-   * \param build_config The BuildConfig to set as the new current context.
+   * \brief Push a new BuildConfig context onto the thread local stack.
    */
-  explicit BuildConfigContext(const tvm::BuildConfig& build_config) {
-    BuildConfig::EnterBuildConfigScope(build_config);
-  }
+  TVM_DLL void EnterWithScope();
 
-  /*! \brief Destructor. Pops the context off the thread local stack. */
-  ~BuildConfigContext() {
-    BuildConfig::ExitBuildConfigScope();
-  }
+  /*!
+   * \brief Pop a build config off the thread local context stack,
+   * restoring the previous configuration as the current context.
+   */
+  TVM_DLL void ExitWithScope();
 };
 
 /*!
-* \brief Construct a BuildConfig containing a new BuildConfigNode
-* \return The new BuildConfig
-*/
-TVM_DLL BuildConfig build_config();
-
-/*!
 * \brief Build a LoweredFunc given a schedule, args and binds
 * \param sch The schedule to lower.
 * \param args The arguments to the function.
index a28ab98..76170a8 100644 (file)
@@ -187,7 +187,7 @@ class BuildConfig(NodeBase):
     def __exit__(self, ptype, value, trace):
         if self.dump_pass_ir:
             BuildConfig._dump_ir.exit()
-        _api_internal._ExitBuildConfigScope()
+        _api_internal._ExitBuildConfigScope(self)
 
     def __setattr__(self, name, value):
         if name in BuildConfig._node_defaults:
index eff0088..828fff8 100644 (file)
@@ -133,7 +133,7 @@ class Target(NodeBase):
         return self
 
     def __exit__(self, ptype, value, trace):
-        _api_internal._ExitTargetScope()
+        _api_internal._ExitTargetScope(self)
 
 
 @register_node
index 55a7064..4d5d8bd 100644 (file)
@@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
         return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
             // can't use make_shared due to noexcept(false) decl in destructor,
             // see https://stackoverflow.com/a/43907314
-            auto ctx =
-                std::shared_ptr<ConstraintContext>(new ConstraintContext(self.get(), args[0]));
+            auto ctx = std::shared_ptr<With<ConstraintContext> >(
+                new With<ConstraintContext>(self.get(), args[0]));
             auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
               ctx.reset();
             };
index 420d6f9..bd8c700 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) {
   // skip rewrite simplify
 }
 
-ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
+
+void ConstraintContext::EnterWithScope() {
+  CHECK(exit_ == nullptr);
   // entering the scope.
-  auto f0 = analyzer->const_int_bound.EnterConstraint(constraint);
-  auto f1 = analyzer->modular_set.EnterConstraint(constraint);
+  auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
+  auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
   // recovery function.
   exit_ = [f0, f1]() {
     if (f1 != nullptr) f1();
@@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
   };
 }
 
+void ConstraintContext::ExitWithScope() {
+  CHECK(exit_ != nullptr);
+  exit_();
+}
+
 bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
   if (const auto* ptr = expr.as<ir::IntImm>()) {
     return ptr->value > lower_bound;
index 58d2b83..0de2a25 100644 (file)
@@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) {
   Expr cond = Mutate(op->condition);
   Expr true_value, false_value;
   {
-    ConstraintContext constraint(parent_, cond);
+    With<ConstraintContext> constraint(parent_, cond);
     true_value = Mutate(op->true_value);
   }
   {
-    ConstraintContext constraint(parent_, Mutate(Not::make(cond)));
+    With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
     false_value = Mutate(op->false_value);
   }
   if (is_zero(cond)) {
@@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) {
     Expr cond = Mutate(op->args[0]);
     Expr true_value, false_value;
     {
-      ConstraintContext constraint(parent_, cond);
+      With<ConstraintContext> constraint(parent_, cond);
       true_value = Mutate(op->args[1]);
     }
     {
-      ConstraintContext constraint(parent_, Mutate(Not::make(cond)));
+      With<ConstraintContext> constraint(parent_, Mutate(Not::make(cond)));
       false_value = Mutate(op->args[2]);
     }
     if (is_zero(cond)) {
index c793214..403187e 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator {
     Expr condition = this->Mutate(op->condition);
     Stmt then_case, else_case;
     {
-      ConstraintContext ctx(&analyzer_, condition);
+      With<ConstraintContext> ctx(&analyzer_, condition);
       then_case = this->Mutate(op->then_case);
     }
     if (op->else_case.defined()) {
-      ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition)));
+      With<ConstraintContext> ctx(&analyzer_, Mutate(Not::make(condition)));
       else_case = this->Mutate(op->else_case);
     }
     if (is_one(condition)) return then_case;
@@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator {
   Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
     Expr condition = this->Mutate(op->condition);
     Expr message = this->Mutate(op->message);
-    ConstraintContext ctx(&analyzer_, condition);
+    With<ConstraintContext> ctx(&analyzer_, condition);
     Stmt body = this->Mutate(op->body);
 
     if (condition.same_as(op->condition) &&
index ac6b797..834b4ee 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  *  Compile executable modules.
  * \file build_module.cc
  */
@@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate")
 TVM_REGISTER_API("_TargetFromString")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   std::string target_str = args[0];
-
-  *ret = Target::create(target_str);
+  *ret = Target::Create(target_str);
   });
 
 std::vector<std::string> TargetNode::keys() const {
@@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) {
   return "";
 }
 
-Target Target::create(const std::string& target_str) {
+Target Target::Create(const std::string& target_str) {
   if (target_str.length() == 0) {
     LOG(ERROR) << "target_str must not be empty";
   }
@@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) {
 struct TVMTargetThreadLocalEntry {
   /*! \brief The current target context */
   std::stack<tvm::Target> context_stack;
-
-  TVMTargetThreadLocalEntry() {
-  }
 };
 
 /*! \brief Thread local store to hold the Target context stack. */
 typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
 
-void Target::EnterTargetScope(const tvm::Target& target) {
+void Target::EnterWithScope() {
   TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
-  entry->context_stack.push(target);
+  entry->context_stack.push(*this);
 }
 
-void Target::ExitTargetScope() {
+void Target::ExitWithScope() {
   TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
+  CHECK(!entry->context_stack.empty());
+  CHECK(entry->context_stack.top().same_as(*this));
   entry->context_stack.pop();
 }
 
-tvm::Target Target::current_target(bool allow_not_defined) {
+tvm::Target Target::Current(bool allow_not_defined) {
   TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
   if (entry->context_stack.size() > 0) {
     return entry->context_stack.top();
@@ -574,7 +571,7 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
                       const BuildConfig& config) {
   Map<Target, Array<LoweredFunc>> updated_input;
   for (const auto& it : inputs) {
-    auto target = Target::create(it.first);
+    auto target = Target::Create(it.first);
     updated_input.Set(target, it.second);
   }
   return build(updated_input, target_host, config);
@@ -589,33 +586,35 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
   return build(inputs, target_host, config);
 }
 
-BuildConfig build_config() {
+BuildConfig BuildConfig::Create() {
   return BuildConfig(make_node<BuildConfigNode>());
 }
 
 /*! \brief Entry to hold the BuildConfig context stack. */
 struct TVMBuildConfigThreadLocalEntry {
   /*! \brief The default build config if the stack is empty */
-  tvm::BuildConfig default_config;
+  BuildConfig default_config;
 
   /*! \brief The current build config context */
-  std::stack<tvm::BuildConfig> context_stack;
+  std::stack<BuildConfig> context_stack;
 
   TVMBuildConfigThreadLocalEntry() :
-    default_config(build_config()) {
+      default_config(BuildConfig::Create()) {
   }
 };
 
 /*! \brief Thread local store to hold the BuildConfig context stack. */
 typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
 
-void BuildConfig::EnterBuildConfigScope(const tvm::BuildConfig& build_config) {
+void BuildConfig::EnterWithScope() {
   TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
-  entry->context_stack.push(build_config);
+  entry->context_stack.push(*this);
 }
 
-void BuildConfig::ExitBuildConfigScope() {
+void BuildConfig::ExitWithScope() {
   TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
+  CHECK(!entry->context_stack.empty());
+  CHECK(entry->context_stack.top().same_as(*this));
   entry->context_stack.pop();
 }
 
@@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
 
 void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
   auto node = static_cast<GenericFuncNode*>(node_.get());
-  auto target = Target::current_target(true);
+  auto target = Target::Current(true);
   PackedFunc func;
 
   if (target.defined()) {
@@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig")
   *ret = BuildConfig::Current();
   });
 
+class BuildConfig::Internal {
+ public:
+  static void EnterScope(BuildConfig target) {
+    target.EnterWithScope();
+  }
+  static void ExitScope(BuildConfig target) {
+    target.ExitWithScope();
+  }
+};
+
 TVM_REGISTER_API("_EnterBuildConfigScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  BuildConfig target = args[0];
-  BuildConfig::EnterBuildConfigScope(target);
-  });
+.set_body_typed(BuildConfig::Internal::EnterScope);
 
 TVM_REGISTER_API("_ExitBuildConfigScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  BuildConfig::ExitBuildConfigScope();
-  });
+.set_body_typed(BuildConfig::Internal::ExitScope);
 
 TVM_REGISTER_API("_BuildConfigSetAddLowerPass")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc")
 TVM_REGISTER_API("_GetCurrentTarget")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   bool allow_not_defined = args[0];
-  *ret = Target::current_target(allow_not_defined);
+  *ret = Target::Current(allow_not_defined);
   });
 
+class Target::Internal {
+ public:
+  static void EnterScope(Target target) {
+    target.EnterWithScope();
+  }
+  static void ExitScope(Target target) {
+    target.ExitWithScope();
+  }
+};
+
 TVM_REGISTER_API("_EnterTargetScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  Target target = args[0];
-  Target::EnterTargetScope(target);
-  });
+.set_body_typed(Target::Internal::EnterScope);
 
 TVM_REGISTER_API("_ExitTargetScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  Target::ExitTargetScope();
-  });
+.set_body_typed(Target::Internal::ExitScope);
 
 }  // namespace tvm
index 6f899cb..03b9b68 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
   std::string cmd = "aoc aocl.cl";
   // AOCL supports fp64.
   cmd += " -Dcl_khr_fp64";
-  Target target = Target::create(target_str);
+  Target target = Target::Create(target_str);
   if (target->device_name != "") {
     cmd += " -board=" + target->device_name;
   }
index a18312f..4d86cc5 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
 
   std::string xclbin;
   if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
-    Target target = Target::create(target_str);
+    Target target = Target::Create(target_str);
     xclbin = (*f)(kernel_info, target->device_name).operator std::string();
   } else {
     LOG(FATAL) << "Cannot compile Vivado HLS code.";
index bedcdc7..1e56583 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
 }
 
 void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
-  arith::ConstraintContext cctx(analyzer_.get(), op->condition);
+  With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
   this->VisitStmt(op->body);
 }
 
index e6fc008..fd113ca 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
 }
 
 void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
-  arith::ConstraintContext cctx(analyzer_.get(), op->condition);
+  With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
   this->VisitStmt(op->body);
 }
 
index 8a0c32f..3b14910 100644 (file)
@@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode {
       if (targets.size() == 1) {
         func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
         for (const auto& kv : targets) {
-          TargetContext tctx(kv.second);
+          With<Target> tctx(kv.second);
           func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
         }
       } else {
@@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode {
    */
   Target CreateDefaultTarget(int device_type) {
     std::string name = runtime::DeviceName(device_type);
-    if (name == "cpu") return Target::create("llvm");
-    if (name == "gpu") return Target::create("cuda");
-    return Target::create(name);
+    if (name == "cpu") return Target::Create("llvm");
+    if (name == "gpu") return Target::Create("cuda");
+    return Target::Create(name);
   }
   /*!
    * \brief Update the target and fallback device required for heterogeneous
@@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode {
                   const RelayBuildConfig& cfg,
                   const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
     // convert
-    tvm_cfg_ = build_config();
+    tvm_cfg_ = BuildConfig::Create();
     TargetsMap device_target;
     if (targets_.size() > 1) {
       device_target = UpdateHeterogeneousInputs(targets_, cfg);
index a824c45..f11dd28 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode {
       cache_[key] = value;
     }
     // Enforce use the target.
-    TargetContext target_ctx(key->target);
+    With<Target> target_scope(key->target);
 
     CHECK(!value->cached_func.defined());
     auto spair = CreateSchedule(key->source_func, key->target);
@@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode {
       cache_node->funcs = (*f)(
           spair.first, all_args, cache_node->func_name, key->source_func);
     } else {
-      tvm::BuildConfig bcfg = tvm::build_config();
+      tvm::BuildConfig bcfg = BuildConfig::Create();
       std::unordered_map<Tensor, Buffer> binds;
       cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
     }
index 97f03c6..602e927 100644 (file)
@@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
 
     // Next generate the invoke instruction.
     CHECK(func->IsPrimitive());
-    auto target = Target::create("llvm");
+    auto target = Target::Create("llvm");
     auto key = CCacheKeyNode::make(func, target);
     auto cfunc = engine->Lower(key);
     // TODO(jroesch): support lowered funcs for multiple targets
@@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
   runtime::Module mod;
   if (lowered_funcs.size() > 0) {
     // TODO(@jroesch): we need to read target from build config
-    Target target = Target::create("llvm");
+    Target target = Target::Create("llvm");
     if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
       mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target);
     } else {
index 45aa449..c085d80 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) {
   DLContext ctx;
   ctx.device_type = kDLCPU;
   ctx.device_id = 0;
-  Target target = Target::create("llvm");
+  Target target = Target::Create("llvm");
   // use a fresh build context
   // in case we are already in a build context.
-  BuildConfigContext fresh_build_ctx(build_config());
+  With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
 
   return ConstantFolder(CreateInterpreter(
       Module(nullptr), ctx, target)).Mutate(expr);
index 5349532..ad86174 100644 (file)
@@ -375,10 +375,10 @@ DLContext CPUContext() {
 }
 
 FInterpreter CPUInterpreter() {
-  Target target = Target::create("llvm");
+  Target target = Target::Create("llvm");
   // use a fresh build context
   // in case we are already in a build context.
-  BuildConfigContext fresh_build_ctx(build_config());
+  With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
 
   return CreateInterpreter(Module(nullptr), CPUContext(), target);
 }
index 393714d..6dbd78e 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -50,14 +50,14 @@ TEST(BuildModule, Basic) {
   auto args = Array<Tensor>({ A, B, C });
   std::unordered_map<Tensor, Buffer> binds;
 
-  auto config = build_config();
+  auto config = BuildConfig::Create();
   auto target = target::llvm();
 
   auto lowered = lower(s, args, "func", binds, config);
   auto module = build(lowered, target, Target(), config);
 
-  auto mali_target = Target::create("opencl -model=Mali-T860MP4@800Mhz -device=mali");
-  CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali"); 
+  auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali");
+  CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali");
 }
 
 TEST(BuildModule, Heterogeneous) {
@@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) {
   auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
   auto s2 = create_schedule({elemwise_sub->op});
 
-  auto config = build_config();
+  auto config = BuildConfig::Create();
   auto args1 = Array<Tensor>({A, B, elemwise_add});
   auto args2 = Array<Tensor>({copy, C, elemwise_sub});
 
index a1ab299..3f46eed 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -75,7 +75,7 @@ TEST(Relay, BuildModule) {
   auto json_f = build_mod.GetFunction("get_graph_json", false);
   auto mod_f = build_mod.GetFunction("get_module", false);
   Map<tvm::Integer, tvm::Target> targets;
-  Target llvm_tgt = Target::create("llvm");
+  Target llvm_tgt = Target::Create("llvm");
   targets.Set(0, llvm_tgt);
   build_f(func, targets, llvm_tgt);
   std::string json = json_f();
index d3e0bc9..57a2743 100644 (file)
@@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) {
 
 TVM_REGISTER_GLOBAL("topi.TEST_create_target")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
-  *rv = tvm::Target::create(args[0]);
+  *rv = tvm::Target::Create(args[0]);
   });
 
 /* Ops from broadcast.h */
@@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function<
  */
 inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
   return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
-    auto target = Target::current_target(false);
+    auto target = Target::Current(false);
     Array<Tensor> outs;
     NodeRef argNodeRef = args[0];
     if (argNodeRef->type_index() == outs->type_index()) {
@@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
 */
 inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
   return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
-    auto target = Target::current_target(false);
+    auto target = Target::Current(false);
     Tensor data = args[0];
     Tensor weight = args[1];
     Tensor bias = args[2];