[REFACTOR][IR] Introduce include/tvm/target (#4721)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 16 Jan 2020 04:23:15 +0000 (20:23 -0800)
committerGitHub <noreply@github.com>
Thu, 16 Jan 2020 04:23:15 +0000 (20:23 -0800)
As part of Unified IR infra.
Introduce target folder to store all the compilation target related information.

12 files changed:
.gitignore
CMakeLists.txt
include/tvm/build_module.h
include/tvm/target/target.h [new file with mode: 0644]
include/tvm/target/target_info.h [moved from include/tvm/target_info.h with 90% similarity]
src/codegen/build_module.cc
src/pass/storage_access.cc
src/pass/storage_flatten.cc
src/pass/storage_rewrite.cc
src/pass/tensor_core.cc
src/target/target.cc [new file with mode: 0644]
src/target/target_info.cc [moved from src/lang/target_info.cc with 75% similarity]

index 2f124d9..068cb87 100644 (file)
@@ -65,7 +65,7 @@ docs/_build/
 docs/gen_modules
 
 # PyBuilder
-target/
+/target/
 
 # IPython Notebook
 .ipynb_checkpoints
index b823528..825da5a 100644 (file)
@@ -127,6 +127,7 @@ assign_source_group("Include" ${GROUP_INCLUDE})
 file(GLOB COMPILER_SRCS
     src/node/*.cc
     src/ir/*.cc
+    src/target/*.cc
     src/api/*.cc
     src/arithmetic/*.cc
     src/autotvm/*.cc
index 8b49fb7..8919188 100644 (file)
 #ifndef TVM_BUILD_MODULE_H_
 #define TVM_BUILD_MODULE_H_
 
+#include <tvm/target/target.h>
+
 #include <string>
 #include <vector>
 #include <utility>
 #include <unordered_map>
 #include <unordered_set>
+
 #include "runtime/packed_func.h"
 #include "schedule_pass.h"
 #include "lowered_func.h"
 namespace tvm {
 
 /*!
-* \brief Container for target device information.
-*   Use target::llvm, target::cuda etc functions instead of constructing directly.
-*/
-class TargetNode : public Object {
- public:
-  /*! \brief The name of the target device */
-  std::string target_name;
-  /*! \brief The name of the target device */
-  std::string device_name;
-  /*! \brief The type of the target device */
-  int device_type;
-  /*! \brief The maximum threads that a schedule should use for this device */
-  int max_num_threads = 1;
-  /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
-  int thread_warp_size = 1;
-  /*! \brief Keys for this target */
-  Array<PrimExpr> keys_array;
-  /*! \brief Options for this target */
-  Array<PrimExpr> options_array;
-  /*! \brief Collection of imported libs */
-  Array<PrimExpr> libs_array;
-
-  /*! \return the full device string to pass to codegen::Build */
-  TVM_DLL const std::string& str() const;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("target_name", &target_name);
-    v->Visit("device_name", &device_name);
-    v->Visit("device_type", &device_type);
-    v->Visit("max_num_threads", &max_num_threads);
-    v->Visit("thread_warp_size", &thread_warp_size);
-    v->Visit("keys_array", &keys_array);
-    v->Visit("options_array", &options_array);
-    v->Visit("libs_array", &libs_array);
-  }
-
-  /*! \brief Get the keys for this target as a vector of string */
-  TVM_DLL std::vector<std::string> keys() const;
-
-  /*! \brief Get the options for this target as a vector of string */
-  TVM_DLL std::vector<std::string> options() const;
-
-  /*! \brief Get the keys for this target as an unordered_set of string */
-  TVM_DLL std::unordered_set<std::string> libs() const;
-
-  static constexpr const char* _type_key = "Target";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);
-
- private:
-  /*! \brief Internal string repr. */
-  mutable std::string str_repr_;
-};
-
-/*! \brief reference cpass to the target. */
-class Target : public ObjectRef {
- public:
-  Target() {}
-  explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-  * \brief Create a Target given a string
-  * \param target_str the string to parse
-  */
-  TVM_DLL static Target Create(const std::string& target_str);
-  /*!
-   * \brief Get the current target context from thread local storage.
-   * \param allow_not_defined If the context stack is empty and this is set to true, an
-   *   undefined Target will be returned. Otherwise, an empty context stack will cause a
-   *   runtime error.
-   * \return The target that is the current context. The target may not be defined if
-   * allow_not_defined is true.
-   */
-  TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
-
-  const TargetNode* operator->() const {
-      return static_cast<const TargetNode*>(get());
-  }
-
-  using ContainerType = TargetNode;
-  class Internal;
- private:
-  // enable with syntax.
-  friend class Internal;
-  friend class With<Target>;
-  /*!
-   * \brief Push a new target context onto the thread local stack.
-   *  The Target on top of the stack is used to determine which
-   *  specialization to use when invoking a GenericFunc.
-   */
-  TVM_DLL void EnterWithScope();
-  /*!
-   * \brief Pop a target off the thread local context stack,
-   *  restoring the previous target as the current context.
-   */
-  TVM_DLL void ExitWithScope();
-};
-
-/*! \brief This namespace provides functions to construct Target instances */
-namespace target {
-/*! \return A target for LLVM */
-TVM_DLL Target llvm(const std::vector<std::string>& options =
-                   std::vector<std::string>());
-
-/*! \return A target for CUDA */
-TVM_DLL Target cuda(const std::vector<std::string>& options =
-                   std::vector<std::string>());
-
-/*! \return A target for ROCm */
-TVM_DLL Target rocm(const std::vector<std::string>& options =
-                   std::vector<std::string>());
-
-/*! \return A target for OpenCL */
-TVM_DLL Target opencl(const std::vector<std::string>& options =
-                     std::vector<std::string>());
-
-/*! \return A target for Metal */
-TVM_DLL Target metal(const std::vector<std::string>& options =
-                    std::vector<std::string>());
-
-/*! \return A target for rasp */
-TVM_DLL Target rasp(const std::vector<std::string>& options =
-                   std::vector<std::string>());
-
-/*! \return A target for Mali */
-TVM_DLL Target mali(const std::vector<std::string>& options =
-                   std::vector<std::string>());
-
-/*! \return A target for Intel Graphics */
-TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
-                             std::vector<std::string>());
-
-/*! \return A target for stackvm */
-TVM_DLL Target stackvm(const std::vector<std::string>& options =
-                      std::vector<std::string>());
-
-/*! \return A target for external device */
-TVM_DLL Target ext_dev(const std::vector<std::string>& options =
-                   std::vector<std::string>());
-}  // namespace target
-
-/*!
  * \brief Container for build configuration options
  */
 class BuildConfigNode : public Object {
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
new file mode 100644 (file)
index 0000000..fd8ab68
--- /dev/null
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/target/target.h
+ * \brief Compilation target object.
+ */
+#ifndef TVM_TARGET_TARGET_H_
+#define TVM_TARGET_TARGET_H_
+
+#include <tvm/support/with.h>
+#include <tvm/node/container.h>
+#include <tvm/ir/expr.h>
+
+#include <string>
+#include <vector>
+#include <unordered_set>
+
+namespace tvm {
+/*!
+ * \brief Compilation target.
+ * \note Use target::llvm, target::cuda etc functions.
+ * \sa Target
+ */
+class TargetNode : public Object {
+ public:
+  /*! \brief The name of the target device */
+  std::string target_name;
+  /*! \brief The name of the target device */
+  std::string device_name;
+  /*! \brief The type of the target device */
+  int device_type;
+  /*! \brief The maximum threads that a schedule should use for this device */
+  int max_num_threads = 1;
+  /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
+  int thread_warp_size = 1;
+  /*! \brief Keys for this target */
+  Array<PrimExpr> keys_array;
+  /*! \brief Options for this target */
+  Array<PrimExpr> options_array;
+  /*! \brief Collection of imported libs */
+  Array<PrimExpr> libs_array;
+
+  /*! \return the full device string to pass to codegen::Build */
+  TVM_DLL const std::string& str() const;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("target_name", &target_name);
+    v->Visit("device_name", &device_name);
+    v->Visit("device_type", &device_type);
+    v->Visit("max_num_threads", &max_num_threads);
+    v->Visit("thread_warp_size", &thread_warp_size);
+    v->Visit("keys_array", &keys_array);
+    v->Visit("options_array", &options_array);
+    v->Visit("libs_array", &libs_array);
+  }
+
+  /*! \brief Get the keys for this target as a vector of string */
+  TVM_DLL std::vector<std::string> keys() const;
+
+  /*! \brief Get the options for this target as a vector of string */
+  TVM_DLL std::vector<std::string> options() const;
+
+  /*! \brief Get the keys for this target as an unordered_set of string */
+  TVM_DLL std::unordered_set<std::string> libs() const;
+
+  static constexpr const char* _type_key = "Target";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);
+
+ private:
+  /*! \brief Internal string repr. */
+  mutable std::string str_repr_;
+};
+
+/*!
+ * \brief Managed reference class to TargetNode.
+ * \sa TargetNode
+ */
+class Target : public ObjectRef {
+ public:
+  Target() {}
+  explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+  * \brief Create a Target given a string
+  * \param target_str the string to parse
+  */
+  TVM_DLL static Target Create(const std::string& target_str);
+  /*!
+   * \brief Get the current target context from thread local storage.
+   * \param allow_not_defined If the context stack is empty and this is set to true, an
+   *   undefined Target will be returned. Otherwise, an empty context stack will cause a
+   *   runtime error.
+   * \return The target that is the current context. The target may not be defined if
+   * allow_not_defined is true.
+   */
+  TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
+
+  const TargetNode* operator->() const {
+      return static_cast<const TargetNode*>(get());
+  }
+
+  using ContainerType = TargetNode;
+  class Internal;
+ private:
+  // enable with syntax.
+  friend class Internal;
+  friend class With<Target>;
+  /*!
+   * \brief Push a new target context onto the thread local stack.
+   *  The Target on top of the stack is used to determine which
+   *  specialization to use when invoking a GenericFunc.
+   */
+  TVM_DLL void EnterWithScope();
+  /*!
+   * \brief Pop a target off the thread local context stack,
+   *  restoring the previous target as the current context.
+   */
+  TVM_DLL void ExitWithScope();
+};
+
+/*! \brief This namespace provides functions to construct Target instances */
+namespace target {
+
+/*! \return A target for LLVM */
+TVM_DLL Target llvm(const std::vector<std::string>& options =
+                    std::vector<std::string>());
+
+/*! \return A target for CUDA */
+TVM_DLL Target cuda(const std::vector<std::string>& options =
+                    std::vector<std::string>());
+
+/*! \return A target for ROCm */
+TVM_DLL Target rocm(const std::vector<std::string>& options =
+                    std::vector<std::string>());
+
+/*! \return A target for OpenCL */
+TVM_DLL Target opencl(const std::vector<std::string>& options =
+                      std::vector<std::string>());
+
+/*! \return A target for Metal */
+TVM_DLL Target metal(const std::vector<std::string>& options =
+                     std::vector<std::string>());
+
+/*! \return A target for rasp */
+TVM_DLL Target rasp(const std::vector<std::string>& options =
+                    std::vector<std::string>());
+
+/*! \return A target for Mali */
+TVM_DLL Target mali(const std::vector<std::string>& options =
+                    std::vector<std::string>());
+
+/*! \return A target for Intel Graphics */
+TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
+                              std::vector<std::string>());
+
+/*! \return A target for stackvm */
+TVM_DLL Target stackvm(const std::vector<std::string>& options =
+                       std::vector<std::string>());
+
+/*! \return A target for external device */
+TVM_DLL Target ext_dev(const std::vector<std::string>& options =
+                       std::vector<std::string>());
+}  // namespace target
+}  // namespace tvm
+#endif  // TVM_TARGET_TARGET_H_
similarity index 90%
rename from include/tvm/target_info.h
rename to include/tvm/target/target_info.h
index 0a42a76..4466476 100644 (file)
  */
 
 /*!
- * \file tvm/target_info.h
+ * \file tvm/target/target_info.h
  * \brief Various information about target.
  */
-#ifndef TVM_TARGET_INFO_H_
-#define TVM_TARGET_INFO_H_
+#ifndef TVM_TARGET_TARGET_INFO_H_
+#define TVM_TARGET_TARGET_INFO_H_
 
+#include <tvm/ir/expr.h>
 #include <string>
-#include "expr.h"
 
 namespace tvm {
 
@@ -33,7 +33,8 @@ namespace tvm {
  * \brief Memory information of special memory region.
  *  Use MemoryInfo as its container type
  */
-struct MemoryInfoNode : public Object {
+class MemoryInfoNode : public Object {
+ public:
   /*! \brief The addressable unit */
   int unit_bits;
   /*! \brief Maximum number of bits supported in the memory */
@@ -71,4 +72,4 @@ class MemoryInfo : public ObjectRef {
 TVM_DLL MemoryInfo GetMemoryInfo(const std::string& scope);
 
 }  // namespace tvm
-#endif  // TVM_TARGET_INFO_H_
+#endif  // TVM_TARGET_TARGET_INFO_H_
index 9f79342..771583b 100644 (file)
@@ -38,288 +38,8 @@ using runtime::TVMArgs;
 using runtime::TVMRetValue;
 using runtime::PackedFunc;
 
-TVM_REGISTER_NODE_TYPE(TargetNode);
 TVM_REGISTER_NODE_TYPE(GenericFuncNode);
 
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const TargetNode*>(node.get());
-    p->stream << op->str();
-  });
-
-
-/*!
-* \brief Construct a Target node from the given name and options.
-* \param target_name The major target name. Should be one of
-* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hybrid", "llvm", "metal",
-*  "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"}
-* \param options Additional options appended to the target
-* \return The constructed Target
-*/
-Target CreateTarget(const std::string& target_name,
-                    const std::vector<std::string>& options) {
-  auto t = make_object<TargetNode>();
-  t->target_name = target_name;
-
-  std::string libs_flag = "-libs=";
-  std::string device_flag = "-device=";
-  std::string keys_flag = "-keys=";
-  for (auto& item : options) {
-    t->options_array.push_back(ir::StringImmNode::make(item));
-
-    if (item.find(libs_flag) == 0) {
-      std::stringstream ss(item.substr(libs_flag.length()));
-      std::string lib_item;
-      while (std::getline(ss, lib_item, ',')) {
-        t->libs_array.push_back(ir::StringImmNode::make(lib_item));
-      }
-    } else if (item.find(device_flag) == 0) {
-      t->device_name = item.substr(device_flag.length());
-      t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
-    } else if (item.find(keys_flag) == 0) {
-      std::stringstream ss(item.substr(keys_flag.length()));
-      std::string key_item;
-      while (std::getline(ss, key_item, ',')) {
-        t->keys_array.push_back(ir::StringImmNode::make(key_item));
-      }
-    }
-  }
-
-  if (t->device_name.length() > 0) {
-    t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
-  }
-  t->device_type = kDLCPU;
-  t->thread_warp_size = 1;
-  if (target_name == "c" && t->device_name == "micro_dev") {
-    t->device_type = kDLMicroDev;
-  } else if (target_name == "c" || target_name == "llvm") {
-    t->keys_array.push_back(ir::StringImmNode::make("cpu"));
-  } else if (target_name == "cuda" || target_name == "nvptx") {
-    t->device_type = kDLGPU;
-    t->keys_array.push_back(ir::StringImmNode::make("cuda"));
-    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
-    t->max_num_threads = 1024;
-    t->thread_warp_size = 32;
-  } else if (target_name == "rocm" || target_name == "opencl") {
-    // For now assume rocm schedule for opencl
-    if (target_name == "opencl") {
-      t->device_type = kDLOpenCL;
-    } else {
-      t->device_type = kDLROCM;
-    }
-    t->keys_array.push_back(ir::StringImmNode::make(target_name));
-    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
-    t->max_num_threads = 256;
-    if (t->device_name == "intel_graphics") {
-      t->thread_warp_size = 16;
-    }
-  } else if (target_name == "metal" || target_name == "vulkan") {
-    if (target_name == "metal") {
-      t->device_type = kDLMetal;
-    } else {
-      t->device_type = kDLVulkan;
-    }
-    t->keys_array.push_back(ir::StringImmNode::make(target_name));
-    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
-    t->max_num_threads = 256;
-  } else if (target_name == "sdaccel") {
-    t->device_type = kDLOpenCL;
-    t->keys_array.push_back(ir::StringImmNode::make("sdaccel"));
-    t->keys_array.push_back(ir::StringImmNode::make("hls"));
-  } else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
-    t->device_type = kDLAOCL;
-    t->keys_array.push_back(ir::StringImmNode::make("aocl"));
-    t->keys_array.push_back(ir::StringImmNode::make("hls"));
-  } else if (target_name == "opengl") {
-    t->device_type = kOpenGL;
-    t->keys_array.push_back(ir::StringImmNode::make("opengl"));
-  } else if (target_name == "stackvm") {
-    t->device_type = kDLCPU;
-  } else if (target_name == "ext_dev") {
-    t->device_type = kDLExtDev;
-  } else if (target_name == "hybrid") {
-    t->device_type = kDLCPU;
-  } else {
-    LOG(ERROR) << "Unknown target name " << target_name;
-    return target::stackvm();
-  }
-
-  return Target(t);
-}
-
-TVM_REGISTER_GLOBAL("_TargetCreate")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  std::string target_name = args[0];
-  std::vector<std::string> options;
-  for (int i = 1; i < args.num_args; ++i) {
-    std::string arg = args[i];
-    options.push_back(arg);
-  }
-
-  *ret = CreateTarget(target_name, options);
-  });
-
-TVM_REGISTER_GLOBAL("_TargetFromString")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  std::string target_str = args[0];
-  *ret = Target::Create(target_str);
-  });
-
-std::vector<std::string> TargetNode::keys() const {
-  std::vector<std::string> result;
-  for (auto& expr : keys_array) {
-    result.push_back(expr.as<ir::StringImmNode>()->value);
-  }
-  return result;
-}
-
-std::vector<std::string> TargetNode::options() const {
-  std::vector<std::string> result;
-  for (auto& expr : options_array) {
-    result.push_back(expr.as<ir::StringImmNode>()->value);
-  }
-  return result;
-}
-
-std::unordered_set<std::string> TargetNode::libs() const {
-  std::unordered_set<std::string> result;
-  for (auto& expr : libs_array) {
-    result.insert(expr.as<ir::StringImmNode>()->value);
-  }
-  return result;
-}
-
-const std::string& TargetNode::str() const {
-  if (str_repr_.length() != 0) return str_repr_;
-  std::ostringstream result;
-  result << target_name;
-  for (const auto &x : options()) {
-    result << " " << x;
-  }
-  str_repr_ = result.str();
-  return str_repr_;
-}
-
-
-bool StartsWith(const std::string& str, const std::string& pattern) {
-  return str.compare(0, pattern.length(), pattern) == 0;
-}
-
-std::string GetDeviceName(const std::string& target_str) {
-  std::istringstream ss(target_str);
-  std::string target_name;
-  ss >> target_name;
-
-  std::string item;
-  while (ss >> item) {
-    if (StartsWith(item, "-device=")) {
-      return item.substr(std::string("-device=").length());
-    }
-  }
-
-  return "";
-}
-
-Target Target::Create(const std::string& target_str) {
-  if (target_str.length() == 0) {
-    LOG(ERROR) << "target_str must not be empty";
-  }
-
-  std::istringstream ss(target_str);
-  std::string target_name;
-
-  ss >> target_name;
-  auto device_name = GetDeviceName(target_str);
-
-  std::vector<std::string> options;
-  std::string item;
-  while (ss >> item) {
-    options.push_back(item);
-  }
-
-  return CreateTarget(target_name, options);
-}
-
-/*! \brief Entry to hold the Target context stack. */
-struct TVMTargetThreadLocalEntry {
-  /*! \brief The current target context */
-  std::stack<tvm::Target> context_stack;
-};
-
-/*! \brief Thread local store to hold the Target context stack. */
-typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
-
-void Target::EnterWithScope() {
-  TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
-  entry->context_stack.push(*this);
-}
-
-void Target::ExitWithScope() {
-  TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
-  CHECK(!entry->context_stack.empty());
-  CHECK(entry->context_stack.top().same_as(*this));
-  entry->context_stack.pop();
-}
-
-tvm::Target Target::Current(bool allow_not_defined) {
-  TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
-  if (entry->context_stack.size() > 0) {
-    return entry->context_stack.top();
-  }
-  CHECK(allow_not_defined)
-    << "Target context required. Please set it by constructing a TargetContext";
-
-  return Target();
-}
-
-namespace target {
-std::vector<std::string> MergeOptions(std::vector<std::string> opts,
-                                             const std::vector<std::string>& new_opts) {
-  opts.insert(opts.end(), new_opts.begin(), new_opts.end());
-  return opts;
-}
-
-Target llvm(const std::vector<std::string>& options) {
-  return CreateTarget("llvm", options);
-}
-
-Target cuda(const std::vector<std::string>& options) {
-  return CreateTarget("cuda", options);
-}
-
-Target rocm(const std::vector<std::string>& options) {
-  return CreateTarget("rocm", options);
-}
-
-Target opencl(const std::vector<std::string>& options) {
-  return CreateTarget("opencl", options);
-}
-
-Target metal(const std::vector<std::string>& options) {
-  return CreateTarget("metal", options);
-}
-
-Target mali(const std::vector<std::string>& options) {
-  return CreateTarget("opencl", MergeOptions(options, {
-    "-device=mali"
-  }));
-}
-
-Target intel_graphics(const std::vector<std::string>& options) {
-  return CreateTarget("opencl", MergeOptions(options, {
-    "-device=intel_graphics"
-  }));
-}
-
-Target stackvm(const std::vector<std::string>& options) {
-  return CreateTarget("stackvm", options);
-}
-
-Target ext_dev(const std::vector<std::string>& options) {
-  return CreateTarget("ext_dev", options);
-}
-}  // namespace target
-
 bool LLVMEnabled() {
   const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm");
   return pf != nullptr;
index d98299f..da153fc 100644 (file)
@@ -21,7 +21,7 @@
  * \file storage_access.cc
  */
 #include <tvm/ir_pass.h>
-#include <tvm/target_info.h>
+#include <tvm/target/target_info.h>
 #include <string>
 #include <utility>
 #include "ir_util.h"
index 08c61aa..a6d83a8 100644 (file)
@@ -30,7 +30,7 @@
 #include <tvm/expr_operator.h>
 #include <tvm/ir_pass.h>
 #include <tvm/buffer.h>
-#include <tvm/target_info.h>
+#include <tvm/target/target_info.h>
 #include <tvm/runtime/device_api.h>
 #include <unordered_map>
 #include "ir_util.h"
index 7a4b13c..4908420 100644 (file)
@@ -25,7 +25,7 @@
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_functor_ext.h>
-#include <tvm/target_info.h>
+#include <tvm/target/target_info.h>
 #include <map>
 #include <unordered_set>
 #include <unordered_map>
index 956f27c..002e422 100644 (file)
@@ -28,7 +28,7 @@
 #include <tvm/expr_operator.h>
 #include <tvm/ir_pass.h>
 #include <tvm/buffer.h>
-#include <tvm/target_info.h>
+#include <tvm/target/target_info.h>
 #include <tvm/build_module.h>
 #include <tvm/runtime/device_api.h>
 #include <unordered_map>
diff --git a/src/target/target.cc b/src/target/target.cc
new file mode 100644 (file)
index 0000000..014d3f9
--- /dev/null
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ *  Compile executable modules.
+ * \file src/target/target.cc
+ */
+#include <dmlc/thread_local.h>
+
+#include <tvm/runtime/registry.h>
+#include <tvm/node/printer.h>
+#include <tvm/target/target.h>
+
+#include <tvm/ir.h>
+
+#include <algorithm>
+#include <stack>
+
+namespace tvm {
+
+using runtime::TVMArgs;
+using runtime::TVMRetValue;
+using runtime::PackedFunc;
+
+TVM_REGISTER_NODE_TYPE(TargetNode);
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const TargetNode*>(node.get());
+    p->stream << op->str();
+  });
+
+/*!
+* \brief Construct a Target node from the given name and options.
+* \param target_name The major target name. Should be one of
+* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hybrid", "llvm", "metal",
+*  "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"}
+* \param options Additional options appended to the target
+* \return The constructed Target
+*/
+Target CreateTarget(const std::string& target_name,
+                    const std::vector<std::string>& options) {
+  auto t = make_object<TargetNode>();
+  t->target_name = target_name;
+
+  std::string libs_flag = "-libs=";
+  std::string device_flag = "-device=";
+  std::string keys_flag = "-keys=";
+  for (auto& item : options) {
+    t->options_array.push_back(ir::StringImmNode::make(item));
+
+    if (item.find(libs_flag) == 0) {
+      std::stringstream ss(item.substr(libs_flag.length()));
+      std::string lib_item;
+      while (std::getline(ss, lib_item, ',')) {
+        t->libs_array.push_back(ir::StringImmNode::make(lib_item));
+      }
+    } else if (item.find(device_flag) == 0) {
+      t->device_name = item.substr(device_flag.length());
+      t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
+    } else if (item.find(keys_flag) == 0) {
+      std::stringstream ss(item.substr(keys_flag.length()));
+      std::string key_item;
+      while (std::getline(ss, key_item, ',')) {
+        t->keys_array.push_back(ir::StringImmNode::make(key_item));
+      }
+    }
+  }
+
+  if (t->device_name.length() > 0) {
+    t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
+  }
+  t->device_type = kDLCPU;
+  t->thread_warp_size = 1;
+  if (target_name == "c" && t->device_name == "micro_dev") {
+    t->device_type = kDLMicroDev;
+  } else if (target_name == "c" || target_name == "llvm") {
+    t->keys_array.push_back(ir::StringImmNode::make("cpu"));
+  } else if (target_name == "cuda" || target_name == "nvptx") {
+    t->device_type = kDLGPU;
+    t->keys_array.push_back(ir::StringImmNode::make("cuda"));
+    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
+    t->max_num_threads = 1024;
+    t->thread_warp_size = 32;
+  } else if (target_name == "rocm" || target_name == "opencl") {
+    // For now assume rocm schedule for opencl
+    if (target_name == "opencl") {
+      t->device_type = kDLOpenCL;
+    } else {
+      t->device_type = kDLROCM;
+    }
+    t->keys_array.push_back(ir::StringImmNode::make(target_name));
+    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
+    t->max_num_threads = 256;
+    if (t->device_name == "intel_graphics") {
+      t->thread_warp_size = 16;
+    }
+  } else if (target_name == "metal" || target_name == "vulkan") {
+    if (target_name == "metal") {
+      t->device_type = kDLMetal;
+    } else {
+      t->device_type = kDLVulkan;
+    }
+    t->keys_array.push_back(ir::StringImmNode::make(target_name));
+    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
+    t->max_num_threads = 256;
+  } else if (target_name == "sdaccel") {
+    t->device_type = kDLOpenCL;
+    t->keys_array.push_back(ir::StringImmNode::make("sdaccel"));
+    t->keys_array.push_back(ir::StringImmNode::make("hls"));
+  } else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
+    t->device_type = kDLAOCL;
+    t->keys_array.push_back(ir::StringImmNode::make("aocl"));
+    t->keys_array.push_back(ir::StringImmNode::make("hls"));
+  } else if (target_name == "opengl") {
+    t->device_type = kOpenGL;
+    t->keys_array.push_back(ir::StringImmNode::make("opengl"));
+  } else if (target_name == "stackvm") {
+    t->device_type = kDLCPU;
+  } else if (target_name == "ext_dev") {
+    t->device_type = kDLExtDev;
+  } else if (target_name == "hybrid") {
+    t->device_type = kDLCPU;
+  } else {
+    LOG(ERROR) << "Unknown target name " << target_name;
+    return target::stackvm();
+  }
+
+  return Target(t);
+}
+
+TVM_REGISTER_GLOBAL("_TargetCreate")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  std::string target_name = args[0];
+  std::vector<std::string> options;
+  for (int i = 1; i < args.num_args; ++i) {
+    std::string arg = args[i];
+    options.push_back(arg);
+  }
+
+  *ret = CreateTarget(target_name, options);
+  });
+
+TVM_REGISTER_GLOBAL("_TargetFromString")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  std::string target_str = args[0];
+  *ret = Target::Create(target_str);
+  });
+
+std::vector<std::string> TargetNode::keys() const {
+  std::vector<std::string> result;
+  for (auto& expr : keys_array) {
+    result.push_back(expr.as<ir::StringImmNode>()->value);
+  }
+  return result;
+}
+
+std::vector<std::string> TargetNode::options() const {
+  std::vector<std::string> result;
+  for (auto& expr : options_array) {
+    result.push_back(expr.as<ir::StringImmNode>()->value);
+  }
+  return result;
+}
+
+std::unordered_set<std::string> TargetNode::libs() const {
+  std::unordered_set<std::string> result;
+  for (auto& expr : libs_array) {
+    result.insert(expr.as<ir::StringImmNode>()->value);
+  }
+  return result;
+}
+
+const std::string& TargetNode::str() const {
+  if (str_repr_.length() != 0) return str_repr_;
+  std::ostringstream result;
+  result << target_name;
+  for (const auto &x : options()) {
+    result << " " << x;
+  }
+  str_repr_ = result.str();
+  return str_repr_;
+}
+
+
+bool StartsWith(const std::string& str, const std::string& pattern) {
+  return str.compare(0, pattern.length(), pattern) == 0;
+}
+
+std::string GetDeviceName(const std::string& target_str) {
+  std::istringstream ss(target_str);
+  std::string target_name;
+  ss >> target_name;
+
+  std::string item;
+  while (ss >> item) {
+    if (StartsWith(item, "-device=")) {
+      return item.substr(std::string("-device=").length());
+    }
+  }
+
+  return "";
+}
+
+Target Target::Create(const std::string& target_str) {
+  if (target_str.length() == 0) {
+    LOG(ERROR) << "target_str must not be empty";
+  }
+
+  std::istringstream ss(target_str);
+  std::string target_name;
+
+  ss >> target_name;
+  auto device_name = GetDeviceName(target_str);
+
+  std::vector<std::string> options;
+  std::string item;
+  while (ss >> item) {
+    options.push_back(item);
+  }
+
+  return CreateTarget(target_name, options);
+}
+
+/*! \brief Entry to hold the Target context stack. */
+struct TVMTargetThreadLocalEntry {
+  /*! \brief The current target context */
+  std::stack<tvm::Target> context_stack;
+};
+
+/*! \brief Thread local store to hold the Target context stack. */
+typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
+
+void Target::EnterWithScope() {
+  TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
+  entry->context_stack.push(*this);
+}
+
+void Target::ExitWithScope() {
+  TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
+  CHECK(!entry->context_stack.empty());
+  CHECK(entry->context_stack.top().same_as(*this));
+  entry->context_stack.pop();
+}
+
+tvm::Target Target::Current(bool allow_not_defined) {
+  TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
+  if (entry->context_stack.size() > 0) {
+    return entry->context_stack.top();
+  }
+  CHECK(allow_not_defined)
+    << "Target context required. Please set it by constructing a TargetContext";
+
+  return Target();
+}
+
+namespace target {
+std::vector<std::string> MergeOptions(std::vector<std::string> opts,
+                                             const std::vector<std::string>& new_opts) {
+  opts.insert(opts.end(), new_opts.begin(), new_opts.end());
+  return opts;
+}
+
+Target llvm(const std::vector<std::string>& options) {
+  return CreateTarget("llvm", options);
+}
+
+Target cuda(const std::vector<std::string>& options) {
+  return CreateTarget("cuda", options);
+}
+
+Target rocm(const std::vector<std::string>& options) {
+  return CreateTarget("rocm", options);
+}
+
+Target opencl(const std::vector<std::string>& options) {
+  return CreateTarget("opencl", options);
+}
+
+Target metal(const std::vector<std::string>& options) {
+  return CreateTarget("metal", options);
+}
+
+Target mali(const std::vector<std::string>& options) {
+  return CreateTarget("opencl", MergeOptions(options, {
+    "-device=mali"
+  }));
+}
+
+Target intel_graphics(const std::vector<std::string>& options) {
+  return CreateTarget("opencl", MergeOptions(options, {
+    "-device=intel_graphics"
+  }));
+}
+
+Target stackvm(const std::vector<std::string>& options) {
+  return CreateTarget("stackvm", options);
+}
+
+Target ext_dev(const std::vector<std::string>& options) {
+  return CreateTarget("ext_dev", options);
+}
+}  // namespace target
+}  // namespace tvm
similarity index 75%
rename from src/lang/target_info.cc
rename to src/target/target_info.cc
index 6bdcf88..6c332e7 100644 (file)
  */
 
 /*!
- * \file target_info.cc
+ * \file target/target_info.cc
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/target_info.h>
-#include <tvm/packed_func_ext.h>
+#include <tvm/node/printer.h>
+#include <tvm/target/target_info.h>
 
 namespace tvm {
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 .set_dispatch<MemoryInfoNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const MemoryInfoNode*>(node.get());
-    p->stream << "mem-info("
-              << "unit_bits=" << op->unit_bits << ", "
-              << "max_num_bits=" << op->max_num_bits << ", "
-              << "max_simd_bits=" << op->max_simd_bits << ", "
-              << "head_address=" << op->head_address << ")";
+  auto* op = static_cast<const MemoryInfoNode*>(node.get());
+  p->stream << "mem-info("
+            << "unit_bits=" << op->unit_bits << ", "
+            << "max_num_bits=" << op->max_num_bits << ", "
+            << "max_simd_bits=" << op->max_simd_bits << ", "
+            << "head_address=" << op->head_address << ")";
 });
 
 TVM_REGISTER_NODE_TYPE(MemoryInfoNode);