[Target] Rename target_id => target_kind (#6199)
authorJunru Shao <junrushao1994@gmail.com>
Tue, 4 Aug 2020 01:50:32 +0000 (18:50 -0700)
committerGitHub <noreply@github.com>
Tue, 4 Aug 2020 01:50:32 +0000 (18:50 -0700)
50 files changed:
include/tvm/target/target.h
include/tvm/target/target_kind.h [moved from include/tvm/target/target_id.h with 63% similarity]
include/tvm/topi/cuda/dense.h
include/tvm/topi/cuda/reduction.h
include/tvm/topi/rocm/dense.h
python/tvm/auto_scheduler/measure_record.py
python/tvm/autotvm/tophub.py
python/tvm/driver/build_module.py
python/tvm/relay/op/strategy/cuda.py
python/tvm/relay/op/strategy/rocm.py
python/tvm/relay/quantize/_calibrate.py
python/tvm/target/__init__.py
python/tvm/target/target.py
python/tvm/topi/cuda/batch_matmul.py
python/tvm/topi/cuda/conv1d.py
python/tvm/topi/cuda/conv1d_transpose_ncw.py
python/tvm/topi/cuda/conv2d_direct.py
python/tvm/topi/cuda/conv2d_nhwc.py
python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
python/tvm/topi/cuda/conv2d_transpose_nchw.py
python/tvm/topi/cuda/conv2d_winograd.py
python/tvm/topi/cuda/conv3d_direct.py
python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
python/tvm/topi/cuda/conv3d_winograd.py
python/tvm/topi/cuda/correlation.py
python/tvm/topi/cuda/deformable_conv2d.py
python/tvm/topi/cuda/dense_tensorcore.py
python/tvm/topi/cuda/depthwise_conv2d.py
python/tvm/topi/cuda/group_conv2d_nchw.py
python/tvm/topi/cuda/reduction.py
python/tvm/topi/cuda/softmax.py
python/tvm/topi/cuda/vision.py
python/tvm/topi/generic/default.py
python/tvm/topi/generic/injective.py
python/tvm/topi/generic/vision.py
python/tvm/topi/intel_graphics/depthwise_conv2d.py
src/auto_scheduler/search_task.cc
src/driver/driver_api.cc
src/relay/backend/build_module.cc
src/target/codegen.cc
src/target/target.cc
src/target/target_kind.cc [moved from src/target/target_id.cc with 83% similarity]
src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
src/tir/analysis/verify_memory.cc
src/tir/transforms/lower_custom_datatypes.cc
src/tir/transforms/lower_intrin.cc
src/tir/transforms/lower_thread_allreduce.cc
src/tir/transforms/make_packed_api.cc
tests/cpp/target_test.cc
tests/python/unittest/test_target_target.py

index 618095f..4a83579 100644 (file)
@@ -28,7 +28,7 @@
 #include <tvm/ir/transform.h>
 #include <tvm/node/container.h>
 #include <tvm/support/with.h>
-#include <tvm/target/target_id.h>
+#include <tvm/target/target_kind.h>
 
 #include <string>
 #include <unordered_set>
@@ -43,8 +43,8 @@ namespace tvm {
  */
 class TargetNode : public Object {
  public:
-  /*! \brief The id of the target device */
-  TargetId id;
+  /*! \brief The kind of the target device */
+  TargetKind kind;
   /*! \brief Tag of the the target, can be empty */
   String tag;
   /*! \brief Keys for this target */
@@ -56,7 +56,7 @@ class TargetNode : public Object {
   TVM_DLL const std::string& str() const;
 
   void VisitAttrs(AttrVisitor* v) {
-    v->Visit("id", &id);
+    v->Visit("kind", &kind);
     v->Visit("tag", &tag);
     v->Visit("keys", &keys);
     v->Visit("attrs", &attrs);
similarity index 63%
rename from include/tvm/target/target_id.h
rename to include/tvm/target/target_kind.h
index a0f275c..7f660be 100644 (file)
  */
 
 /*!
- * \file tvm/target/target_id.h
- * \brief Target id registry
+ * \file tvm/target/target_kind.h
+ * \brief Target kind registry
  */
-#ifndef TVM_TARGET_TARGET_ID_H_
-#define TVM_TARGET_TARGET_ID_H_
+#ifndef TVM_TARGET_TARGET_KIND_H_
+#define TVM_TARGET_TARGET_KIND_H_
 
 #include <tvm/ir/expr.h>
 #include <tvm/ir/transform.h>
@@ -49,14 +49,14 @@ class Target;
 TVM_DLL void TargetValidateSchema(const Map<String, ObjectRef>& config);
 
 template <typename>
-class TargetIdAttrMap;
+class TargetKindAttrMap;
 
-/*! \brief Target Id, specifies the kind of the target */
-class TargetIdNode : public Object {
+/*! \brief Target kind, specifies the kind of the target */
+class TargetKindNode : public Object {
  public:
-  /*! \brief Name of the target id */
+  /*! \brief Name of the target kind */
   String name;
-  /*! \brief Device type of target id */
+  /*! \brief Device type of target kind */
   int device_type;
   /*! \brief Default keys of the target */
   Array<String> default_keys;
@@ -71,8 +71,8 @@ class TargetIdNode : public Object {
 
   Optional<String> StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) const;
 
-  static constexpr const char* _type_key = "TargetId";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TargetIdNode, Object);
+  static constexpr const char* _type_key = "TargetKind";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object);
 
  private:
   /*! \brief Stores the required type_key and type_index of a specific attr of a target */
@@ -88,7 +88,7 @@ class TargetIdNode : public Object {
   /*! \brief Perform schema validation */
   void ValidateSchema(const Map<String, ObjectRef>& config) const;
   /*! \brief Verify if the obj is consistent with the type info */
-  void VerifyTypeInfo(const ObjectRef& obj, const TargetIdNode::ValueTypeInfo& info) const;
+  void VerifyTypeInfo(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) const;
   /*! \brief A hash table that stores the type information of each attr of the target key */
   std::unordered_map<String, ValueTypeInfo> key2vtype_;
   /*! \brief A hash table that stores the default value of each attr of the target key */
@@ -97,67 +97,67 @@ class TargetIdNode : public Object {
   uint32_t index_;
   friend void TargetValidateSchema(const Map<String, ObjectRef>&);
   friend class Target;
-  friend class TargetId;
+  friend class TargetKind;
   template <typename, typename>
   friend class AttrRegistry;
   template <typename>
   friend class AttrRegistryMapContainerMap;
-  friend class TargetIdRegEntry;
+  friend class TargetKindRegEntry;
   template <typename, typename, typename>
   friend struct detail::ValueTypeInfoMaker;
 };
 
 /*!
- * \brief Managed reference class to TargetIdNode
- * \sa TargetIdNode
+ * \brief Managed reference class to TargetKindNode
+ * \sa TargetKindNode
  */
-class TargetId : public ObjectRef {
+class TargetKind : public ObjectRef {
  public:
-  TargetId() = default;
+  TargetKind() = default;
   /*! \brief Get the attribute map given the attribute name */
   template <typename ValueType>
-  static inline TargetIdAttrMap<ValueType> GetAttrMap(const String& attr_name);
+  static inline TargetKindAttrMap<ValueType> GetAttrMap(const String& attr_name);
   /*!
-   * \brief Retrieve the TargetId given its name
-   * \param target_id_name Name of the target id
-   * \return The TargetId requested
+   * \brief Retrieve the TargetKind given its name
+   * \param target_kind_name Name of the target kind
+   * \return The TargetKind requested
    */
-  TVM_DLL static const TargetId& Get(const String& target_id_name);
-  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetId, ObjectRef, TargetIdNode);
+  TVM_DLL static const TargetKind& Get(const String& target_kind_name);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode);
 
  private:
   /*! \brief Mutable access to the container class  */
-  TargetIdNode* operator->() { return static_cast<TargetIdNode*>(data_.get()); }
-  TVM_DLL static const AttrRegistryMapContainerMap<TargetId>& GetAttrMapContainer(
+  TargetKindNode* operator->() { return static_cast<TargetKindNode*>(data_.get()); }
+  TVM_DLL static const AttrRegistryMapContainerMap<TargetKind>& GetAttrMapContainer(
       const String& attr_name);
   template <typename, typename>
   friend class AttrRegistry;
-  friend class TargetIdRegEntry;
+  friend class TargetKindRegEntry;
   friend class Target;
 };
 
 /*!
- * \brief Map<TargetId, ValueType> used to store meta-information about TargetId
+ * \brief Map<TargetKind, ValueType> used to store meta-information about TargetKind
  * \tparam ValueType The type of the value stored in map
  */
 template <typename ValueType>
-class TargetIdAttrMap : public AttrRegistryMap<TargetId, ValueType> {
+class TargetKindAttrMap : public AttrRegistryMap<TargetKind, ValueType> {
  public:
-  using TParent = AttrRegistryMap<TargetId, ValueType>;
+  using TParent = AttrRegistryMap<TargetKind, ValueType>;
   using TParent::count;
   using TParent::get;
   using TParent::operator[];
-  explicit TargetIdAttrMap(const AttrRegistryMapContainerMap<TargetId>& map) : TParent(map) {}
+  explicit TargetKindAttrMap(const AttrRegistryMapContainerMap<TargetKind>& map) : TParent(map) {}
 };
 
 /*!
- * \brief Helper structure to register TargetId
- * \sa TVM_REGISTER_TARGET_ID
+ * \brief Helper structure to register TargetKind
+ * \sa TVM_REGISTER_TARGET_KIND
  */
-class TargetIdRegEntry {
+class TargetKindRegEntry {
  public:
   /*!
-   * \brief Register additional attributes to target_id.
+   * \brief Register additional attributes to target_kind.
    * \param attr_name The name of the attribute.
    * \param value The value to be set.
    * \param plevel The priority level of this attribute,
@@ -170,25 +170,25 @@ class TargetIdRegEntry {
    * \tparam ValueType The type of the value to be set.
    */
   template <typename ValueType>
-  inline TargetIdRegEntry& set_attr(const String& attr_name, const ValueType& value,
-                                    int plevel = 10);
+  inline TargetKindRegEntry& set_attr(const String& attr_name, const ValueType& value,
+                                      int plevel = 10);
   /*!
    * \brief Set DLPack's device_type the target
    * \param device_type Device type
    */
-  inline TargetIdRegEntry& set_device_type(int device_type);
+  inline TargetKindRegEntry& set_device_type(int device_type);
   /*!
    * \brief Set DLPack's device_type the target
    * \param keys The default keys
    */
-  inline TargetIdRegEntry& set_default_keys(std::vector<String> keys);
+  inline TargetKindRegEntry& set_default_keys(std::vector<String> keys);
   /*!
    * \brief Register a valid configuration option and its ValueType for validation
    * \param key The configuration key
    * \tparam ValueType The value type to be registered
    */
   template <typename ValueType>
-  inline TargetIdRegEntry& add_attr_option(const String& key);
+  inline TargetKindRegEntry& add_attr_option(const String& key);
   /*!
    * \brief Register a valid configuration option and its ValueType for validation
    * \param key The configuration key
@@ -196,26 +196,26 @@ class TargetIdRegEntry {
    * \tparam ValueType The value type to be registered
    */
   template <typename ValueType>
-  inline TargetIdRegEntry& add_attr_option(const String& key, ObjectRef default_value);
-  /*! \brief Set name of the TargetId to be the same as registry if it is empty */
-  inline TargetIdRegEntry& set_name();
+  inline TargetKindRegEntry& add_attr_option(const String& key, ObjectRef default_value);
+  /*! \brief Set name of the TargetKind to be the same as registry if it is empty */
+  inline TargetKindRegEntry& set_name();
   /*!
    * \brief Register or get a new entry.
-   * \param target_id_name The name of the TargetId.
+   * \param target_kind_name The name of the TargetKind.
    * \return the corresponding entry.
    */
-  TVM_DLL static TargetIdRegEntry& RegisterOrGet(const String& target_id_name);
+  TVM_DLL static TargetKindRegEntry& RegisterOrGet(const String& target_kind_name);
 
  private:
-  TargetId id_;
+  TargetKind kind_;
   String name;
 
   /*! \brief private constructor */
-  explicit TargetIdRegEntry(uint32_t reg_index) : id_(make_object<TargetIdNode>()) {
-    id_->index_ = reg_index;
+  explicit TargetKindRegEntry(uint32_t reg_index) : kind_(make_object<TargetKindNode>()) {
+    kind_->index_ = reg_index;
   }
   /*!
-   * \brief update the attribute TargetIdAttrMap
+   * \brief update the attribute TargetKindAttrMap
    * \param key The name of the attribute
    * \param value The value to be set
    * \param plevel The priority level
@@ -223,21 +223,21 @@ class TargetIdRegEntry {
   TVM_DLL void UpdateAttr(const String& key, TVMRetValue value, int plevel);
   template <typename, typename>
   friend class AttrRegistry;
-  friend class TargetId;
+  friend class TargetKind;
 };
 
-#define TVM_TARGET_ID_REGISTER_VAR_DEF \
-  static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetIdRegEntry& __make_##TargetId
+#define TVM_TARGET_KIND_REGISTER_VAR_DEF \
+  static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetKindRegEntry& __make_##TargetKind
 
 /*!
- * \def TVM_REGISTER_TARGET_ID
- * \brief Register a new target id, or set attribute of the corresponding target id.
+ * \def TVM_REGISTER_TARGET_KIND
+ * \brief Register a new target kind, or set attribute of the corresponding target kind.
  *
- * \param TargetIdName The name of target id
+ * \param TargetKindName The name of target kind
  *
  * \code
  *
- *  TVM_REGISTER_TARGET_ID("llvm")
+ *  TVM_REGISTER_TARGET_KIND("llvm")
  *  .set_attr<TPreCodegenPass>("TPreCodegenPass", a-pre-codegen-pass)
  *  .add_attr_option<Bool>("system_lib")
  *  .add_attr_option<String>("mtriple")
@@ -245,9 +245,9 @@ class TargetIdRegEntry {
  *
  * \endcode
  */
-#define TVM_REGISTER_TARGET_ID(TargetIdName)                    \
-  TVM_STR_CONCAT(TVM_TARGET_ID_REGISTER_VAR_DEF, __COUNTER__) = \
-      ::tvm::TargetIdRegEntry::RegisterOrGet(TargetIdName).set_name()
+#define TVM_REGISTER_TARGET_KIND(TargetKindName)                  \
+  TVM_STR_CONCAT(TVM_TARGET_KIND_REGISTER_VAR_DEF, __COUNTER__) = \
+      ::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName).set_name()
 
 namespace detail {
 template <typename Type, template <typename...> class Container>
@@ -266,7 +266,7 @@ struct ValueTypeInfoMaker {};
 
 template <typename ValueType>
 struct ValueTypeInfoMaker<ValueType, std::false_type, std::false_type> {
-  using ValueTypeInfo = TargetIdNode::ValueTypeInfo;
+  using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
 
   ValueTypeInfo operator()() const {
     uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex();
@@ -281,7 +281,7 @@ struct ValueTypeInfoMaker<ValueType, std::false_type, std::false_type> {
 
 template <typename ValueType>
 struct ValueTypeInfoMaker<ValueType, std::true_type, std::false_type> {
-  using ValueTypeInfo = TargetIdNode::ValueTypeInfo;
+  using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
 
   ValueTypeInfo operator()() const {
     using key_type = ValueTypeInfoMaker<typename ValueType::value_type>;
@@ -297,7 +297,7 @@ struct ValueTypeInfoMaker<ValueType, std::true_type, std::false_type> {
 
 template <typename ValueType>
 struct ValueTypeInfoMaker<ValueType, std::false_type, std::true_type> {
-  using ValueTypeInfo = TargetIdNode::ValueTypeInfo;
+  using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
   ValueTypeInfo operator()() const {
     using key_type = ValueTypeInfoMaker<typename ValueType::key_type>;
     using val_type = ValueTypeInfoMaker<typename ValueType::mapped_type>;
@@ -314,13 +314,13 @@ struct ValueTypeInfoMaker<ValueType, std::false_type, std::true_type> {
 }  // namespace detail
 
 template <typename ValueType>
-inline TargetIdAttrMap<ValueType> TargetId::GetAttrMap(const String& attr_name) {
-  return TargetIdAttrMap<ValueType>(GetAttrMapContainer(attr_name));
+inline TargetKindAttrMap<ValueType> TargetKind::GetAttrMap(const String& attr_name) {
+  return TargetKindAttrMap<ValueType>(GetAttrMapContainer(attr_name));
 }
 
 template <typename ValueType>
-inline TargetIdRegEntry& TargetIdRegEntry::set_attr(const String& attr_name, const ValueType& value,
-                                                    int plevel) {
+inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const String& attr_name,
+                                                        const ValueType& value, int plevel) {
   CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
   runtime::TVMRetValue rv;
   rv = value;
@@ -328,39 +328,39 @@ inline TargetIdRegEntry& TargetIdRegEntry::set_attr(const String& attr_name, con
   return *this;
 }
 
-inline TargetIdRegEntry& TargetIdRegEntry::set_device_type(int device_type) {
-  id_->device_type = device_type;
+inline TargetKindRegEntry& TargetKindRegEntry::set_device_type(int device_type) {
+  kind_->device_type = device_type;
   return *this;
 }
 
-inline TargetIdRegEntry& TargetIdRegEntry::set_default_keys(std::vector<String> keys) {
-  id_->default_keys = keys;
+inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector<String> keys) {
+  kind_->default_keys = keys;
   return *this;
 }
 
 template <typename ValueType>
-inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) {
-  CHECK(!id_->key2vtype_.count(key))
+inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) {
+  CHECK(!kind_->key2vtype_.count(key))
       << "AttributeError: add_attr_option failed because '" << key << "' has been set once";
-  id_->key2vtype_[key] = detail::ValueTypeInfoMaker<ValueType>()();
+  kind_->key2vtype_[key] = detail::ValueTypeInfoMaker<ValueType>()();
   return *this;
 }
 
 template <typename ValueType>
-inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key,
-                                                           ObjectRef default_value) {
+inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key,
+                                                               ObjectRef default_value) {
   add_attr_option<ValueType>(key);
-  id_->key2default_[key] = default_value;
+  kind_->key2default_[key] = default_value;
   return *this;
 }
 
-inline TargetIdRegEntry& TargetIdRegEntry::set_name() {
-  if (id_->name.empty()) {
-    id_->name = name;
+inline TargetKindRegEntry& TargetKindRegEntry::set_name() {
+  if (kind_->name.empty()) {
+    kind_->name = name;
   }
   return *this;
 }
 
 }  // namespace tvm
 
-#endif  // TVM_TARGET_TARGET_ID_H_
+#endif  // TVM_TARGET_TARGET_KIND_H_
index 34af343..447486d 100644 (file)
@@ -86,7 +86,7 @@ inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& d
  * \return A schedule for the given ops.
  */
 inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
-  if (target->id->name == "cuda" && target->GetLibs().count("cublas")) {
+  if (target->kind->name == "cuda" && target->GetLibs().count("cublas")) {
     return topi::generic::schedule_extern(target, outs);
   }
 
index 18d4484..acfcc76 100644 (file)
@@ -70,7 +70,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch,
   if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) {
     all_reduce = false;
     num_thread = 32;
-    if (target->id->name == "opencl") {
+    if (target->kind->name == "opencl") {
       // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
       // Don't know why.
       num_thread = 16;
index e279152..a1e4d14 100644 (file)
@@ -86,7 +86,7 @@ inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& d
  * \return A schedule for the given ops.
  */
 inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
-  if (target->id->name == "rocm" && target->GetLibs().count("rocblas")) {
+  if (target->kind->name == "rocm" && target->GetLibs().count("rocblas")) {
     return topi::generic::schedule_extern(target, outs);
   }
 
index 2a4cd29..dd40f21 100644 (file)
@@ -144,7 +144,7 @@ def load_best(filename, workload_key=None, target=None):
             continue
         if workload_key and inp.task.workload_key != workload_key:
             continue
-        if target and inp.task.target.id.name != target.id.name:
+        if target and inp.task.target.kind.name != target.kind.name:
             continue
 
         costs = [v.value for v in res.costs]
index c0e4eed..063932d 100644 (file)
@@ -106,7 +106,7 @@ def context(target, extra_files=None):
         device = tgt.attrs.get("device", "")
         if device != "":
             possible_names.append(_alias(device))
-        possible_names.append(tgt.id.name)
+        possible_names.append(tgt.kind.name)
 
         all_packages = list(PACKAGE_VERSION.keys())
         for name in possible_names:
index 663a17a..ff4b56b 100644 (file)
@@ -156,8 +156,10 @@ def lower(sch,
     """
     # config setup
     pass_ctx = PassContext.current()
-    instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False))
-    disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False))
+    instrument_bound_checkers = bool(pass_ctx.config.get(
+        "tir.instrument_bound_checkers", False))
+    disable_vectorize = bool(pass_ctx.config.get(
+        "tir.disable_vectorize", False))
     add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", [])
 
     lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
@@ -239,14 +241,16 @@ def _build_for_device(input_mod, target, target_host):
     """
     target = _target.create(target)
     target_host = _target.create(target_host)
-    device_type = ndarray.context(target.id.name, 0).device_type
+    device_type = ndarray.context(target.kind.name, 0).device_type
 
     mod_mixed = input_mod
-    mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
+    mod_mixed = tvm.tir.transform.Apply(
+        lambda f: f.with_attr("target", target))(mod_mixed)
 
     opt_mixed = [tvm.tir.transform.VerifyMemory()]
     if len(mod_mixed.functions) == 1:
-        opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
+        opt_mixed += [tvm.tir.transform.Apply(
+            lambda f: f.with_attr("tir.is_entry_func", True))]
 
     if PassContext.current().config.get("tir.detect_global_barrier", False):
         opt_mixed += [tvm.tir.transform.ThreadSync("global")]
@@ -258,7 +262,6 @@ def _build_for_device(input_mod, target, target_host):
                   tvm.tir.transform.SplitHostDevice()]
     mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
 
-
     # device optimizations
     opt_device = tvm.transform.Sequential(
         [tvm.tir.transform.Filter(
@@ -289,7 +292,8 @@ def _build_for_device(input_mod, target, target_host):
             "Specified target %s, but cannot find device code, did you do "
             "bind?" % target)
 
-    rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
+    rt_mod_dev = codegen.build_module(mod_dev, target) if len(
+        mod_dev.functions) != 0 else None
     return mod_host, rt_mod_dev
 
 
@@ -383,7 +387,8 @@ def build(inputs,
     elif isinstance(inputs, tvm.IRModule):
         input_mod = inputs
     elif not isinstance(inputs, (dict, container.Map)):
-        raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule")
+        raise ValueError(
+            "inputs must be Schedule, IRModule or dict of target to IRModule")
 
     if not isinstance(inputs, (dict, container.Map)):
         target = _target.Target.current() if target is None else target
@@ -403,7 +408,7 @@ def build(inputs,
     if not target_host:
         for tar, _ in target_input_mod.items():
             tar = _target.create(tar)
-            device_type = ndarray.context(tar.id.name, 0).device_type
+            device_type = ndarray.context(tar.kind.name, 0).device_type
             if device_type == ndarray.cpu(0).device_type:
                 target_host = tar
                 break
index 21c3c83..b2a0ff4 100644 (file)
@@ -68,7 +68,7 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target):
         wrap_compute_softmax(topi.nn.softmax),
         wrap_topi_schedule(topi.cuda.schedule_softmax),
         name="softmax.cuda")
-    if target.id.name == "cuda" and "cudnn" in target.libs:
+    if target.kind.name == "cuda" and "cudnn" in target.libs:
         strategy.add_implementation(
             wrap_compute_softmax(topi.cuda.softmax_cudnn),
             wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn),
@@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                                                                              dilation_h, dilation_w,
                                                                              pre_flag=False)
             if judge_winograd_shape:
-                if target.id.name == "cuda" and \
+                if target.kind.name == "cuda" and \
                     nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
                     judge_winograd_tensorcore:
                     strategy.add_implementation(
@@ -162,7 +162,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                             topi.cuda.schedule_conv2d_nhwc_winograd_direct),
                         name="conv2d_nhwc_winograd_direct.cuda",
                         plevel=5)
-            if target.id.name == "cuda":
+            if target.kind.name == "cuda":
                 if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                     if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
                             (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
@@ -181,7 +181,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
         else:
             raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
         # add cudnn implementation
-        if target.id.name == "cuda" and "cudnn" in target.libs:
+        if target.kind.name == "cuda" and "cudnn" in target.libs:
             if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
                     padding[1] == padding[3]:
                 strategy.add_implementation(
@@ -209,7 +209,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
     else: # group_conv2d
         # add cudnn implementation, if any
         cudnn_impl = False
-        if target.id.name == "cuda" and "cudnn" in target.libs:
+        if target.kind.name == "cuda" and "cudnn" in target.libs:
             if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
                     padding[1] == padding[3]:
                 strategy.add_implementation(
@@ -264,7 +264,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
                                                       padding, stride_h, stride_w,
                                                       dilation_h, dilation_w,
                                                       pre_flag=True)
-        if target.id.name == "cuda" and \
+        if target.kind.name == "cuda" and \
             nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
             judge_winograd_tensorcore:
             strategy.add_implementation(
@@ -362,7 +362,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
             plevel=10)
         N, _, _, _, _ = get_const_tuple(data.shape)
         _, _, _, CI, CO = get_const_tuple(kernel.shape)
-        if target.id.name == "cuda":
+        if target.kind.name == "cuda":
             if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                 if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
                 (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
@@ -373,7 +373,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
                         name="conv3d_ndhwc_tensorcore.cuda",
                         plevel=20)
 
-    if target.id.name == "cuda" and "cudnn" in target.libs:
+    if target.kind.name == "cuda" and "cudnn" in target.libs:
         strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
                                     wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
                                     name="conv3d_cudnn.cuda",
@@ -458,7 +458,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
                 wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
                 name="dense_large_batch.cuda",
                 plevel=5)
-        if target.id.name == "cuda":
+        if target.kind.name == "cuda":
             if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                 if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
                         or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
@@ -468,7 +468,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
                         wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
                         name="dense_tensorcore.cuda",
                         plevel=20)
-    if target.id.name == "cuda" and "cublas" in target.libs:
+    if target.kind.name == "cuda" and "cublas" in target.libs:
         strategy.add_implementation(
             wrap_compute_dense(topi.cuda.dense_cublas),
             wrap_topi_schedule(topi.cuda.schedule_dense_cublas),
@@ -485,7 +485,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
         name="batch_matmul.cuda",
         plevel=10)
-    if target.id.name == "cuda" and "cublas" in target.libs:
+    if target.kind.name == "cuda" and "cublas" in target.libs:
         strategy.add_implementation(
             wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
             wrap_topi_schedule(topi.generic.schedule_extern),
index e70298a..01cf621 100644 (file)
@@ -127,7 +127,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
         wrap_compute_dense(topi.rocm.dense),
         wrap_topi_schedule(topi.rocm.schedule_dense),
         name="dense.rocm")
-    if target.id.name == "rocm" and "rocblas" in target.libs:
+    if target.kind.name == "rocm" and "rocblas" in target.libs:
         assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
         strategy.add_implementation(
             wrap_compute_dense(topi.rocm.dense_rocblas),
index 74a6f60..8f553dd 100644 (file)
@@ -39,7 +39,7 @@ def _get_profile_runtime(mod):
 
     if tvm.target.Target.current():
         target = tvm.target.Target.current()
-        ctx = tvm.context(target.id.name)
+        ctx = tvm.context(target.kind.name)
     else:
         target = 'llvm'
         ctx = tvm.context(target)
index 55bb110..09d7bd8 100644 (file)
@@ -16,7 +16,7 @@
 # under the License.
 """Target description and codgen module.
 
-TVM's target string is in fomat ``<target_id> [-option=value]...``.
+TVM's target string is in fomat ``<target_kind> [-option=value]...``.
 
 Note
 ----
index 9be7f83..597f8a5 100644 (file)
@@ -25,8 +25,8 @@ from . import _ffi_api
 
 
 @tvm._ffi.register_object
-class TargetId(Object):
-    """Id of a compilation target
+class TargetKind(Object):
+    """Kind of a compilation target
     """
 
 
@@ -45,6 +45,7 @@ class Target(Object):
     - :py:func:`tvm.target.mali` create Mali target
     - :py:func:`tvm.target.intel_graphics` create Intel Graphics target
     """
+
     def __enter__(self):
         _ffi_api.EnterTargetScope(self)
         return self
@@ -163,7 +164,8 @@ def intel_graphics(model='unknown', options=None):
     options : str or list of str
         Additional options
     """
-    opts = ["-device=intel_graphics", "-model=%s" % model, "-thread_warp_size=16"]
+    opts = ["-device=intel_graphics", "-model=%s" %
+            model, "-thread_warp_size=16"]
     opts = _merge_opts(opts, options)
     return _ffi_api.TargetCreate("opencl", *opts)
 
@@ -280,7 +282,7 @@ def hexagon(cpu_ver='v66', sim_args=None, hvx=128):
                 i = sim_args.index('hvx_length') + len('hvx_length') + 1
                 sim_hvx = sim_args[i:i+3]
                 if sim_hvx != str(codegen_hvx):
-                    print('WARNING: sim hvx {} and codegen hvx {} mismatch!' \
+                    print('WARNING: sim hvx {} and codegen hvx {} mismatch!'
                           .format(sim_hvx, codegen_hvx))
             elif codegen_hvx != 0:
                 # If --hvx_length was not given, add it if HVX is enabled
@@ -313,10 +315,10 @@ def hexagon(cpu_ver='v66', sim_args=None, hvx=128):
             # Parse options into correct order
             cpu_attr = {x: str(m.groupdict()[x] or '') for x in m.groupdict()}
             sim_args = cpu_attr['base_version'] +  \
-                       cpu_attr['sub_version']  +  \
-                       cpu_attr['l2_size'] +       \
-                       cpu_attr['rev'] + ' ' +     \
-                       cpu_attr['pre'] + cpu_attr['post']
+                cpu_attr['sub_version'] +  \
+                cpu_attr['l2_size'] +       \
+                cpu_attr['rev'] + ' ' +     \
+                cpu_attr['pre'] + cpu_attr['post']
 
         return sim_cpu + ' ' + validate_hvx_length(hvx, sim_args)
 
index bcd98cc..373b1ec 100644 (file)
@@ -69,7 +69,7 @@ def schedule_batch_matmul(cfg, outs):
         cfg.define_split("tile_k", k, num_outputs=2)
         cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
         target = tvm.target.Target.current()
-        if target.id.name in ['nvptx', 'rocm']:
+        if target.kind.name in ['nvptx', 'rocm']:
             # llvm-based backends cannot do non-explicit unrolling
             cfg.define_knob("unroll_explicit", [1])
         else:
index 533cf74..c099d25 100644 (file)
@@ -72,7 +72,7 @@ def schedule_conv1d_ncw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.id.name in ['nvptx', 'rocm']:
+            if target.kind.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -197,7 +197,7 @@ def schedule_conv1d_nwc(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.id.name in ['nvptx', 'rocm']:
+            if target.kind.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
index ffce584..dbfe1f5 100644 (file)
@@ -124,7 +124,7 @@ def schedule_conv1d_transpose_ncw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.id.name in ['nvptx', 'rocm']:
+            if target.kind.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
index 9d8146e..8a26a82 100644 (file)
@@ -36,7 +36,7 @@ def schedule_direct_cuda(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.id.name in ['nvptx', 'rocm']:
+    if target.kind.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -44,7 +44,7 @@ def schedule_direct_cuda(cfg, s, conv):
     # fallback support
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.id.name, target.model, 'conv2d_nchw.cuda')
+            target.kind.name, target.model, 'conv2d_nchw.cuda')
         cfg.fallback_with_reference_log(ref_log)
     ##### space definition end #####
 
index c7c3f18..23607b1 100644 (file)
@@ -56,7 +56,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.id.name, target.model, 'conv2d_nhwc.cuda')
+            target.kind.name, target.model, 'conv2d_nhwc.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     tile_n = cfg["tile_n"].val
index 7703e40..a82508b 100644 (file)
@@ -134,7 +134,7 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.id.name, target.model, 'conv2d_nhwc_tensorcore.cuda')
+            target.kind.name, target.model, 'conv2d_nhwc_tensorcore.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     block_row_warps = cfg["block_row_warps"].val
index 4dfcc03..d0a683e 100644 (file)
@@ -177,7 +177,7 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])
 
             target = tvm.target.Target.current()
-            if target.id.name in ['nvptx', 'rocm']:
+            if target.kind.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
index d976aaa..f5259ba 100644 (file)
@@ -193,7 +193,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     cfg.define_split("tile_rc", rc, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.id.name in ['nvptx', 'rocm']:
+    if target.kind.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
index 0b80e79..e3dd6f9 100644 (file)
@@ -43,7 +43,7 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.id.name in ['nvptx', 'rocm']:
+    if target.kind.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -51,7 +51,7 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
     # fallback support
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.id.name, target.model, workload_name)
+            target.kind.name, target.model, workload_name)
         cfg.fallback_with_reference_log(ref_log)
     ##### space definition end #####
 
index 68b0145..bc4f0e1 100644 (file)
@@ -141,7 +141,7 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.id.name, target.model, 'conv3d_ndhwc_tensorcore.cuda')
+            target.kind.name, target.model, 'conv3d_ndhwc_tensorcore.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     block_row_warps = cfg["block_row_warps"].val
index e8b5037..3e6b1c1 100644 (file)
@@ -321,7 +321,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
     cfg.define_split("tile_rc", rc, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.id.name in ['nvptx', 'rocm']:
+    if target.kind.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -478,7 +478,7 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
     cfg.define_split("tile_rz", rz, num_outputs=2)
     cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
     target = tvm.target.Target.current()
-    if target.id.name in ['nvptx', 'rocm']:
+    if target.kind.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
index 6d9be95..dbaabb7 100644 (file)
@@ -81,7 +81,7 @@ def _schedule_correlation_nchw(cfg, s, correlation):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.id.name in ['nvptx', 'rocm']:
+    if target.kind.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
@@ -143,8 +143,10 @@ def _schedule_correlation_nchw(cfg, s, correlation):
         s[load].bind(tx, te.thread_axis("threadIdx.x"))
 
     # unroll
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
+                     cfg['auto_unroll_max_step'].val)
+    s[output].pragma(kernel_scope, 'unroll_explicit',
+                     cfg['unroll_explicit'].val)
 
 
 @autotvm.register_topi_schedule("correlation_nchw.cuda")
index 6def731..d97d501 100644 (file)
@@ -71,7 +71,7 @@ def _schedule_direct_cuda(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.id.name in ['nvptx', 'rocm']:
+    if target.kind.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
index a6d1c05..bb51c40 100644 (file)
@@ -95,7 +95,7 @@ def _schedule_dense_tensorcore(cfg, s, C):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.id.name, target.model, 'dense_tensorcore.cuda')
+            target.kind.name, target.model, 'dense_tensorcore.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     # Deal with op fusion, such as bias and relu
index f9ef8b6..f2f7a04 100644 (file)
@@ -61,7 +61,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
 
             target = tvm.target.Target.current()
-            if target.id.name in ['nvptx', 'rocm']:
+            if target.kind.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -69,7 +69,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    target.id.name, target.model, 'depthwise_conv2d_nchw.cuda')
+                    target.kind.name, target.model, 'depthwise_conv2d_nchw.cuda')
                 cfg.fallback_with_reference_log(ref_log)
                 # TODO(lmzheng): A bug here, set unroll_explicit to False as workaround
                 cfg['unroll_explicit'].val = 0
@@ -169,7 +169,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
         # num_thread here could be 728, it is larger than cuda.max_num_threads
         num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value
         target = tvm.target.Target.current()
-        if target and (target.id.name not in ["cuda", "nvptx"]):
+        if target and (target.kind.name not in ["cuda", "nvptx"]):
             num_thread = target.max_num_threads
         xoc, xic = s[Output].split(c, factor=num_thread)
         s[Output].reorder(xoc, b, h, w, xic)
index e5cbe3e..ab7db66 100644 (file)
@@ -83,7 +83,7 @@ def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
     cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
 
     target = tvm.target.Target.current()
-    if target.id.name in ['nvptx', 'rocm']:
+    if target.kind.name in ['nvptx', 'rocm']:
         cfg.define_knob("unroll_explicit", [1])
     else:
         cfg.define_knob("unroll_explicit", [0, 1])
index 9d3c529..38e3086 100644 (file)
@@ -36,7 +36,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
         all_reduce = False
         num_thread = 32
         target = tvm.target.Target.current()
-        if target and target.id.name == "opencl":
+        if target and target.kind.name == "opencl":
             # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py
             # don't know why
             num_thread = 16
index 910d0f3..ef97651 100644 (file)
@@ -59,9 +59,9 @@ def schedule_softmax(outs):
     #
     # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
     def sched_warp_softmax():
-        if tgt.id.name == "nvptx" or tgt.id.name == "rocm":
+        if tgt.kind.name == "nvptx" or tgt.kind.name == "rocm":
             return softmax.dtype == "float32" or softmax.dtype == "int32"
-        if tgt.id.name != "cuda":
+        if tgt.kind.name != "cuda":
             # this is used as the gpu schedule for other arches which may not have warp reductions
             return False
         return True
index c5e2a6e..7536654 100644 (file)
@@ -53,7 +53,7 @@ def schedule_reorg(outs):
         The computation schedule for reorg.
     """
     target = tvm.target.Target.current(allow_none=False)
-    cpp_target = cpp.TEST_create_target(target.id.name)
+    cpp_target = cpp.TEST_create_target(target.kind.name)
     return cpp.cuda.schedule_injective(cpp_target, outs)
 
 def schedule_nms(outs):
index 93a1dd2..f03c497 100644 (file)
@@ -24,7 +24,7 @@ def default_schedule(outs, auto_inline):
     """Default schedule for llvm."""
     target = tvm.target.Target.current(allow_none=False)
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
-    if target.id.name not in ("llvm", "c"):
+    if target.kind.name not in ("llvm", "c"):
         raise RuntimeError("schedule not registered for '%s'" % target)
     s = te.create_schedule([x.op for x in outs])
     if auto_inline:
index a60b1e7..6360f8b 100644 (file)
@@ -54,7 +54,7 @@ def schedule_injective(outs):
         The computation schedule for the op.
     """
     target = tvm.target.Target.current(allow_none=False)
-    if target.id.name != "llvm":
+    if target.kind.name != "llvm":
         raise RuntimeError("schedule_injective not registered for '%s'" % target)
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     x = outs[0]
index a1db9ab..d0855a0 100644 (file)
@@ -37,7 +37,7 @@ def schedule_reorg(outs):
       The computation schedule for the op.
     """
     target = tvm.target.Target.current(allow_none=False)
-    cpp_target = cpp.TEST_create_target(target.id.name)
+    cpp_target = cpp.TEST_create_target(target.kind.name)
     return cpp.generic.default_schedule(cpp_target, outs, False)
 
 def schedule_get_valid_counts(outs):
index bc2b27b..ffeb9af 100644 (file)
@@ -62,7 +62,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
 
             target = tvm.target.Target.current()
-            if target.id.name in ['nvptx', 'rocm']:
+            if target.kind.name in ['nvptx', 'rocm']:
                 cfg.define_knob("unroll_explicit", [1])
             else:
                 cfg.define_knob("unroll_explicit", [0, 1])
@@ -70,7 +70,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    target.id.name, target.model, 'depthwise_conv2d_nchw.intel_graphics')
+                    target.kind.name, target.model, 'depthwise_conv2d_nchw.intel_graphics')
                 cfg.fallback_with_reference_log(ref_log)
                 cfg['unroll_explicit'].val = 0
             ##### space definition end #####
@@ -170,7 +170,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
         # num_thread here could be 728, it is larger than cuda.max_num_threads
         num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value
         target = tvm.target.Target.current()
-        if target and (target.id.name not in ["cuda", "nvptx"]):
+        if target and (target.kind.name not in ["cuda", "nvptx"]):
             num_thread = target.max_num_threads
         xoc, xic = s[Output].split(c, factor=num_thread)
         s[Output].reorder(xoc, b, h, w, xic)
index 9cc21f2..e632d4e 100644 (file)
@@ -44,7 +44,7 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l
 
 HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
                                                             const Target& target_host) {
-  if (target->id->name == "llvm") {
+  if (target->kind->name == "llvm") {
     return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64);
   } else {
     LOG(FATAL) << "No default hardware parameters for target: " << target;
index 2c08ea1..142bdfc 100644 (file)
@@ -56,7 +56,7 @@ bool LLVMEnabled() {
 
 /*! \return The default host target for a given device target */
 Target DefaultTargetHost(Target target) {
-  if (target.defined() && target->id->device_type == kDLCPU) {
+  if (target.defined() && target->kind->device_type == kDLCPU) {
     return target;
   } else {
     if (LLVMEnabled()) {
@@ -239,7 +239,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
                  << " but cannot find device code. Did you forget to bind?";
   }
 
-  if (target->id->device_type == kDLCPU && target_host == target) {
+  if (target->kind->device_type == kDLCPU && target_host == target) {
     CHECK(mdevice->functions.empty()) << "No device code should be generated when target "
                                       << "and host_target are both llvm target."
                                       << "\n";
@@ -256,7 +256,7 @@ runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_
   Target target_host_val = target_host;
   if (!target_host.defined()) {
     for (const auto& it : inputs) {
-      if (it.first->id->device_type == kDLCPU || it.first->id->device_type == kDLMicroDev) {
+      if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) {
         target_host_val = it.first;
         break;
       }
index bfcc2a6..4d84c48 100644 (file)
@@ -444,7 +444,7 @@ class RelayBuildModule : public runtime::ModuleNode {
       if (!target_host.defined())
         target_host = (pf != nullptr) ? target::llvm() : target::stackvm();
 
-      if (target_host.defined() && target_host->id->name == "llvm") {
+      if (target_host.defined() && target_host->kind->name == "llvm") {
         // If we can decide the target is LLVM, we then create an empty LLVM module.
         ret_.mod = (*pf)(target_host->str(), "empty_module");
       } else {
@@ -470,7 +470,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     Target target_host = target_host_;
     if (!target_host_.defined()) {
       for (const auto& it : targets_) {
-        if (it.second->id->device_type == kDLCPU) {
+        if (it.second->kind->device_type == kDLCPU) {
           target_host = it.second;
           break;
         }
index 52bd1c2..0ac4993 100644 (file)
@@ -48,10 +48,10 @@ runtime::Module Build(IRModule mod, const Target& target) {
     mod = tir::transform::SkipAssert()(mod);
   }
   std::string build_f_name;
-  if (target->id->name == "micro_dev") {
+  if (target->kind->name == "micro_dev") {
     build_f_name = "target.build.c";
   } else {
-    build_f_name = "target.build." + target->id->name;
+    build_f_name = "target.build." + target->kind->name;
   }
   // the build function.
   const PackedFunc* bf = runtime::Registry::Get(build_f_name);
index fb9c34e..94b5b03 100644 (file)
@@ -24,7 +24,7 @@
 #include <tvm/node/repr_printer.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/target/target.h>
-#include <tvm/target/target_id.h>
+#include <tvm/target/target_kind.h>
 #include <tvm/tir/expr.h>
 
 #include <algorithm>
@@ -37,13 +37,13 @@ using runtime::TVMArgs;
 using runtime::TVMRetValue;
 
 Target Target::CreateTarget(const std::string& name, const std::vector<std::string>& options) {
-  TargetId id = TargetId::Get(name);
+  TargetKind kind = TargetKind::Get(name);
   ObjectPtr<TargetNode> target = make_object<TargetNode>();
-  target->id = id;
+  target->kind = kind;
   // tag is always empty
   target->tag = "";
   // parse attrs
-  target->attrs = id->ParseAttrsFromRaw(options);
+  target->attrs = kind->ParseAttrsFromRaw(options);
   String device_name = target->GetAttr<String>("device", "").value();
   // set up keys
   {
@@ -58,7 +58,7 @@ Target Target::CreateTarget(const std::string& name, const std::vector<std::stri
       keys.push_back(device_name);
     }
     // add default keys
-    for (const auto& key : target->id->default_keys) {
+    for (const auto& key : target->kind->default_keys) {
       keys.push_back(key);
     }
     // de-duplicate keys
@@ -127,7 +127,7 @@ std::unordered_set<std::string> TargetNode::GetLibs() const {
 const std::string& TargetNode::str() const {
   if (str_repr_.empty()) {
     std::ostringstream os;
-    os << id->name;
+    os << kind->name;
     if (!this->keys.empty()) {
       os << " -keys=";
       bool is_first = true;
@@ -140,7 +140,7 @@ const std::string& TargetNode::str() const {
         os << s;
       }
     }
-    if (Optional<String> attrs_str = id->StringifyAttrsToRaw(attrs)) {
+    if (Optional<String> attrs_str = kind->StringifyAttrsToRaw(attrs)) {
       os << ' ' << attrs_str.value();
     }
     str_repr_ = os.str();
similarity index 83%
rename from src/target/target_id.cc
rename to src/target/target_kind.cc
index 735ec16..0bef651 100644 (file)
  */
 
 /*!
- * \file src/target/target_id.cc
- * \brief Target id registry
+ * \file src/target/target_kind.cc
+ * \brief Target kind registry
  */
-#include <tvm/target/target_id.h>
+#include <tvm/target/target_kind.h>
 
 #include <algorithm>
 
 
 namespace tvm {
 
-TVM_REGISTER_NODE_TYPE(TargetIdNode);
+TVM_REGISTER_NODE_TYPE(TargetKindNode);
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<TargetIdNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const TargetIdNode*>(node.get());
+    .set_dispatch<TargetKindNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const TargetKindNode*>(node.get());
       p->stream << op->name;
     });
 
-using TargetIdRegistry = AttrRegistry<TargetIdRegEntry, TargetId>;
+using TargetKindRegistry = AttrRegistry<TargetKindRegEntry, TargetKind>;
 
-TargetIdRegEntry& TargetIdRegEntry::RegisterOrGet(const String& target_id_name) {
-  return TargetIdRegistry::Global()->RegisterOrGet(target_id_name);
+TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) {
+  return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name);
 }
 
-void TargetIdRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
-  TargetIdRegistry::Global()->UpdateAttr(key, id_, value, plevel);
+void TargetKindRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
+  TargetKindRegistry::Global()->UpdateAttr(key, kind_, value, plevel);
 }
 
-const AttrRegistryMapContainerMap<TargetId>& TargetId::GetAttrMapContainer(
+const AttrRegistryMapContainerMap<TargetKind>& TargetKind::GetAttrMapContainer(
     const String& attr_name) {
-  return TargetIdRegistry::Global()->GetAttrMap(attr_name);
+  return TargetKindRegistry::Global()->GetAttrMap(attr_name);
 }
 
-const TargetId& TargetId::Get(const String& target_id_name) {
-  const TargetIdRegEntry* reg = TargetIdRegistry::Global()->Get(target_id_name);
-  CHECK(reg != nullptr) << "ValueError: TargetId \"" << target_id_name << "\" is not registered";
-  return reg->id_;
+const TargetKind& TargetKind::Get(const String& target_kind_name) {
+  const TargetKindRegEntry* reg = TargetKindRegistry::Global()->Get(target_kind_name);
+  CHECK(reg != nullptr) << "ValueError: TargetKind \"" << target_kind_name
+                        << "\" is not registered";
+  return reg->kind_;
 }
 
-void TargetIdNode::VerifyTypeInfo(const ObjectRef& obj,
-                                  const TargetIdNode::ValueTypeInfo& info) const {
+void TargetKindNode::VerifyTypeInfo(const ObjectRef& obj,
+                                    const TargetKindNode::ValueTypeInfo& info) const {
   CHECK(obj.defined()) << "Object is None";
   if (!runtime::ObjectInternal::DerivedFrom(obj.get(), info.type_index)) {
     LOG(FATAL) << "AttributeError: expect type \"" << info.type_key << "\" but get "
@@ -102,17 +103,17 @@ void TargetIdNode::VerifyTypeInfo(const ObjectRef& obj,
   }
 }
 
-void TargetIdNode::ValidateSchema(const Map<String, ObjectRef>& config) const {
-  const String kTargetId = "id";
+void TargetKindNode::ValidateSchema(const Map<String, ObjectRef>& config) const {
+  const String kTargetKind = "kind";
   for (const auto& kv : config) {
     const String& name = kv.first;
     const ObjectRef& obj = kv.second;
-    if (name == kTargetId) {
+    if (name == kTargetKind) {
       CHECK(obj->IsInstance<StringObj>())
-          << "AttributeError: \"id\" is not a string, but its type is \"" << obj->GetTypeKey()
+          << "AttributeError: \"kind\" is not a string, but its type is \"" << obj->GetTypeKey()
           << "\"";
       CHECK(Downcast<String>(obj) == this->name)
-          << "AttributeError: \"id\" = \"" << obj << "\" is inconsistent with TargetId \""
+          << "AttributeError: \"kind\" = \"" << obj << "\" is inconsistent with TargetKind \""
           << this->name << "\"";
       continue;
     }
@@ -131,7 +132,7 @@ void TargetIdNode::ValidateSchema(const Map<String, ObjectRef>& config) const {
     try {
       VerifyTypeInfo(obj, info);
     } catch (const tvm::Error& e) {
-      LOG(FATAL) << "AttributeError: Schema validation failed for TargetId \"" << this->name
+      LOG(FATAL) << "AttributeError: Schema validation failed for TargetKind \"" << this->name
                  << "\", details:\n"
                  << e.what() << "\n"
                  << "The config is:\n"
@@ -141,12 +142,13 @@ void TargetIdNode::ValidateSchema(const Map<String, ObjectRef>& config) const {
   }
 }
 
-inline String GetId(const Map<String, ObjectRef>& target, const char* name) {
-  const String kTargetId = "id";
-  CHECK(target.count(kTargetId)) << "AttributeError: \"id\" does not exist in \"" << name << "\"\n"
-                                 << name << " = " << target;
-  const ObjectRef& obj = target[kTargetId];
-  CHECK(obj->IsInstance<StringObj>()) << "AttributeError: \"id\" is not a string in \"" << name
+inline String GetKind(const Map<String, ObjectRef>& target, const char* name) {
+  const String kTargetKind = "kind";
+  CHECK(target.count(kTargetKind))
+      << "AttributeError: \"kind\" does not exist in \"" << name << "\"\n"
+      << name << " = " << target;
+  const ObjectRef& obj = target[kTargetKind];
+  CHECK(obj->IsInstance<StringObj>()) << "AttributeError: \"kind\" is not a string in \"" << name
                                       << "\", but its type is \"" << obj->GetTypeKey() << "\"\n"
                                       << name << " = \"" << target << '"';
   return Downcast<String>(obj);
@@ -157,16 +159,16 @@ void TargetValidateSchema(const Map<String, ObjectRef>& config) {
     const String kTargetHost = "target_host";
     Map<String, ObjectRef> target = config;
     Map<String, ObjectRef> target_host;
-    String target_id = GetId(target, "target");
-    String target_host_id;
+    String target_kind = GetKind(target, "target");
+    String target_host_kind;
     if (config.count(kTargetHost)) {
       target.erase(kTargetHost);
       target_host = Downcast<Map<String, ObjectRef>>(config[kTargetHost]);
-      target_host_id = GetId(target_host, "target_host");
+      target_host_kind = GetKind(target_host, "target_host");
     }
-    TargetId::Get(target_id)->ValidateSchema(target);
+    TargetKind::Get(target_kind)->ValidateSchema(target);
     if (!target_host.empty()) {
-      TargetId::Get(target_host_id)->ValidateSchema(target_host);
+      TargetKind::Get(target_host_kind)->ValidateSchema(target_host);
     }
   } catch (const tvm::Error& e) {
     LOG(FATAL) << "AttributeError: schedule validation fails:\n"
@@ -230,7 +232,7 @@ static inline Optional<String> Join(const std::vector<String>& array, char separ
   return String(os.str());
 }
 
-Map<String, ObjectRef> TargetIdNode::ParseAttrsFromRaw(
+Map<String, ObjectRef> TargetKindNode::ParseAttrsFromRaw(
     const std::vector<std::string>& options) const {
   std::unordered_map<String, ObjectRef> attrs;
   for (size_t iter = 0, end = options.size(); iter < end;) {
@@ -313,7 +315,7 @@ Map<String, ObjectRef> TargetIdNode::ParseAttrsFromRaw(
   return attrs;
 }
 
-Optional<String> TargetIdNode::StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) const {
+Optional<String> TargetKindNode::StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) const {
   std::ostringstream os;
   std::vector<String> keys;
   for (const auto& kv : attrs) {
@@ -348,7 +350,7 @@ Optional<String> TargetIdNode::StringifyAttrsToRaw(const Map<String, ObjectRef>&
 
 // TODO(@junrushao1994): remove some redundant attributes
 
-TVM_REGISTER_TARGET_ID("llvm")
+TVM_REGISTER_TARGET_KIND("llvm")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -361,7 +363,7 @@ TVM_REGISTER_TARGET_ID("llvm")
     .set_default_keys({"cpu"})
     .set_device_type(kDLCPU);
 
-TVM_REGISTER_TARGET_ID("c")
+TVM_REGISTER_TARGET_KIND("c")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -370,7 +372,7 @@ TVM_REGISTER_TARGET_ID("c")
     .set_default_keys({"cpu"})
     .set_device_type(kDLCPU);
 
-TVM_REGISTER_TARGET_ID("micro_dev")
+TVM_REGISTER_TARGET_KIND("micro_dev")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -379,7 +381,7 @@ TVM_REGISTER_TARGET_ID("micro_dev")
     .set_default_keys({"micro_dev"})
     .set_device_type(kDLMicroDev);
 
-TVM_REGISTER_TARGET_ID("cuda")
+TVM_REGISTER_TARGET_KIND("cuda")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -391,7 +393,7 @@ TVM_REGISTER_TARGET_ID("cuda")
     .set_default_keys({"cuda", "gpu"})
     .set_device_type(kDLGPU);
 
-TVM_REGISTER_TARGET_ID("nvptx")
+TVM_REGISTER_TARGET_KIND("nvptx")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -403,7 +405,7 @@ TVM_REGISTER_TARGET_ID("nvptx")
     .set_default_keys({"cuda", "gpu"})
     .set_device_type(kDLGPU);
 
-TVM_REGISTER_TARGET_ID("rocm")
+TVM_REGISTER_TARGET_KIND("rocm")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -414,7 +416,7 @@ TVM_REGISTER_TARGET_ID("rocm")
     .set_default_keys({"rocm", "gpu"})
     .set_device_type(kDLROCM);
 
-TVM_REGISTER_TARGET_ID("opencl")
+TVM_REGISTER_TARGET_KIND("opencl")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -425,7 +427,7 @@ TVM_REGISTER_TARGET_ID("opencl")
     .set_default_keys({"opencl", "gpu"})
     .set_device_type(kDLOpenCL);
 
-TVM_REGISTER_TARGET_ID("metal")
+TVM_REGISTER_TARGET_KIND("metal")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -435,7 +437,7 @@ TVM_REGISTER_TARGET_ID("metal")
     .set_default_keys({"metal", "gpu"})
     .set_device_type(kDLMetal);
 
-TVM_REGISTER_TARGET_ID("vulkan")
+TVM_REGISTER_TARGET_KIND("vulkan")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -445,7 +447,7 @@ TVM_REGISTER_TARGET_ID("vulkan")
     .set_default_keys({"vulkan", "gpu"})
     .set_device_type(kDLVulkan);
 
-TVM_REGISTER_TARGET_ID("webgpu")
+TVM_REGISTER_TARGET_KIND("webgpu")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -455,7 +457,7 @@ TVM_REGISTER_TARGET_ID("webgpu")
     .set_default_keys({"webgpu", "gpu"})
     .set_device_type(kDLWebGPU);
 
-TVM_REGISTER_TARGET_ID("sdaccel")
+TVM_REGISTER_TARGET_KIND("sdaccel")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -464,7 +466,7 @@ TVM_REGISTER_TARGET_ID("sdaccel")
     .set_default_keys({"sdaccel", "hls"})
     .set_device_type(kDLOpenCL);
 
-TVM_REGISTER_TARGET_ID("aocl")
+TVM_REGISTER_TARGET_KIND("aocl")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -473,7 +475,7 @@ TVM_REGISTER_TARGET_ID("aocl")
     .set_default_keys({"aocl", "hls"})
     .set_device_type(kDLAOCL);
 
-TVM_REGISTER_TARGET_ID("aocl_sw_emu")
+TVM_REGISTER_TARGET_KIND("aocl_sw_emu")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -482,7 +484,7 @@ TVM_REGISTER_TARGET_ID("aocl_sw_emu")
     .set_default_keys({"aocl", "hls"})
     .set_device_type(kDLAOCL);
 
-TVM_REGISTER_TARGET_ID("hexagon")
+TVM_REGISTER_TARGET_KIND("hexagon")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -491,7 +493,7 @@ TVM_REGISTER_TARGET_ID("hexagon")
     .set_default_keys({"hexagon"})
     .set_device_type(kDLHexagon);
 
-TVM_REGISTER_TARGET_ID("stackvm")
+TVM_REGISTER_TARGET_KIND("stackvm")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -499,7 +501,7 @@ TVM_REGISTER_TARGET_ID("stackvm")
     .add_attr_option<Bool>("system-lib")
     .set_device_type(kDLCPU);
 
-TVM_REGISTER_TARGET_ID("ext_dev")
+TVM_REGISTER_TARGET_KIND("ext_dev")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
@@ -507,7 +509,7 @@ TVM_REGISTER_TARGET_ID("ext_dev")
     .add_attr_option<Bool>("system-lib")
     .set_device_type(kDLExtDev);
 
-TVM_REGISTER_TARGET_ID("hybrid")
+TVM_REGISTER_TARGET_KIND("hybrid")
     .add_attr_option<Array<String>>("keys")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
index 7541662..7c4a3c7 100644 (file)
@@ -1083,7 +1083,7 @@ Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
                                           Map<Tensor, Buffer> extern_buffer) {
   // Check if current lower target is CUDA
   auto target = tvm::Target::Current(true);
-  if (target.defined() && target->id->name != "cuda") {
+  if (target.defined() && target->kind->name != "cuda") {
     return stmt;
   }
 
index f8a5986..dfad549 100644 (file)
@@ -177,7 +177,7 @@ bool VerifyMemory(const PrimFunc& func) {
 
   if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
       CallingConv::kDefault) {
-    MemoryAccessVerifier v(func, target.value()->id->device_type);
+    MemoryAccessVerifier v(func, target.value()->kind->device_type);
     v.Run();
     return !v.Failed();
   } else {
index f5491da..7fd2352 100644 (file)
@@ -139,7 +139,7 @@ Pass LowerCustomDatatypes() {
     auto target = f->GetAttr<Target>(tvm::attr::kTarget);
     CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute";
 
-    n->body = CustomDatatypesLowerer(target.value()->id->name)(std::move(n->body));
+    n->body = CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
index f3fe945..8774fc3 100644 (file)
@@ -311,7 +311,7 @@ Pass LowerIntrin() {
     arith::Analyzer analyzer;
     auto mtriple = target.value()->GetAttr<runtime::String>("mtriple", "");
     n->body =
-        IntrinInjecter(&analyzer, target.value()->id->name, mtriple.value())(std::move(n->body));
+        IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
index 17b4265..bd216bb 100644 (file)
@@ -484,10 +484,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
   bool is_warp_reduction(const std::vector<DataType>& types) const {
     // Only cuda target supports warp reductions.
-    if ((target_->id->name != "cuda") && (target_->id->name != "rocm")) return false;
+    if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;
 
     // rocm only supports 32 bit operands for shuffling at the moment
-    if ((target_->id->name == "rocm") && (std::any_of(types.begin(), types.end(), [](DataType ty) {
+    if ((target_->kind->name == "rocm") &&
+        (std::any_of(types.begin(), types.end(), [](DataType ty) {
           if (ty.is_vector()) return true;
           return ty.bits() != 32;
         }))) {
index 9519fa6..3fae2bb 100644 (file)
@@ -51,7 +51,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
 
   auto target = func->GetAttr<Target>(tvm::attr::kTarget);
   CHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
-  int target_device_type = target.value()->id->device_type;
+  int target_device_type = target.value()->kind->device_type;
 
   std::string name_hint = global_symbol.value();
 
index 6b8e0b1..8bee707 100644 (file)
 
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
-#include <tvm/target/target_id.h>
+#include <tvm/target/target_kind.h>
 
 #include <cmath>
 #include <string>
 
 using namespace tvm;
 
-TVM_REGISTER_TARGET_ID("TestTargetId")
+TVM_REGISTER_TARGET_KIND("TestTargetKind")
     .set_attr<std::string>("Attr1", "Value1")
     .add_attr_option<Bool>("my_bool")
     .add_attr_option<Array<String>>("your_names")
     .add_attr_option<Map<String, Integer>>("her_maps");
 
-TEST(TargetId, GetAttrMap) {
-  auto map = tvm::TargetId::GetAttrMap<std::string>("Attr1");
-  auto target_id = tvm::TargetId::Get("TestTargetId");
-  std::string result = map[target_id];
+TEST(TargetKind, GetAttrMap) {
+  auto map = tvm::TargetKind::GetAttrMap<std::string>("Attr1");
+  auto target_kind = tvm::TargetKind::Get("TestTargetKind");
+  std::string result = map[target_kind];
   CHECK_EQ(result, "Value1");
 }
 
-TEST(TargetId, SchemaValidation) {
+TEST(TargetKind, SchemaValidation) {
   tvm::Map<String, ObjectRef> target;
   {
     tvm::Array<String> your_names{"junru", "jian"};
@@ -50,7 +50,7 @@ TEST(TargetId, SchemaValidation) {
     target.Set("my_bool", Bool(true));
     target.Set("your_names", your_names);
     target.Set("her_maps", her_maps);
-    target.Set("id", String("TestTargetId"));
+    target.Set("kind", String("TestTargetKind"));
   }
   TargetValidateSchema(target);
   tvm::Map<String, ObjectRef> target_host(target.begin(), target.end());
@@ -58,7 +58,7 @@ TEST(TargetId, SchemaValidation) {
   TargetValidateSchema(target);
 }
 
-TEST(TargetId, SchemaValidationFail) {
+TEST(TargetKind, SchemaValidationFail) {
   tvm::Map<String, ObjectRef> target;
   {
     tvm::Array<String> your_names{"junru", "jian"};
@@ -70,7 +70,7 @@ TEST(TargetId, SchemaValidationFail) {
     target.Set("your_names", your_names);
     target.Set("her_maps", her_maps);
     target.Set("ok", ObjectRef(nullptr));
-    target.Set("id", String("TestTargetId"));
+    target.Set("kind", String("TestTargetKind"));
   }
   bool failed = false;
   try {
index fb365c8..4258da9 100644 (file)
@@ -18,19 +18,23 @@ import tvm
 from tvm import te
 from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, hexagon
 
+
 @tvm.target.generic_func
 def mygeneric(data):
     # default generic function
     return data + 1
 
+
 @mygeneric.register(["cuda", "gpu"])
 def cuda_func(data):
     return data + 2
 
+
 @mygeneric.register("rocm")
 def rocm_func(data):
     return data + 3
 
+
 @mygeneric.register("cpu")
 def rocm_func(data):
     return data + 10
@@ -58,7 +62,7 @@ def test_target_dispatch():
 def test_target_string_parse():
     target = tvm.target.create("cuda -model=unknown -libs=cublas,cudnn")
 
-    assert target.id.name == "cuda"
+    assert target.kind.name == "cuda"
     assert target.model == "unknown"
     assert set(target.keys) == set(['cuda', 'gpu'])
     assert set(target.libs) == set(['cublas', 'cudnn'])
@@ -70,7 +74,8 @@ def test_target_string_parse():
 
 
 def test_target_create():
-    targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu('rk3399'), vta(), bifrost()]
+    targets = [cuda(), rocm(), mali(), intel_graphics(),
+               arm_cpu('rk3399'), vta(), bifrost()]
     for tgt in targets:
         assert tgt is not None
 
@@ -78,4 +83,4 @@ def test_target_create():
 if __name__ == "__main__":
     test_target_dispatch()
     test_target_string_parse()
-    test_target_create()
\ No newline at end of file
+    test_target_create()