From: Junru Shao Date: Thu, 2 Jul 2020 19:23:57 +0000 (-0700) Subject: [Target] Migrate data structure of TargetNode (#5960) X-Git-Tag: upstream/0.7.0~455 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6ce8a1cb5fbddf1acf0ed9e00eef2d3e5071f86f;p=platform%2Fupstream%2Ftvm.git [Target] Migrate data structure of TargetNode (#5960) --- diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index c85349d..30ae19a 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -42,45 +43,50 @@ namespace tvm { */ class TargetNode : public Object { public: - /*! \brief The name of the target device */ - std::string target_name; - /*! \brief The name of the target device */ - std::string device_name; - /*! \brief The type of the target device */ - int device_type; - /*! \brief The maximum threads that a schedule should use for this device */ - int max_num_threads = 1; - /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ - int thread_warp_size = 1; + /*! \brief The id of the target device */ + TargetId id; + /*! \brief Tag of the the target, can be empty */ + String tag; /*! \brief Keys for this target */ - Array keys_array; - /*! \brief Options for this target */ - Array options_array; - /*! \brief Collection of imported libs */ - Array libs_array; + Array keys; + /*! \brief Collection of attributes */ + Map attrs; /*! \return the full device string to pass to codegen::Build */ TVM_DLL const std::string& str() const; void VisitAttrs(AttrVisitor* v) { - v->Visit("target_name", &target_name); - v->Visit("device_name", &device_name); - v->Visit("device_type", &device_type); - v->Visit("max_num_threads", &max_num_threads); - v->Visit("thread_warp_size", &thread_warp_size); - v->Visit("keys_array", &keys_array); - v->Visit("options_array", &options_array); - v->Visit("libs_array", &libs_array); + v->Visit("id", &id); + v->Visit("tag", &tag); + v->Visit("keys_", &keys); + v->Visit("attrs", &attrs); + v->Visit("_str_repr_", &str_repr_); } - /*! \brief Get the keys for this target as a vector of string */ - TVM_DLL std::vector keys() const; + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + auto it = attrs.find(attr_key); + if (it != attrs.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } - /*! \brief Get the options for this target as a vector of string */ - TVM_DLL std::vector options() const; + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + /*! \brief Get the keys for this target as a vector of string */ + TVM_DLL std::vector GetKeys() const; /*! \brief Get the keys for this target as an unordered_set of string */ - TVM_DLL std::unordered_set libs() const; + TVM_DLL std::unordered_set GetLibs() const; static constexpr const char* _type_key = "Target"; TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); @@ -88,6 +94,7 @@ class TargetNode : public Object { private: /*! \brief Internal string repr. */ mutable std::string str_repr_; + friend class Target; }; /*! @@ -102,7 +109,17 @@ class Target : public ObjectRef { * \brief Create a Target given a string * \param target_str the string to parse */ - TVM_DLL static Target Create(const std::string& target_str); + TVM_DLL static Target Create(const String& target_str); + /*! + * \brief Construct a Target node from the given name and options. + * \param name The major target name. Should be one of + * {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm", + * "metal", "nvptx", "opencl", "rocm", "sdaccel", "stackvm", "vulkan"} + * \param options Additional options appended to the target + * \return The constructed Target + */ + TVM_DLL static Target CreateTarget(const std::string& name, + const std::vector& options); /*! * \brief Get the current target context from thread local storage. * \param allow_not_defined If the context stack is empty and this is set to true, an diff --git a/include/tvm/target/target_id.h b/include/tvm/target/target_id.h index 93c88c75..e8d53a3 100644 --- a/include/tvm/target/target_id.h +++ b/include/tvm/target/target_id.h @@ -43,6 +43,8 @@ template struct ValueTypeInfoMaker; } +class Target; + /*! \brief Perform schema validation */ TVM_DLL void TargetValidateSchema(const Map& config); @@ -54,6 +56,10 @@ class TargetIdNode : public Object { public: /*! \brief Name of the target id */ String name; + /*! \brief Device type of target id */ + int device_type; + /*! \brief Default keys of the target */ + Array default_keys; /*! \brief Stores the required type_key and type_index of a specific attr of a target */ struct ValueTypeInfo { String type_key; @@ -62,6 +68,14 @@ class TargetIdNode : public Object { std::unique_ptr val; }; + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("device_type", &device_type); + v->Visit("default_keys", &default_keys); + } + + Map ParseAttrsFromRawString(const std::vector& options); + static constexpr const char* _type_key = "TargetId"; TVM_DECLARE_FINAL_OBJECT_INFO(TargetIdNode, Object); @@ -72,9 +86,12 @@ class TargetIdNode : public Object { void ValidateSchema(const Map& config) const; /*! \brief A hash table that stores the type information of each attr of the target key */ std::unordered_map key2vtype_; + /*! \brief A hash table that stores the default value of each attr of the target key */ + std::unordered_map key2default_; /*! \brief Index used for internal lookup of attribute registry */ uint32_t index_; friend void TargetValidateSchema(const Map&); + friend class Target; friend class TargetId; template friend class AttrRegistry; @@ -91,6 +108,7 @@ class TargetIdNode : public Object { */ class TargetId : public ObjectRef { public: + TargetId() = default; /*! \brief Get the attribute map given the attribute name */ template static inline TargetIdAttrMap GetAttrMap(const String& attr_name); @@ -110,6 +128,7 @@ class TargetId : public ObjectRef { template friend class AttrRegistry; friend class TargetIdRegEntry; + friend class Target; }; /*! @@ -149,12 +168,30 @@ class TargetIdRegEntry { inline TargetIdRegEntry& set_attr(const String& attr_name, const ValueType& value, int plevel = 10); /*! + * \brief Set DLPack's device_type the target + * \param device_type Device type + */ + inline TargetIdRegEntry& set_device_type(int device_type); + /*! + * \brief Set DLPack's device_type the target + * \param keys The default keys + */ + inline TargetIdRegEntry& set_default_keys(std::vector keys); + /*! * \brief Register a valid configuration option and its ValueType for validation * \param key The configuration key * \tparam ValueType The value type to be registered */ template inline TargetIdRegEntry& add_attr_option(const String& key); + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \param default_value The default value of the key + * \tparam ValueType The value type to be registered + */ + template + inline TargetIdRegEntry& add_attr_option(const String& key, ObjectRef default_value); /*! \brief Set name of the TargetId to be the same as registry if it is empty */ inline TargetIdRegEntry& set_name(); /*! @@ -286,6 +323,16 @@ inline TargetIdRegEntry& TargetIdRegEntry::set_attr(const String& attr_name, con return *this; } +inline TargetIdRegEntry& TargetIdRegEntry::set_device_type(int device_type) { + id_->device_type = device_type; + return *this; +} + +inline TargetIdRegEntry& TargetIdRegEntry::set_default_keys(std::vector keys) { + id_->default_keys = keys; + return *this; +} + template inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) { CHECK(!id_->key2vtype_.count(key)) @@ -294,6 +341,14 @@ inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) { return *this; } +template +inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key, + ObjectRef default_value) { + add_attr_option(key); + id_->key2default_[key] = default_value; + return *this; +} + inline TargetIdRegEntry& TargetIdRegEntry::set_name() { if (id_->name.empty()) { id_->name = name; diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index a11c16b..c0e4eed 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -103,11 +103,10 @@ def context(target, extra_files=None): tgt = _target.create(tgt) possible_names = [] - for opt in tgt.options: - if opt.startswith("-device"): - device = _alias(opt[8:]) - possible_names.append(device) - possible_names.append(tgt.target_name) + device = tgt.attrs.get("device", "") + if device != "": + possible_names.append(_alias(device)) + possible_names.append(tgt.id.name) all_packages = list(PACKAGE_VERSION.keys()) for name in possible_names: diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 47e9a81..b107000 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -238,7 +238,7 @@ def _build_for_device(input_mod, target, target_host): """ target = _target.create(target) target_host = _target.create(target_host) - device_type = ndarray.context(target.target_name, 0).device_type + device_type = ndarray.context(target.id.name, 0).device_type mod_mixed = input_mod mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) @@ -402,7 +402,7 @@ def build(inputs, if not target_host: for tar, _ in target_input_mod.items(): tar = _target.create(tar) - device_type = ndarray.context(tar.target_name, 0).device_type + device_type = ndarray.context(tar.id.name, 0).device_type if device_type == ndarray.cpu(0).device_type: target_host = tar break diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index a1c88b8..d626a9d 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -68,7 +68,7 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_softmax(topi.nn.softmax), wrap_topi_schedule(topi.cuda.schedule_softmax), name="softmax.cuda") - if target.target_name == "cuda" and "cudnn" in target.libs: + if target.id.name == "cuda" and "cudnn" in target.libs: strategy.add_implementation( wrap_compute_softmax(topi.cuda.softmax_cudnn), wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn), @@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): dilation_h, dilation_w, pre_flag=False) if judge_winograd_shape: - if target.target_name == "cuda" and \ + if target.id.name == "cuda" and \ nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \ judge_winograd_tensorcore: strategy.add_implementation( @@ -162,7 +162,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): topi.cuda.schedule_conv2d_nhwc_winograd_direct), name="conv2d_nhwc_winograd_direct.cuda", plevel=5) - if target.target_name == "cuda": + if target.id.name == "cuda": if nvcc.have_tensorcore(tvm.gpu(0).compute_version): if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ @@ -181,7 +181,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) # add cudnn implementation - if target.target_name == "cuda" and "cudnn" in target.libs: + if target.id.name == "cuda" and "cudnn" in target.libs: if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \ padding[1] == padding[3]: strategy.add_implementation( @@ -209,7 +209,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): else: # group_conv2d # add cudnn implementation, if any cudnn_impl = False - if target.target_name == "cuda" and "cudnn" in target.libs: + if target.id.name == "cuda" and "cudnn" in target.libs: if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \ padding[1] == padding[3]: strategy.add_implementation( @@ -264,7 +264,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty padding, stride_h, stride_w, dilation_h, dilation_w, pre_flag=True) - if target.target_name == "cuda" and \ + if target.id.name == "cuda" and \ nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \ judge_winograd_tensorcore: strategy.add_implementation( @@ -362,7 +362,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target): plevel=10) N, _, _, _, _ = get_const_tuple(data.shape) _, _, _, CI, CO = get_const_tuple(kernel.shape) - if target.target_name == "cuda": + if target.id.name == "cuda": if nvcc.have_tensorcore(tvm.gpu(0).compute_version): if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ @@ -373,7 +373,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target): name="conv3d_ndhwc_tensorcore.cuda", plevel=20) - if target.target_name == "cuda" and "cudnn" in target.libs: + if target.id.name == "cuda" and "cudnn" in target.libs: strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True), wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn), name="conv3d_cudnn.cuda", @@ -458,7 +458,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_dense_large_batch), name="dense_large_batch.cuda", plevel=5) - if target.target_name == "cuda": + if target.id.name == "cuda": if nvcc.have_tensorcore(tvm.gpu(0).compute_version): if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \ or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \ @@ -468,7 +468,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore), name="dense_tensorcore.cuda", plevel=20) - if target.target_name == "cuda" and "cublas" in target.libs: + if target.id.name == "cuda" and "cublas" in target.libs: strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_cublas), wrap_topi_schedule(topi.cuda.schedule_dense_cublas), @@ -485,7 +485,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_batch_matmul), name="batch_matmul.cuda", plevel=10) - if target.target_name == "cuda" and "cublas" in target.libs: + if target.id.name == "cuda" and "cublas" in target.libs: strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas), wrap_topi_schedule(topi.generic.schedule_extern), diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index b1213f1..a80b6ca 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -127,7 +127,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): wrap_compute_dense(topi.rocm.dense), wrap_topi_schedule(topi.rocm.schedule_dense), name="dense.rocm") - if target.target_name == "rocm" and "rocblas" in target.libs: + if target.id.name == "rocm" and "rocblas" in target.libs: assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( wrap_compute_dense(topi.rocm.dense_rocblas), diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 7246214..00866e0 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -229,18 +229,17 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): def is_fast_int8_on_intel(): """ Checks whether the hardware has support for fast Int8 arithmetic operations. """ target = tvm.target.Target.current(allow_none=False) - intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'} - return intel_supported_arches.intersection(set(target.options)) + return target.mcpu in {'skylake-avx512', 'cascadelake'} def is_fast_int8_on_arm(): """ Checks whether the hardware has support for fast Int8 arithmetic operations. """ target = tvm.target.Target.current(allow_none=False) - return '+v8.2a,+dotprod' in ' '.join(target.options) + return '+v8.2a,+dotprod' in target.mattr def is_aarch64_arm(): """ Checks whether we are compiling for an AArch64 target. """ target = tvm.target.Target.current(allow_none=False) - return 'aarch64' in ' '.join(target.options) + return 'aarch64' in target.attrs.get("target", "") ######################## # ARM CPU legalizations. diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py index 9590e87..74a6f60 100644 --- a/python/tvm/relay/quantize/_calibrate.py +++ b/python/tvm/relay/quantize/_calibrate.py @@ -39,7 +39,7 @@ def _get_profile_runtime(mod): if tvm.target.Target.current(): target = tvm.target.Target.current() - ctx = tvm.context(target.target_name) + ctx = tvm.context(target.id.name) else: target = 'llvm' ctx = tvm.context(target) diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 2553fed..18a9e7e 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -16,7 +16,7 @@ # under the License. """Target description and codgen module. -TVM's target string is in fomat `` [-option=value]...``. +TVM's target string is in fomat `` [-option=value]...``. Note ---- diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 3335e12..a2a4501 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -23,6 +23,12 @@ from . import _ffi_api @tvm._ffi.register_object +class TargetId(Object): + """Id of a compilation target + """ + + +@tvm._ffi.register_object class Target(Object): """Target device information, use through TVM API. @@ -41,45 +47,15 @@ class Target(Object): # Always override new to enable class obj = Object.__new__(cls) obj._keys = None - obj._options = None obj._libs = None return obj @property def keys(self): if not self._keys: - self._keys = [str(k) for k in self.keys_array] + self._keys = [str(k) for k in self.keys_] return self._keys - @property - def options(self): - if not self._options: - self._options = [str(o) for o in self.options_array] - return self._options - - @property - def libs(self): - if not self._libs: - self._libs = [str(l) for l in self.libs_array] - return self._libs - - @property - def model(self): - for opt in self.options_array: - if opt.startswith('-model='): - return opt[7:] - return 'unknown' - - @property - def mcpu(self): - """Returns the mcpu from the target if it exists.""" - mcpu = '' - if self.options is not None: - for opt in self.options: - if 'mcpu' in opt: - mcpu = opt.split('=')[1] - return mcpu - def __enter__(self): _ffi_api.EnterTargetScope(self) return self @@ -102,6 +78,40 @@ class Target(Object): """ return _ffi_api.GetCurrentTarget(allow_none) + @property + def max_num_threads(self): + return int(self.attrs["max_num_threads"]) + + @property + def thread_warp_size(self): + return int(self.attrs["thread_warp_size"]) + + @property + def device_name(self): + return str(self.attrs.get("device", "")) + + @property + def model(self): + """Returns model from the target if it exists.""" + return str(self.attrs.get("model", "unknown")) + + @property + def mcpu(self): + """Returns the mcpu from the target if it exists.""" + return str(self.attrs.get("mcpu", "")) + + @property + def mattr(self): + """Returns the mattr from the target if it exists.""" + return self.attrs.get("mattr", "") + + @property + def libs(self): + if not self._libs: + self._libs = list(self.attrs.get("libs", "")) + return self._libs + + def _merge_opts(opts, new_opts): """Helper function to merge options""" @@ -167,7 +177,7 @@ def intel_graphics(model='unknown', options=None): options : str or list of str Additional options """ - opts = ["-device=intel_graphics", '-model=%s' % model] + opts = ["-device=intel_graphics", "-model=%s" % model, "-thread_warp_size=16"] opts = _merge_opts(opts, options) return _ffi_api.TargetCreate("opencl", *opts) @@ -216,7 +226,7 @@ def rasp(options=None): def vta(model='unknown', options=None): - opts = ["-device=vta", '-keys=cpu', '-model=%s' % model] + opts = ["-device=vta", '-keys=vta,cpu', '-model=%s' % model] opts = _merge_opts(opts, options) ret = _ffi_api.TargetCreate("ext_dev", *opts) return ret diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e796f49..2c08ea1 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -56,7 +56,7 @@ bool LLVMEnabled() { /*! \return The default host target for a given device target */ Target DefaultTargetHost(Target target) { - if (target.defined() && target->device_type == kDLCPU) { + if (target.defined() && target->id->device_type == kDLCPU) { return target; } else { if (LLVMEnabled()) { @@ -232,14 +232,14 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target auto mdevice = opt_device(mod_mixed); // some final misc checks. - auto keys = target->keys(); + auto keys = target->GetKeys(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); if (target_is_gpu && mdevice->functions.size() == 0) { LOG(WARNING) << "Specified target " << target->str() << " but cannot find device code. Did you forget to bind?"; } - if (target->device_type == target::llvm()->device_type && target_host == target) { + if (target->id->device_type == kDLCPU && target_host == target) { CHECK(mdevice->functions.empty()) << "No device code should be generated when target " << "and host_target are both llvm target." << "\n"; @@ -256,7 +256,7 @@ runtime::Module build(const Map& inputs, const Target& target_ Target target_host_val = target_host; if (!target_host.defined()) { for (const auto& it : inputs) { - if (it.first->device_type == kDLCPU || it.first->device_type == kDLMicroDev) { + if (it.first->id->device_type == kDLCPU || it.first->id->device_type == kDLMicroDev) { target_host_val = it.first; break; } @@ -295,7 +295,8 @@ runtime::Module build(const Map& inputs, const Target& target_ Map updated_input; for (const auto& it : inputs) { auto target = Target::Create(it.first); - if (target->device_name == "vta") { + Optional device = target->GetAttr("device"); + if (device.defined() && device.value() == "vta") { target = Target::Create("ext_dev"); } updated_input.Set(target, it.second); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 34c3487..b589bcc 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -441,7 +441,7 @@ class RelayBuildModule : public runtime::ModuleNode { if (!target_host.defined()) target_host = (pf != nullptr) ? target::llvm() : target::stackvm(); - if (target_host.defined() && target_host->target_name == "llvm") { + if (target_host.defined() && target_host->id->name == "llvm") { // If we can decide the target is LLVM, we then create an empty LLVM module. ret_.mod = (*pf)(target_host->str(), "empty_module"); } else { @@ -467,7 +467,7 @@ class RelayBuildModule : public runtime::ModuleNode { Target target_host = target_host_; if (!target_host_.defined()) { for (const auto& it : targets_) { - if (it.second->device_type == kDLCPU) { + if (it.second->id->device_type == kDLCPU) { target_host = it.second; break; } diff --git a/src/target/codegen.cc b/src/target/codegen.cc index d0a3156..52bd1c2 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -47,7 +47,12 @@ runtime::Module Build(IRModule mod, const Target& target) { .value()) { mod = tir::transform::SkipAssert()(mod); } - std::string build_f_name = "target.build." + target->target_name; + std::string build_f_name; + if (target->id->name == "micro_dev") { + build_f_name = "target.build.c"; + } else { + build_f_name = "target.build." + target->id->name; + } // the build function. const PackedFunc* bf = runtime::Registry::Get(build_f_name); CHECK(bf != nullptr) << "target.build." << target << " is not enabled"; diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 9ad9f56..b5842ee 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -102,7 +102,7 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { PackedFunc func; if (target.defined()) { - for (auto& k : target->keys()) { + for (auto& k : target->GetKeys()) { auto iter = node->dispatch_dict_.find(k); if (iter != node->dispatch_dict_.end()) { func = iter->second; diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 2b77869..597fd37 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -62,8 +62,9 @@ runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) // AOCL supports fp64. cmd += " -Dcl_khr_fp64"; Target target = Target::Create(target_str); - if (target->device_name != "") { - cmd += " -board=" + target->device_name; + Optional device = target->GetAttr("device"); + if (device.defined()) { + cmd += " -board=" + device.value(); } if (emulation) { cmd += " -march=emulator"; diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index e60e1f5..3d77dda 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -179,7 +179,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { std::string xclbin; if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) { Target target = Target::Create(target_str); - xclbin = (*f)(kernel_info, target->device_name).operator std::string(); + String device = target->GetAttr("device", "").value(); + xclbin = (*f)(kernel_info, device).operator std::string(); } else { LOG(FATAL) << "Cannot compile Vivado HLS code."; } diff --git a/src/target/target.cc b/src/target/target.cc index 2104c2e..5c61867 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -35,6 +36,41 @@ using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +Target Target::CreateTarget(const std::string& name, const std::vector& options) { + TargetId id = TargetId::Get(name); + ObjectPtr target = make_object(); + target->id = id; + // tag is always empty + target->tag = ""; + // parse attrs + target->attrs = id->ParseAttrsFromRawString(options); + String device_name = target->GetAttr("device", "").value(); + // create string representation + { + std::ostringstream str_repr; + str_repr << name; + for (const auto& s : options) { + str_repr << ' ' << s; + } + target->str_repr_ = str_repr.str(); + } + // set up keys + { + // user provided keys + Array keys = target->GetAttr>("keys").value_or({}); + // add `device_name` + if (!device_name.empty()) { + keys.push_back(device_name); + } + // add default keys + for (const auto& key : target->id->default_keys) { + keys.push_back(key); + } + target->keys = std::move(keys); + } + return Target(target); +} + TVM_REGISTER_NODE_TYPE(TargetNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -43,119 +79,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << op->str(); }); -/*! - * \brief Construct a Target node from the given name and options. - * \param target_name The major target name. Should be one of - * {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm", - * "metal", "nvptx", "opencl", "rocm", "sdaccel", "stackvm", "vulkan"} - * \param options Additional options appended to the target - * \return The constructed Target - */ -Target CreateTarget(const std::string& target_name, const std::vector& options) { - auto t = make_object(); - t->target_name = target_name; - - std::string libs_flag = "-libs="; - std::string device_flag = "-device="; - std::string keys_flag = "-keys="; - for (auto& item : options) { - t->options_array.push_back(item); - - if (item.find(libs_flag) == 0) { - std::stringstream ss(item.substr(libs_flag.length())); - std::string lib_item; - while (std::getline(ss, lib_item, ',')) { - t->libs_array.push_back(lib_item); - } - } else if (item.find(device_flag) == 0) { - t->device_name = item.substr(device_flag.length()); - t->keys_array.push_back(t->device_name); - } else if (item.find(keys_flag) == 0) { - std::stringstream ss(item.substr(keys_flag.length())); - std::string key_item; - while (std::getline(ss, key_item, ',')) { - t->keys_array.push_back(key_item); - } - } - } - - if (t->device_name.length() > 0) { - t->keys_array.push_back(t->device_name); - } - t->device_type = kDLCPU; - t->thread_warp_size = 1; - if (target_name == "c" && t->device_name == "micro_dev") { - t->device_type = kDLMicroDev; - } else if (target_name == "c" || target_name == "llvm") { - t->keys_array.push_back("cpu"); - } else if (target_name == "cuda" || target_name == "nvptx") { - t->device_type = kDLGPU; - t->keys_array.push_back("cuda"); - t->keys_array.push_back("gpu"); - t->max_num_threads = 1024; - t->thread_warp_size = 32; - } else if (target_name == "rocm" || target_name == "opencl") { - // For now assume rocm schedule for opencl - if (target_name == "opencl") { - t->device_type = kDLOpenCL; - } else { // rocm - t->device_type = kDLROCM; - t->thread_warp_size = 64; - } - t->keys_array.push_back(target_name); - t->keys_array.push_back("gpu"); - t->max_num_threads = 256; - if (t->device_name == "intel_graphics") { - t->thread_warp_size = 16; - } - } else if (target_name == "metal" || target_name == "vulkan" || target_name == "webgpu") { - if (target_name == "metal") { - t->device_type = kDLMetal; - } else if (target_name == "vulkan") { - t->device_type = kDLVulkan; - } else { - t->device_type = kDLWebGPU; - } - t->keys_array.push_back(target_name); - t->keys_array.push_back("gpu"); - t->max_num_threads = 256; - } else if (target_name == "sdaccel") { - t->device_type = kDLOpenCL; - t->keys_array.push_back("sdaccel"); - t->keys_array.push_back("hls"); - } else if (target_name == "aocl" || target_name == "aocl_sw_emu") { - t->device_type = kDLAOCL; - t->keys_array.push_back("aocl"); - t->keys_array.push_back("hls"); - } else if (target_name == "stackvm") { - t->device_type = kDLCPU; - } else if (target_name == "ext_dev") { - t->device_type = kDLExtDev; - } else if (target_name == "hybrid") { - t->device_type = kDLCPU; - } else if (target_name == "hexagon") { - t->keys_array.push_back("hexagon"); - t->device_type = kDLHexagon; - } else if (target_name == "webgpu") { - t->keys_array.push_back("webgpu"); - t->device_type = kDLWebGPU; - } else { - LOG(ERROR) << "Unknown target name " << target_name << "; falling back to stackvm"; - return target::stackvm(); - } - - return Target(t); -} - TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, TVMRetValue* ret) { - std::string target_name = args[0]; + std::string name = args[0]; std::vector options; for (int i = 1; i < args.num_args; ++i) { std::string arg = args[i]; options.push_back(arg); } - *ret = CreateTarget(target_name, options); + *ret = Target::CreateTarget(name, options); }); TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -163,38 +95,28 @@ TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetV *ret = Target::Create(target_str); }); -std::vector TargetNode::keys() const { +std::vector TargetNode::GetKeys() const { std::vector result; - for (auto& expr : keys_array) { + for (auto& expr : keys) { result.push_back(expr); } return result; } -std::vector TargetNode::options() const { - std::vector result; - for (auto& expr : options_array) { - result.push_back(expr); +std::unordered_set TargetNode::GetLibs() const { + Optional> libs = this->GetAttr>("libs"); + if (!libs.defined()) { + return {}; } - return result; -} - -std::unordered_set TargetNode::libs() const { std::unordered_set result; - for (auto& expr : libs_array) { - result.insert(expr); + for (const auto& item : libs.value()) { + result.insert(item); } return result; } const std::string& TargetNode::str() const { - if (str_repr_.length() != 0) return str_repr_; - std::ostringstream result; - result << target_name; - for (const auto& x : options()) { - result << " " << x; - } - str_repr_ = result.str(); + CHECK(!str_repr_.empty()); return str_repr_; } @@ -202,39 +124,14 @@ bool StartsWith(const std::string& str, const std::string& pattern) { return str.compare(0, pattern.length(), pattern) == 0; } -std::string GetDeviceName(const std::string& target_str) { - std::istringstream ss(target_str); - std::string target_name; - ss >> target_name; - - std::string item; - while (ss >> item) { - if (StartsWith(item, "-device=")) { - return item.substr(std::string("-device=").length()); - } +Target Target::Create(const String& target_str) { + std::vector splits; + std::istringstream is(target_str); + for (std::string s; is >> s; splits.push_back(s)) { } - - return ""; -} - -Target Target::Create(const std::string& target_str) { - if (target_str.length() == 0) { - LOG(ERROR) << "target_str must not be empty"; - } - - std::istringstream ss(target_str); - std::string target_name; - - ss >> target_name; - auto device_name = GetDeviceName(target_str); - - std::vector options; - std::string item; - while (ss >> item) { - options.push_back(item); - } - - return CreateTarget(target_name, options); + CHECK(!splits.empty()) << "ValueError: Cannot parse empty target string: \"" << target_str + << "\""; + return CreateTarget(splits[0], {splits.begin() + 1, splits.end()}); } /*! \brief Entry to hold the Target context stack. */ @@ -290,28 +187,45 @@ std::vector MergeOptions(std::vector opts, return opts; } -Target llvm(const std::vector& options) { return CreateTarget("llvm", options); } +Target llvm(const std::vector& options) { + return Target::CreateTarget("llvm", options); +} -Target cuda(const std::vector& options) { return CreateTarget("cuda", options); } +Target cuda(const std::vector& options) { + return Target::CreateTarget("cuda", options); +} -Target rocm(const std::vector& options) { return CreateTarget("rocm", options); } +Target rocm(const std::vector& options) { + return Target::CreateTarget("rocm", options); +} -Target opencl(const std::vector& options) { return CreateTarget("opencl", options); } +Target opencl(const std::vector& options) { + return Target::CreateTarget("opencl", options); +} -Target metal(const std::vector& options) { return CreateTarget("metal", options); } +Target metal(const std::vector& options) { + return Target::CreateTarget("metal", options); +} Target mali(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, {"-device=mali"})); + return Target::CreateTarget("opencl", MergeOptions(options, {"-device=mali"})); } Target intel_graphics(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, {"-device=intel_graphics"})); + return Target::CreateTarget( + "opencl", MergeOptions(options, {"-device=intel_graphics", "-thread_warp_size=16"})); } -Target stackvm(const std::vector& options) { return CreateTarget("stackvm", options); } +Target stackvm(const std::vector& options) { + return Target::CreateTarget("stackvm", options); +} -Target ext_dev(const std::vector& options) { return CreateTarget("ext_dev", options); } +Target ext_dev(const std::vector& options) { + return Target::CreateTarget("ext_dev", options); +} -Target hexagon(const std::vector& options) { return CreateTarget("hexagon", options); } +Target hexagon(const std::vector& options) { + return Target::CreateTarget("hexagon", options); +} } // namespace target } // namespace tvm diff --git a/src/target/target_id.cc b/src/target/target_id.cc index dc1255c..faecf03 100644 --- a/src/target/target_id.cc +++ b/src/target/target_id.cc @@ -28,6 +28,14 @@ namespace tvm { +TVM_REGISTER_NODE_TYPE(TargetIdNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->name; + }); + using TargetIdRegistry = AttrRegistry; TargetIdRegEntry& TargetIdRegEntry::RegisterOrGet(const String& target_id_name) { @@ -45,14 +53,14 @@ const AttrRegistryMapContainerMap& TargetId::GetAttrMapContainer( const TargetId& TargetId::Get(const String& target_id_name) { const TargetIdRegEntry* reg = TargetIdRegistry::Global()->Get(target_id_name); - CHECK(reg != nullptr) << "TargetId " << target_id_name << " is not registered"; + CHECK(reg != nullptr) << "ValueError: TargetId \"" << target_id_name << "\" is not registered"; return reg->id_; } void VerifyTypeInfo(const ObjectRef& obj, const TargetIdNode::ValueTypeInfo& info) { CHECK(obj.defined()) << "Object is None"; if (!runtime::ObjectInternal::DerivedFrom(obj.get(), info.type_index)) { - LOG(FATAL) << "AttributeError: expect type " << info.type_key << " but get " + LOG(FATAL) << "AttributeError: expect type \"" << info.type_key << "\" but get " << obj->GetTypeKey(); throw; } @@ -74,16 +82,16 @@ void VerifyTypeInfo(const ObjectRef& obj, const TargetIdNode::ValueTypeInfo& inf try { VerifyTypeInfo(kv.first, *info.key); } catch (const tvm::Error& e) { - LOG(FATAL) << "The key of map failed type checking, where key = " << kv.first - << ", value = " << kv.second << ", and the error is:\n" + LOG(FATAL) << "The key of map failed type checking, where key = \"" << kv.first + << "\", value = \"" << kv.second << "\", and the error is:\n" << e.what(); throw; } try { VerifyTypeInfo(kv.second, *info.val); } catch (const tvm::Error& e) { - LOG(FATAL) << "The value of map failed type checking, where key = " << kv.first - << ", value = " << kv.second << ", and the error is:\n" + LOG(FATAL) << "The value of map failed type checking, where key = \"" << kv.first + << "\", value = \"" << kv.second << "\", and the error is:\n" << e.what(); throw; } @@ -98,16 +106,18 @@ void TargetIdNode::ValidateSchema(const Map& config) const { const ObjectRef& obj = kv.second; if (name == kTargetId) { CHECK(obj->IsInstance()) - << "AttributeError: \"id\" is not a string, but its type is " << obj->GetTypeKey(); + << "AttributeError: \"id\" is not a string, but its type is \"" << obj->GetTypeKey() + << "\""; CHECK(Downcast(obj) == this->name) - << "AttributeError: \"id\" = " << obj << " is inconsistent with TargetId " << this->name; + << "AttributeError: \"id\" = \"" << obj << "\" is inconsistent with TargetId \"" + << this->name << "\""; continue; } auto it = key2vtype_.find(name); if (it == key2vtype_.end()) { std::ostringstream os; - os << "AttributeError: Invalid config option, cannot recognize \'" << name - << "\'. Candidates are:"; + os << "AttributeError: Invalid config option, cannot recognize \"" << name + << "\". Candidates are:"; for (const auto& kv : key2vtype_) { os << "\n " << kv.first; } @@ -118,8 +128,8 @@ void TargetIdNode::ValidateSchema(const Map& config) const { try { VerifyTypeInfo(obj, info); } catch (const tvm::Error& e) { - LOG(FATAL) << "AttributeError: Schema validation failed for TargetId " << this->name - << ", details:\n" + LOG(FATAL) << "AttributeError: Schema validation failed for TargetId \"" << this->name + << "\", details:\n" << e.what() << "\n" << "The config is:\n" << config; @@ -130,12 +140,12 @@ void TargetIdNode::ValidateSchema(const Map& config) const { inline String GetId(const Map& target, const char* name) { const String kTargetId = "id"; - CHECK(target.count(kTargetId)) << "AttributeError: \"id\" does not exist in " << name << "\n" + CHECK(target.count(kTargetId)) << "AttributeError: \"id\" does not exist in \"" << name << "\"\n" << name << " = " << target; const ObjectRef& obj = target[kTargetId]; - CHECK(obj->IsInstance()) << "AttributeError: \"id\" is not a string in " << name - << ", but its type is " << obj->GetTypeKey() << "\n" - << name << " = " << target; + CHECK(obj->IsInstance()) << "AttributeError: \"id\" is not a string in \"" << name + << "\", but its type is \"" << obj->GetTypeKey() << "\"\n" + << name << " = \"" << target << '"'; return Downcast(obj); } @@ -156,9 +166,292 @@ void TargetValidateSchema(const Map& config) { TargetId::Get(target_host_id)->ValidateSchema(target_host); } } catch (const tvm::Error& e) { - LOG(INFO) << e.what(); - throw e; + LOG(FATAL) << "AttributeError: schedule validation fails:\n" + << e.what() << "\nThe configuration is:\n" + << config; + } +} + +static inline size_t CountNumPrefixDashes(const std::string& s) { + size_t i = 0; + for (; i < s.length() && s[i] == '-'; ++i) { + } + return i; +} + +static inline int FindUniqueSubstr(const std::string& str, const std::string& substr) { + size_t pos = str.find_first_of(substr); + if (pos == std::string::npos) { + return -1; + } + size_t next_pos = pos + substr.size(); + CHECK(next_pos >= str.size() || str.find_first_of(substr, next_pos) == std::string::npos) + << "ValueError: At most one \"" << substr << "\" is allowed in " + << "the the given string \"" << str << "\""; + return pos; +} + +static inline ObjectRef ParseScalar(uint32_t type_index, const std::string& str) { + std::istringstream is(str); + if (type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + int v; + is >> v; + return is.fail() ? ObjectRef(nullptr) : Integer(v); + } else if (type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + std::string v; + is >> v; + return is.fail() ? ObjectRef(nullptr) : String(v); + } + return ObjectRef(nullptr); +} + +Map TargetIdNode::ParseAttrsFromRawString( + const std::vector& options) { + std::unordered_map attrs; + for (size_t iter = 0, end = options.size(); iter < end;) { + std::string s = options[iter++]; + // remove the prefix dashes + size_t n_dashes = CountNumPrefixDashes(s); + CHECK(0 < n_dashes && n_dashes < s.size()) + << "ValueError: Not an attribute key \"" << s << "\""; + s = s.substr(n_dashes); + // parse name-obj pair + std::string name; + std::string obj; + int pos; + if ((pos = FindUniqueSubstr(s, "=")) != -1) { + // case 1. --key=value + name = s.substr(0, pos); + obj = s.substr(pos + 1); + CHECK(!name.empty()) << "ValueError: Empty attribute key in \"" << options[iter - 1] << "\""; + CHECK(!obj.empty()) << "ValueError: Empty attribute in \"" << options[iter - 1] << "\""; + } else if (iter < end && options[iter][0] != '-') { + // case 2. --key value + name = s; + obj = options[iter++]; + } else { + // case 3. --boolean-key + name = s; + obj = "1"; + } + // check if `name` is invalid + auto it = key2vtype_.find(name); + if (it == key2vtype_.end()) { + std::ostringstream os; + os << "AttributeError: Invalid config option, cannot recognize \'" << name + << "\'. Candidates are:"; + for (const auto& kv : key2vtype_) { + os << "\n " << kv.first; + } + LOG(FATAL) << os.str(); + } + // then `name` is valid, let's parse them + // only several types are supported when parsing raw string + const auto& info = it->second; + ObjectRef parsed_obj(nullptr); + if (info.type_index != ArrayNode::_type_index) { + parsed_obj = ParseScalar(info.type_index, obj); + } else { + Array array; + std::string item; + bool failed = false; + uint32_t type_index = info.key->type_index; + for (std::istringstream is(obj); std::getline(is, item, ',');) { + ObjectRef parsed_obj = ParseScalar(type_index, item); + if (parsed_obj.defined()) { + array.push_back(parsed_obj); + } else { + failed = true; + break; + } + } + if (!failed) { + parsed_obj = std::move(array); + } + } + if (!parsed_obj.defined()) { + LOG(FATAL) << "ValueError: Cannot parse type \"" << info.type_key << "\"" + << ", where attribute key is \"" << name << "\"" + << ", and attribute is \"" << obj << "\""; + } + attrs[name] = std::move(parsed_obj); } + // set default attribute values if they do not exist + for (const auto& kv : key2default_) { + if (!attrs.count(kv.first)) { + attrs[kv.first] = kv.second; + } + } + return attrs; } +// TODO(@junrushao1994): remove some redundant attributes + +TVM_REGISTER_TARGET_ID("llvm") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .add_attr_option("mcpu") + .add_attr_option("mattr") + .add_attr_option("mtriple") + .add_attr_option("target") // FIXME: rename to mtriple + .set_default_keys({"cpu"}) + .set_device_type(kDLCPU); + +TVM_REGISTER_TARGET_ID("c") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_default_keys({"cpu"}) + .set_device_type(kDLCPU); + +TVM_REGISTER_TARGET_ID("micro_dev") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_default_keys({"micro_dev"}) + .set_device_type(kDLMicroDev); + +TVM_REGISTER_TARGET_ID("cuda") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .add_attr_option("max_num_threads", Integer(1024)) + .add_attr_option("thread_warp_size", Integer(32)) + .add_attr_option("mcpu") + .set_default_keys({"cuda", "gpu"}) + .set_device_type(kDLGPU); + +TVM_REGISTER_TARGET_ID("nvptx") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .add_attr_option("max_num_threads", Integer(1024)) + .add_attr_option("thread_warp_size", Integer(32)) + .add_attr_option("mcpu") + .set_default_keys({"cuda", "gpu"}) + .set_device_type(kDLGPU); + +TVM_REGISTER_TARGET_ID("rocm") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("thread_warp_size", Integer(64)) + .set_default_keys({"rocm", "gpu"}) + .set_device_type(kDLROCM); + +TVM_REGISTER_TARGET_ID("opencl") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("thread_warp_size") + .set_default_keys({"opencl", "gpu"}) + .set_device_type(kDLOpenCL); + +TVM_REGISTER_TARGET_ID("metal") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .add_attr_option("max_num_threads", Integer(256)) + .set_default_keys({"metal", "gpu"}) + .set_device_type(kDLMetal); + +TVM_REGISTER_TARGET_ID("vulkan") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .add_attr_option("max_num_threads", Integer(256)) + .set_default_keys({"vulkan", "gpu"}) + .set_device_type(kDLVulkan); + +TVM_REGISTER_TARGET_ID("webgpu") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .add_attr_option("max_num_threads", Integer(256)) + .set_default_keys({"webgpu", "gpu"}) + .set_device_type(kDLWebGPU); + +TVM_REGISTER_TARGET_ID("sdaccel") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_default_keys({"sdaccel", "hls"}) + .set_device_type(kDLOpenCL); + +TVM_REGISTER_TARGET_ID("aocl") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_default_keys({"aocl", "hls"}) + .set_device_type(kDLAOCL); + +TVM_REGISTER_TARGET_ID("aocl_sw_emu") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_default_keys({"aocl", "hls"}) + .set_device_type(kDLAOCL); + +TVM_REGISTER_TARGET_ID("hexagon") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_default_keys({"hexagon"}) + .set_device_type(kDLHexagon); + +TVM_REGISTER_TARGET_ID("stackvm") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_device_type(kDLCPU); + +TVM_REGISTER_TARGET_ID("ext_dev") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_device_type(kDLExtDev); + +TVM_REGISTER_TARGET_ID("hybrid") + .add_attr_option>("keys") + .add_attr_option>("libs") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option("system-lib") + .set_device_type(kDLCPU); + } // namespace tvm diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 75605ad..7541662 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -1083,7 +1083,7 @@ Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, Map extern_buffer) { // Check if current lower target is CUDA auto target = tvm::Target::Current(true); - if (target.defined() && target->target_name != "cuda") { + if (target.defined() && target->id->name != "cuda") { return stmt; } diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 12ec270..f8a5986 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -177,7 +177,7 @@ bool VerifyMemory(const PrimFunc& func) { if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { - MemoryAccessVerifier v(func, target.value()->device_type); + MemoryAccessVerifier v(func, target.value()->id->device_type); v.Run(); return !v.Failed(); } else { diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 154023c..f5491da 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -139,7 +139,7 @@ Pass LowerCustomDatatypes() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; - n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body)); + n->body = CustomDatatypesLowerer(target.value()->id->name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 5ec4fe3..5372ef8 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -40,12 +40,11 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - IntrinInjecter(arith::Analyzer* analyzer, std::string target_name) - : IRMutatorWithAnalyzer(analyzer) { - patterns_.push_back("tvm.intrin.rule." + target_name + "."); + IntrinInjecter(arith::Analyzer* analyzer, std::string target) : IRMutatorWithAnalyzer(analyzer) { + patterns_.push_back("tvm.intrin.rule." + target + "."); patterns_.push_back("tvm.intrin.rule.default."); fma_ = runtime::Registry::Get(patterns_[0] + "fma"); - if (target_name == "stackvm") { + if (target == "stackvm") { support_bitwise_op_ = false; } } @@ -275,9 +274,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { bool support_bitwise_op_{true}; }; -Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) { +Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { arith::Analyzer analyzer; - return IntrinInjecter(&analyzer, target_name)(std::move(stmt)); + return IntrinInjecter(&analyzer, target)(std::move(stmt)); } namespace transform { @@ -288,7 +287,7 @@ Pass LowerIntrin() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - n->body = IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); + n->body = IntrinInjecter(&analyzer, target.value()->id->name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 04b8953..17b4265 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -40,7 +40,7 @@ namespace tir { class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) - : target_(target), warp_size_(target->thread_warp_size) {} + : target_(target), warp_size_(target->GetAttr("thread_warp_size", 1).value()) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -484,11 +484,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). bool is_warp_reduction(const std::vector& types) const { // Only cuda target supports warp reductions. - if ((target_->target_name != "cuda") && (target_->target_name != "rocm")) return false; + if ((target_->id->name != "cuda") && (target_->id->name != "rocm")) return false; // rocm only supports 32 bit operands for shuffling at the moment - if ((target_->target_name == "rocm") && - (std::any_of(types.begin(), types.end(), [](DataType ty) { + if ((target_->id->name == "rocm") && (std::any_of(types.begin(), types.end(), [](DataType ty) { if (ty.is_vector()) return true; return ty.bits() != 32; }))) { diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 480c62c..8892c32 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -392,7 +392,8 @@ Pass LowerWarpMemory() { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body)); + int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); + n->body = WarpMemoryRewriter(warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index bfcf0b7..191bb0a 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -51,7 +51,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "MakePackedAPI: Require the target attribute"; - int target_device_type = target.value()->device_type; + int target_device_type = target.value()->id->device_type; std::string name_hint = global_symbol.value(); diff --git a/tests/micro/test_runtime_micro_on_arm.py b/tests/micro/test_runtime_micro_on_arm.py index 301677e..ed7d62f 100644 --- a/tests/micro/test_runtime_micro_on_arm.py +++ b/tests/micro/test_runtime_micro_on_arm.py @@ -33,7 +33,7 @@ from tvm.relay.testing import resnet # Ex : export CMSIS_ST_PATH="/home/yourid/st/STM32Cube_FW_F7_V1.16.0/Drivers/CMSIS" DEV_CONFIG_A = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666) DEV_CONFIG_B = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666) -TARGET = 'c -device=micro_dev' +TARGET = 'micro_dev' def relay_micro_build(func, dev_config, params=None): """Create a graph runtime module with a micro device context from a Relay function. diff --git a/tests/python/unittest/test_runtime_micro.py b/tests/python/unittest/test_runtime_micro.py index 2eea3df..eb137a9 100644 --- a/tests/python/unittest/test_runtime_micro.py +++ b/tests/python/unittest/test_runtime_micro.py @@ -28,7 +28,7 @@ from tvm.relay.testing import resnet # # Use the host emulated micro device. DEV_CONFIG_A = micro.device.host.generate_config() DEV_CONFIG_B = micro.device.host.generate_config() -TARGET = 'c -device=micro_dev' +TARGET = 'micro_dev' def relay_micro_build(func, dev_config, params=None): """Create a graph runtime module with a micro device context from a Relay function. diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index da7bcee..fe3799b 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -57,8 +57,8 @@ def test_target_dispatch(): def test_target_string_parse(): target = tvm.target.create("cuda -model=unknown -libs=cublas,cudnn") - assert target.target_name == "cuda" - assert target.options == ['-model=unknown', '-libs=cublas,cudnn'] + assert target.id.name == "cuda" + assert target.model == "unknown" assert target.keys == ['cuda', 'gpu'] assert target.libs == ['cublas', 'cudnn'] assert str(target) == str(tvm.target.cuda(options="-libs=cublas,cudnn")) diff --git a/topi/include/topi/cuda/dense.h b/topi/include/topi/cuda/dense.h index 145d249..c8ceebf 100644 --- a/topi/include/topi/cuda/dense.h +++ b/topi/include/topi/cuda/dense.h @@ -62,7 +62,7 @@ inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& d auto in_dim = data->shape[1]; auto out_dim = weight->shape[0]; - if (target->libs().count("cublas")) { + if (target->GetLibs().count("cublas")) { CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported."; auto mm = topi::contrib::cublas_matmul(data, weight, false, true); if (bias.defined()) { @@ -85,7 +85,7 @@ inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& d * \return A schedule for the given ops. */ inline Schedule schedule_dense(const Target& target, const Array& outs) { - if (target->target_name == "cuda" && target->libs().count("cublas")) { + if (target->id->name == "cuda" && target->GetLibs().count("cublas")) { return topi::generic::schedule_extern(target, outs); } diff --git a/topi/include/topi/cuda/injective.h b/topi/include/topi/cuda/injective.h index 5a5c5af..e7bce05 100644 --- a/topi/include/topi/cuda/injective.h +++ b/topi/include/topi/cuda/injective.h @@ -47,7 +47,7 @@ namespace cuda { inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) { auto fused = detail::Fuse(sch[out], sch[out]->op.as()->axis); auto target = Target::Current(false); - auto num_thread = target->max_num_threads; + int num_thread = target->GetAttr("max_num_threads").value(); IterVar bx, tx; sch[out].split(fused, num_thread, &bx, &tx); sch[out].bind(bx, thread_axis(Range(), "blockIdx.x")); diff --git a/topi/include/topi/cuda/pooling.h b/topi/include/topi/cuda/pooling.h index 87866f2..7e8f55d 100644 --- a/topi/include/topi/cuda/pooling.h +++ b/topi/include/topi/cuda/pooling.h @@ -56,7 +56,7 @@ inline Schedule schedule_pool(const Target& target, const Array& outs) { if (padded_input->op->IsInstance()) { s[padded_input].compute_inline(); } - auto num_thread = target->max_num_threads; + int num_thread = target->GetAttr("max_num_threads").value(); Tensor out; Tensor OL; if (detail::contains(s->outputs, pool->op)) { diff --git a/topi/include/topi/cuda/reduction.h b/topi/include/topi/cuda/reduction.h index 35ce346..377b922 100644 --- a/topi/include/topi/cuda/reduction.h +++ b/topi/include/topi/cuda/reduction.h @@ -69,7 +69,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, if (out_stage->op.as()->axis.size() > 0) { all_reduce = false; num_thread = 32; - if (target->target_name == "opencl") { + if (target->id->name == "opencl") { // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests. // Don't know why. num_thread = 16; @@ -79,7 +79,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y"); } else { all_reduce = true; - num_thread = target->max_num_threads; + num_thread = target->GetAttr("max_num_threads").value(); thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x"); } diff --git a/topi/include/topi/rocm/dense.h b/topi/include/topi/rocm/dense.h index 72f8ee6..e2e04b4 100644 --- a/topi/include/topi/rocm/dense.h +++ b/topi/include/topi/rocm/dense.h @@ -63,7 +63,7 @@ inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& d auto in_dim = data->shape[1]; auto out_dim = weight->shape[0]; - if (target->libs().count("rocblas")) { + if (target->GetLibs().count("rocblas")) { CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported."; auto mm = topi::contrib::rocblas_matmul(data, weight, false, true); if (bias.defined()) { @@ -86,7 +86,7 @@ inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& d * \return A schedule for the given ops. */ inline Schedule schedule_dense(const Target& target, const Array& outs) { - if (target->target_name == "rocm" && target->libs().count("rocblas")) { + if (target->id->name == "rocm" && target->GetLibs().count("rocblas")) { return topi::generic::schedule_extern(target, outs); } diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py index c1587ba..68161c3 100644 --- a/topi/python/topi/arm_cpu/conv2d_gemm.py +++ b/topi/python/topi/arm_cpu/conv2d_gemm.py @@ -27,7 +27,7 @@ from .tensor_intrin import gemv_quantized, gemv_quantized_impl def is_aarch64_arm(): """ Checks whether we are compiling for an AArch64 target. """ target = tvm.target.Target.current(allow_none=False) - return 'aarch64' in ' '.join(target.options) + return 'aarch64' in target.attrs.get("target", "") # Compute function diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py index 7d92edf..bcd98cc 100644 --- a/topi/python/topi/cuda/batch_matmul.py +++ b/topi/python/topi/cuda/batch_matmul.py @@ -69,7 +69,7 @@ def schedule_batch_matmul(cfg, outs): cfg.define_split("tile_k", k, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: # llvm-based backends cannot do non-explicit unrolling cfg.define_knob("unroll_explicit", [1]) else: diff --git a/topi/python/topi/cuda/conv1d.py b/topi/python/topi/cuda/conv1d.py index 3ddecbe..533cf74 100644 --- a/topi/python/topi/cuda/conv1d.py +++ b/topi/python/topi/cuda/conv1d.py @@ -72,7 +72,7 @@ def schedule_conv1d_ncw(cfg, outs): cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) @@ -197,7 +197,7 @@ def schedule_conv1d_nwc(cfg, outs): cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) diff --git a/topi/python/topi/cuda/conv1d_transpose_ncw.py b/topi/python/topi/cuda/conv1d_transpose_ncw.py index a2ac7e1..ffce584 100644 --- a/topi/python/topi/cuda/conv1d_transpose_ncw.py +++ b/topi/python/topi/cuda/conv1d_transpose_ncw.py @@ -124,7 +124,7 @@ def schedule_conv1d_transpose_ncw(cfg, outs): cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) diff --git a/topi/python/topi/cuda/conv2d_direct.py b/topi/python/topi/cuda/conv2d_direct.py index db6bff2..9d8146e 100644 --- a/topi/python/topi/cuda/conv2d_direct.py +++ b/topi/python/topi/cuda/conv2d_direct.py @@ -36,7 +36,7 @@ def schedule_direct_cuda(cfg, s, conv): cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) @@ -44,7 +44,7 @@ def schedule_direct_cuda(cfg, s, conv): # fallback support if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'conv2d_nchw.cuda') + target.id.name, target.model, 'conv2d_nchw.cuda') cfg.fallback_with_reference_log(ref_log) ##### space definition end ##### diff --git a/topi/python/topi/cuda/conv2d_nhwc.py b/topi/python/topi/cuda/conv2d_nhwc.py index 55714b2..c7c3f18 100644 --- a/topi/python/topi/cuda/conv2d_nhwc.py +++ b/topi/python/topi/cuda/conv2d_nhwc.py @@ -56,7 +56,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'conv2d_nhwc.cuda') + target.id.name, target.model, 'conv2d_nhwc.cuda') cfg.fallback_with_reference_log(ref_log) tile_n = cfg["tile_n"].val diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py index 790db0f..7703e40 100644 --- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py @@ -134,7 +134,7 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'conv2d_nhwc_tensorcore.cuda') + target.id.name, target.model, 'conv2d_nhwc_tensorcore.cuda') cfg.fallback_with_reference_log(ref_log) block_row_warps = cfg["block_row_warps"].val diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py index 5ad4947..4dfcc03 100644 --- a/topi/python/topi/cuda/conv2d_transpose_nchw.py +++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py @@ -177,7 +177,7 @@ def schedule_conv2d_transpose_nchw(cfg, outs): cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) diff --git a/topi/python/topi/cuda/conv2d_winograd.py b/topi/python/topi/cuda/conv2d_winograd.py index 881f63a..d976aaa 100644 --- a/topi/python/topi/cuda/conv2d_winograd.py +++ b/topi/python/topi/cuda/conv2d_winograd.py @@ -193,7 +193,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed): cfg.define_split("tile_rc", rc, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 128, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) diff --git a/topi/python/topi/cuda/conv3d_direct.py b/topi/python/topi/cuda/conv3d_direct.py index 50b73d6..0b80e79 100644 --- a/topi/python/topi/cuda/conv3d_direct.py +++ b/topi/python/topi/cuda/conv3d_direct.py @@ -43,7 +43,7 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name): cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) @@ -51,7 +51,7 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name): # fallback support if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, workload_name) + target.id.name, target.model, workload_name) cfg.fallback_with_reference_log(ref_log) ##### space definition end ##### diff --git a/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py b/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py index e3c7513..68b0145 100644 --- a/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py +++ b/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py @@ -141,7 +141,7 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv): target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'conv3d_ndhwc_tensorcore.cuda') + target.id.name, target.model, 'conv3d_ndhwc_tensorcore.cuda') cfg.fallback_with_reference_log(ref_log) block_row_warps = cfg["block_row_warps"].val diff --git a/topi/python/topi/cuda/conv3d_winograd.py b/topi/python/topi/cuda/conv3d_winograd.py index 5876243..e8b5037 100644 --- a/topi/python/topi/cuda/conv3d_winograd.py +++ b/topi/python/topi/cuda/conv3d_winograd.py @@ -321,7 +321,7 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed): cfg.define_split("tile_rc", rc, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 128, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) @@ -478,7 +478,7 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed): cfg.define_split("tile_rz", rz, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 128, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) diff --git a/topi/python/topi/cuda/correlation.py b/topi/python/topi/cuda/correlation.py index a383e4e..6d9be95 100644 --- a/topi/python/topi/cuda/correlation.py +++ b/topi/python/topi/cuda/correlation.py @@ -81,7 +81,7 @@ def _schedule_correlation_nchw(cfg, s, correlation): cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) diff --git a/topi/python/topi/cuda/deformable_conv2d.py b/topi/python/topi/cuda/deformable_conv2d.py index 8c31835..6def731 100644 --- a/topi/python/topi/cuda/deformable_conv2d.py +++ b/topi/python/topi/cuda/deformable_conv2d.py @@ -71,7 +71,7 @@ def _schedule_direct_cuda(cfg, s, conv): cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) diff --git a/topi/python/topi/cuda/dense_tensorcore.py b/topi/python/topi/cuda/dense_tensorcore.py index 3546847..a6d1c05 100644 --- a/topi/python/topi/cuda/dense_tensorcore.py +++ b/topi/python/topi/cuda/dense_tensorcore.py @@ -95,7 +95,7 @@ def _schedule_dense_tensorcore(cfg, s, C): target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'dense_tensorcore.cuda') + target.id.name, target.model, 'dense_tensorcore.cuda') cfg.fallback_with_reference_log(ref_log) # Deal with op fusion, such as bias and relu diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index b7cb32d..f9ef8b6 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -61,7 +61,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs): cfg.define_knob("auto_unroll_max_step", [0, 256, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) @@ -69,7 +69,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs): # fallback support if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'depthwise_conv2d_nchw.cuda') + target.id.name, target.model, 'depthwise_conv2d_nchw.cuda') cfg.fallback_with_reference_log(ref_log) # TODO(lmzheng): A bug here, set unroll_explicit to False as workaround cfg['unroll_explicit'].val = 0 @@ -169,7 +169,7 @@ def schedule_depthwise_conv2d_nhwc(outs): # num_thread here could be 728, it is larger than cuda.max_num_threads num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value target = tvm.target.Target.current() - if target and (target.target_name not in ["cuda", "nvptx"]): + if target and (target.id.name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads xoc, xic = s[Output].split(c, factor=num_thread) s[Output].reorder(xoc, b, h, w, xic) diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index c5cf72b..e5cbe3e 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -83,7 +83,7 @@ def _schedule_group_conv2d_nchw_direct(cfg, s, conv): cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) diff --git a/topi/python/topi/cuda/reduction.py b/topi/python/topi/cuda/reduction.py index d885c09..9d3c529 100644 --- a/topi/python/topi/cuda/reduction.py +++ b/topi/python/topi/cuda/reduction.py @@ -36,7 +36,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): all_reduce = False num_thread = 32 target = tvm.target.Target.current() - if target and target.target_name == "opencl": + if target and target.id.name == "opencl": # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py # don't know why num_thread = 16 diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 5f7402b..910d0f3 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -59,9 +59,9 @@ def schedule_softmax(outs): # # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend. def sched_warp_softmax(): - if tgt.target_name == "nvptx" or tgt.target_name == "rocm": + if tgt.id.name == "nvptx" or tgt.id.name == "rocm": return softmax.dtype == "float32" or softmax.dtype == "int32" - if tgt.target_name != "cuda": + if tgt.id.name != "cuda": # this is used as the gpu schedule for other arches which may not have warp reductions return False return True diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index eb49328..c5e2a6e 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -53,7 +53,7 @@ def schedule_reorg(outs): The computation schedule for reorg. """ target = tvm.target.Target.current(allow_none=False) - cpp_target = cpp.TEST_create_target(target.target_name) + cpp_target = cpp.TEST_create_target(target.id.name) return cpp.cuda.schedule_injective(cpp_target, outs) def schedule_nms(outs): diff --git a/topi/python/topi/generic/default.py b/topi/python/topi/generic/default.py index 59e5a25..93a1dd2 100644 --- a/topi/python/topi/generic/default.py +++ b/topi/python/topi/generic/default.py @@ -24,7 +24,7 @@ def default_schedule(outs, auto_inline): """Default schedule for llvm.""" target = tvm.target.Target.current(allow_none=False) outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - if target.target_name not in ("llvm", "c"): + if target.id.name not in ("llvm", "c"): raise RuntimeError("schedule not registered for '%s'" % target) s = te.create_schedule([x.op for x in outs]) if auto_inline: diff --git a/topi/python/topi/generic/injective.py b/topi/python/topi/generic/injective.py index fa6aee4..a60b1e7 100644 --- a/topi/python/topi/generic/injective.py +++ b/topi/python/topi/generic/injective.py @@ -54,7 +54,7 @@ def schedule_injective(outs): The computation schedule for the op. """ target = tvm.target.Target.current(allow_none=False) - if target.target_name != "llvm": + if target.id.name != "llvm": raise RuntimeError("schedule_injective not registered for '%s'" % target) outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs x = outs[0] diff --git a/topi/python/topi/generic/vision.py b/topi/python/topi/generic/vision.py index edf1a48..a1db9ab 100644 --- a/topi/python/topi/generic/vision.py +++ b/topi/python/topi/generic/vision.py @@ -37,7 +37,7 @@ def schedule_reorg(outs): The computation schedule for the op. """ target = tvm.target.Target.current(allow_none=False) - cpp_target = cpp.TEST_create_target(target.target_name) + cpp_target = cpp.TEST_create_target(target.id.name) return cpp.generic.default_schedule(cpp_target, outs, False) def schedule_get_valid_counts(outs): diff --git a/topi/python/topi/intel_graphics/depthwise_conv2d.py b/topi/python/topi/intel_graphics/depthwise_conv2d.py index 6508099..bc2b27b 100644 --- a/topi/python/topi/intel_graphics/depthwise_conv2d.py +++ b/topi/python/topi/intel_graphics/depthwise_conv2d.py @@ -62,7 +62,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs): cfg.define_knob("auto_unroll_max_step", [0, 256, 1500]) target = tvm.target.Target.current() - if target.target_name in ['nvptx', 'rocm']: + if target.id.name in ['nvptx', 'rocm']: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) @@ -70,7 +70,7 @@ def schedule_depthwise_conv2d_nchw(cfg, outs): # fallback support if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.target_name, target.model, 'depthwise_conv2d_nchw.intel_graphics') + target.id.name, target.model, 'depthwise_conv2d_nchw.intel_graphics') cfg.fallback_with_reference_log(ref_log) cfg['unroll_explicit'].val = 0 ##### space definition end ##### @@ -170,7 +170,7 @@ def schedule_depthwise_conv2d_nhwc(outs): # num_thread here could be 728, it is larger than cuda.max_num_threads num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value target = tvm.target.Target.current() - if target and (target.target_name not in ["cuda", "nvptx"]): + if target and (target.id.name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads xoc, xic = s[Output].split(c, factor=num_thread) s[Output].reorder(xoc, b, h, w, xic)